diff --git a/codex-lens/scripts/generate_embeddings.py b/codex-lens/scripts/generate_embeddings.py index 69fd2412..a2bb052c 100644 --- a/codex-lens/scripts/generate_embeddings.py +++ b/codex-lens/scripts/generate_embeddings.py @@ -1,15 +1,9 @@ #!/usr/bin/env python3 """Generate vector embeddings for existing CodexLens indexes. -This script processes all files in a CodexLens index database and generates -semantic vector embeddings for code chunks. The embeddings are stored in the -same SQLite database in the 'semantic_chunks' table. - -Performance optimizations: -- Parallel file processing using ProcessPoolExecutor -- Batch embedding generation for efficient GPU/CPU utilization -- Batch database writes to minimize I/O overhead -- HNSW index auto-generation for fast similarity search +This script is a CLI wrapper around the memory-efficient streaming implementation +in codexlens.cli.embedding_manager. It uses batch processing to keep memory usage +under 2GB regardless of project size. Requirements: pip install codexlens[semantic] @@ -20,27 +14,21 @@ Usage: # Generate embeddings for a single index python generate_embeddings.py /path/to/_index.db - # Generate embeddings with parallel processing - python generate_embeddings.py /path/to/_index.db --workers 4 - - # Use specific embedding model and batch size - python generate_embeddings.py /path/to/_index.db --model code --batch-size 256 + # Use specific embedding model + python generate_embeddings.py /path/to/_index.db --model code # Generate embeddings for all indexes in a directory python generate_embeddings.py --scan ~/.codexlens/indexes + + # Force regeneration + python generate_embeddings.py /path/to/_index.db --force """ import argparse import logging -import multiprocessing -import os -import sqlite3 import sys -import time -from concurrent.futures import ProcessPoolExecutor, as_completed -from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple +from typing import List # Configure logging logging.basicConfig( @@ -50,100 +38,32 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) - -@dataclass -class FileData: - """Data for a single file to process.""" - full_path: str - content: str - language: str - - -@dataclass -class ChunkData: - """Processed chunk data ready for embedding.""" - file_path: str - content: str - metadata: dict +# Import the memory-efficient implementation +try: + from codexlens.cli.embedding_manager import ( + generate_embeddings, + generate_embeddings_recursive, + ) + from codexlens.semantic import SEMANTIC_AVAILABLE +except ImportError as exc: + logger.error(f"Failed to import codexlens: {exc}") + logger.error("Make sure codexlens is installed: pip install codexlens") + SEMANTIC_AVAILABLE = False def check_dependencies(): """Check if semantic search dependencies are available.""" - try: - from codexlens.semantic import SEMANTIC_AVAILABLE - if not SEMANTIC_AVAILABLE: - logger.error("Semantic search dependencies not available") - logger.error("Install with: pip install codexlens[semantic]") - logger.error("Or: pip install fastembed numpy hnswlib") - return False - return True - except ImportError as exc: - logger.error(f"Failed to import codexlens: {exc}") - logger.error("Make sure codexlens is installed: pip install codexlens") + if not SEMANTIC_AVAILABLE: + logger.error("Semantic search dependencies not available") + logger.error("Install with: pip install codexlens[semantic]") + logger.error("Or: pip install fastembed numpy hnswlib") return False + return True -def count_files(index_db_path: Path) -> int: - """Count total files in index.""" - try: - with sqlite3.connect(index_db_path) as conn: - cursor = conn.execute("SELECT COUNT(*) FROM files") - return cursor.fetchone()[0] - except Exception as exc: - logger.error(f"Failed to count files: {exc}") - return 0 - - -def check_existing_chunks(index_db_path: Path) -> int: - """Check if semantic chunks already exist.""" - try: - with sqlite3.connect(index_db_path) as conn: - # Check if table exists - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='semantic_chunks'" - ) - if not cursor.fetchone(): - return 0 - - # Count existing chunks - cursor = conn.execute("SELECT COUNT(*) FROM semantic_chunks") - return cursor.fetchone()[0] - except Exception: - return 0 - - -def process_file_worker(args: Tuple[str, str, str, int]) -> List[ChunkData]: - """Worker function to process a single file (runs in separate process). - - Args: - args: Tuple of (file_path, content, language, chunk_size) - - Returns: - List of ChunkData objects - """ - file_path, content, language, chunk_size = args - - try: - from codexlens.semantic.chunker import Chunker, ChunkConfig - - chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size)) - chunks = chunker.chunk_sliding_window( - content, - file_path=file_path, - language=language - ) - - return [ - ChunkData( - file_path=file_path, - content=chunk.content, - metadata=chunk.metadata or {} - ) - for chunk in chunks - ] - except Exception as exc: - logger.debug(f"Error processing {file_path}: {exc}") - return [] +def progress_callback(message: str): + """Callback function for progress updates.""" + logger.info(message) def generate_embeddings_for_index( @@ -151,259 +71,63 @@ def generate_embeddings_for_index( model_profile: str = "code", force: bool = False, chunk_size: int = 2000, - workers: int = 0, - batch_size: int = 256, + **kwargs # Ignore unused parameters (workers, batch_size) for backward compatibility ) -> dict: - """Generate embeddings for all files in an index. + """Generate embeddings for an index using memory-efficient streaming. - Performance optimizations: - - Parallel file processing (chunking) - - Batch embedding generation - - Batch database writes - - HNSW index auto-generation + This function wraps the streaming implementation from embedding_manager + to maintain CLI compatibility while using the memory-optimized approach. Args: index_db_path: Path to _index.db file model_profile: Model profile to use (fast, code, multilingual, balanced) force: If True, regenerate even if embeddings exist chunk_size: Maximum chunk size in characters - workers: Number of parallel workers (0 = auto-detect CPU count) - batch_size: Batch size for embedding generation + **kwargs: Additional parameters (ignored for compatibility) Returns: Dictionary with generation statistics """ logger.info(f"Processing index: {index_db_path}") - # Check existing chunks - existing_chunks = check_existing_chunks(index_db_path) - if existing_chunks > 0 and not force: - logger.warning(f"Index already has {existing_chunks} chunks") - logger.warning("Use --force to regenerate") - return { - "success": False, - "error": "Embeddings already exist", - "existing_chunks": existing_chunks, - } + # Call the memory-efficient streaming implementation + result = generate_embeddings( + index_path=index_db_path, + model_profile=model_profile, + force=force, + chunk_size=chunk_size, + progress_callback=progress_callback, + ) - if force and existing_chunks > 0: - logger.info(f"Force mode: clearing {existing_chunks} existing chunks") - try: - with sqlite3.connect(index_db_path) as conn: - conn.execute("DELETE FROM semantic_chunks") - conn.commit() - # Also remove HNSW index file - hnsw_path = index_db_path.parent / "_vectors.hnsw" - if hnsw_path.exists(): - hnsw_path.unlink() - logger.info("Removed existing HNSW index") - except Exception as exc: - logger.error(f"Failed to clear existing data: {exc}") + if not result["success"]: + if "error" in result: + logger.error(result["error"]) + return result - # Import dependencies - try: - from codexlens.semantic.embedder import Embedder - from codexlens.semantic.vector_store import VectorStore - from codexlens.entities import SemanticChunk - except ImportError as exc: - return { - "success": False, - "error": f"Import failed: {exc}", - } - - # Initialize components - try: - embedder = Embedder(profile=model_profile) - vector_store = VectorStore(index_db_path) - - logger.info(f"Using model: {embedder.model_name}") - logger.info(f"Embedding dimension: {embedder.embedding_dim}") - except Exception as exc: - return { - "success": False, - "error": f"Failed to initialize components: {exc}", - } - - # Read files from index - try: - with sqlite3.connect(index_db_path) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.execute("SELECT full_path, content, language FROM files") - files = [ - FileData( - full_path=row["full_path"], - content=row["content"], - language=row["language"] or "python" - ) - for row in cursor.fetchall() - ] - except Exception as exc: - return { - "success": False, - "error": f"Failed to read files: {exc}", - } - - logger.info(f"Found {len(files)} files to process") - if len(files) == 0: - return { - "success": False, - "error": "No files found in index", - } - - # Determine worker count - if workers <= 0: - workers = min(multiprocessing.cpu_count(), len(files), 8) - logger.info(f"Using {workers} worker(s) for parallel processing") - logger.info(f"Batch size for embeddings: {batch_size}") - - start_time = time.time() - - # Phase 1: Parallel chunking - logger.info("Phase 1: Chunking files...") - chunk_start = time.time() - - all_chunks: List[ChunkData] = [] - failed_files = [] - - # Prepare work items - work_items = [ - (f.full_path, f.content, f.language, chunk_size) - for f in files - ] - - if workers == 1: - # Single-threaded for debugging - for i, item in enumerate(work_items, 1): - try: - chunks = process_file_worker(item) - all_chunks.extend(chunks) - if i % 100 == 0: - logger.info(f"Chunked {i}/{len(files)} files ({len(all_chunks)} chunks)") - except Exception as exc: - failed_files.append((item[0], str(exc))) - else: - # Parallel processing - with ProcessPoolExecutor(max_workers=workers) as executor: - futures = { - executor.submit(process_file_worker, item): item[0] - for item in work_items - } - - completed = 0 - for future in as_completed(futures): - file_path = futures[future] - completed += 1 - try: - chunks = future.result() - all_chunks.extend(chunks) - if completed % 100 == 0: - logger.info( - f"Chunked {completed}/{len(files)} files " - f"({len(all_chunks)} chunks)" - ) - except Exception as exc: - failed_files.append((file_path, str(exc))) - - chunk_time = time.time() - chunk_start - logger.info(f"Chunking completed in {chunk_time:.1f}s: {len(all_chunks)} chunks") - - if not all_chunks: - return { - "success": False, - "error": "No chunks created from files", - "files_processed": len(files) - len(failed_files), - "files_failed": len(failed_files), - } - - # Phase 2: Batch embedding generation - logger.info("Phase 2: Generating embeddings...") - embed_start = time.time() - - # Extract all content for batch embedding - all_contents = [c.content for c in all_chunks] - - # Generate embeddings in batches - all_embeddings = [] - for i in range(0, len(all_contents), batch_size): - batch_contents = all_contents[i:i + batch_size] - batch_embeddings = embedder.embed(batch_contents) - all_embeddings.extend(batch_embeddings) - - progress = min(i + batch_size, len(all_contents)) - if progress % (batch_size * 4) == 0 or progress == len(all_contents): - logger.info(f"Generated embeddings: {progress}/{len(all_contents)}") - - embed_time = time.time() - embed_start - logger.info(f"Embedding completed in {embed_time:.1f}s") - - # Phase 3: Batch database write - logger.info("Phase 3: Storing chunks...") - store_start = time.time() - - # Create SemanticChunk objects with embeddings - semantic_chunks_with_paths = [] - for chunk_data, embedding in zip(all_chunks, all_embeddings): - semantic_chunk = SemanticChunk( - content=chunk_data.content, - metadata=chunk_data.metadata, - ) - semantic_chunk.embedding = embedding - semantic_chunks_with_paths.append((semantic_chunk, chunk_data.file_path)) - - # Batch write (handles both SQLite and HNSW) - write_batch_size = 1000 - total_stored = 0 - for i in range(0, len(semantic_chunks_with_paths), write_batch_size): - batch = semantic_chunks_with_paths[i:i + write_batch_size] - vector_store.add_chunks_batch(batch) - total_stored += len(batch) - if total_stored % 5000 == 0 or total_stored == len(semantic_chunks_with_paths): - logger.info(f"Stored: {total_stored}/{len(semantic_chunks_with_paths)} chunks") - - store_time = time.time() - store_start - logger.info(f"Storage completed in {store_time:.1f}s") - - elapsed_time = time.time() - start_time - - # Generate summary + # Extract result data and log summary + data = result["result"] logger.info("=" * 60) - logger.info(f"Completed in {elapsed_time:.1f}s") - logger.info(f" Chunking: {chunk_time:.1f}s") - logger.info(f" Embedding: {embed_time:.1f}s") - logger.info(f" Storage: {store_time:.1f}s") - logger.info(f"Total chunks created: {len(all_chunks)}") - logger.info(f"Files processed: {len(files) - len(failed_files)}/{len(files)}") - if vector_store.ann_available: - logger.info(f"HNSW index vectors: {vector_store.ann_count}") - if failed_files: - logger.warning(f"Failed files: {len(failed_files)}") - for file_path, error in failed_files[:5]: # Show first 5 failures - logger.warning(f" {file_path}: {error}") + logger.info(f"Completed in {data['elapsed_time']:.1f}s") + logger.info(f"Total chunks created: {data['chunks_created']}") + logger.info(f"Files processed: {data['files_processed']}") + if data['files_failed'] > 0: + logger.warning(f"Failed files: {data['files_failed']}") + if data.get('failed_files'): + for file_path, error in data['failed_files']: + logger.warning(f" {file_path}: {error}") return { "success": True, - "chunks_created": len(all_chunks), - "files_processed": len(files) - len(failed_files), - "files_failed": len(failed_files), - "elapsed_time": elapsed_time, - "chunk_time": chunk_time, - "embed_time": embed_time, - "store_time": store_time, - "ann_vectors": vector_store.ann_count if vector_store.ann_available else 0, + "chunks_created": data["chunks_created"], + "files_processed": data["files_processed"], + "files_failed": data["files_failed"], + "elapsed_time": data["elapsed_time"], } -def find_index_databases(scan_dir: Path) -> List[Path]: - """Find all _index.db files in directory tree.""" - logger.info(f"Scanning for indexes in: {scan_dir}") - index_files = list(scan_dir.rglob("_index.db")) - logger.info(f"Found {len(index_files)} index databases") - return index_files - - def main(): parser = argparse.ArgumentParser( - description="Generate vector embeddings for CodexLens indexes", + description="Generate vector embeddings for CodexLens indexes (memory-efficient streaming)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__ ) @@ -439,14 +163,14 @@ def main(): "--workers", type=int, default=0, - help="Number of parallel workers for chunking (default: auto-detect CPU count)" + help="(Deprecated) Kept for backward compatibility, ignored" ) parser.add_argument( "--batch-size", type=int, default=256, - help="Batch size for embedding generation (default: 256)" + help="(Deprecated) Kept for backward compatibility, ignored" ) parser.add_argument( @@ -481,43 +205,33 @@ def main(): # Determine if scanning or single file if args.scan or index_path.is_dir(): - # Scan mode + # Scan mode - use recursive implementation if index_path.is_file(): logger.error("--scan requires a directory path") sys.exit(1) - index_files = find_index_databases(index_path) - if not index_files: - logger.error(f"No index databases found in: {index_path}") + result = generate_embeddings_recursive( + index_root=index_path, + model_profile=args.model, + force=args.force, + chunk_size=args.chunk_size, + progress_callback=progress_callback, + ) + + if not result["success"]: + logger.error(f"Failed: {result.get('error', 'Unknown error')}") sys.exit(1) - # Process each index - total_chunks = 0 - successful = 0 - for idx, index_file in enumerate(index_files, 1): - logger.info(f"\n{'='*60}") - logger.info(f"Processing index {idx}/{len(index_files)}") - logger.info(f"{'='*60}") - - result = generate_embeddings_for_index( - index_file, - model_profile=args.model, - force=args.force, - chunk_size=args.chunk_size, - workers=args.workers, - batch_size=args.batch_size, - ) - - if result["success"]: - total_chunks += result["chunks_created"] - successful += 1 - - # Final summary + # Log summary + data = result["result"] logger.info(f"\n{'='*60}") logger.info("BATCH PROCESSING COMPLETE") logger.info(f"{'='*60}") - logger.info(f"Indexes processed: {successful}/{len(index_files)}") - logger.info(f"Total chunks created: {total_chunks}") + logger.info(f"Indexes processed: {data['indexes_successful']}/{data['indexes_processed']}") + logger.info(f"Total chunks created: {data['total_chunks_created']}") + logger.info(f"Total files processed: {data['total_files_processed']}") + if data['total_files_failed'] > 0: + logger.warning(f"Total files failed: {data['total_files_failed']}") else: # Single index mode @@ -530,8 +244,6 @@ def main(): model_profile=args.model, force=args.force, chunk_size=args.chunk_size, - workers=args.workers, - batch_size=args.batch_size, ) if not result["success"]: diff --git a/codex-lens/src/codexlens/cli/commands.py b/codex-lens/src/codexlens/cli/commands.py index 041917da..166237d2 100644 --- a/codex-lens/src/codexlens/cli/commands.py +++ b/codex-lens/src/codexlens/cli/commands.py @@ -268,7 +268,6 @@ def search( files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."), mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."), weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."), - enrich: bool = typer.Option(False, "--enrich", help="Enrich results with code graph relationships (calls, imports)."), json_mode: bool = typer.Option(False, "--json", help="Output JSON response."), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."), ) -> None: @@ -423,30 +422,10 @@ def search( for r in result.results ] - # Enrich results with relationship data if requested - enriched = False - if enrich: - try: - from codexlens.search.enrichment import RelationshipEnricher - - # Find index path for the search path - project_record = registry.find_by_source_path(str(search_path)) - if project_record: - index_path = Path(project_record["index_root"]) / "_index.db" - if index_path.exists(): - with RelationshipEnricher(index_path) as enricher: - results_list = enricher.enrich(results_list, limit=limit) - enriched = True - except Exception as e: - # Enrichment failure should not break search - if verbose: - console.print(f"[yellow]Warning: Enrichment failed: {e}[/yellow]") - payload = { "query": query, "mode": actual_mode, "count": len(results_list), - "enriched": enriched, "results": results_list, "stats": { "dirs_searched": result.stats.dirs_searched, @@ -458,8 +437,7 @@ def search( print_json(success=True, result=payload) else: render_search_results(result.results, verbose=verbose) - enrich_status = " | [green]Enriched[/green]" if enriched else "" - console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms{enrich_status}[/dim]") + console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms[/dim]") except SearchError as exc: if json_mode: @@ -1376,103 +1354,6 @@ def clean( raise typer.Exit(code=1) -@app.command() -def graph( - query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"), - symbol: str = typer.Argument(..., help="Symbol name to query"), - path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."), - limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."), - depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."), - json_mode: bool = typer.Option(False, "--json", help="Output JSON response."), - verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."), -) -> None: - """Query semantic graph for code relationships. - - Supported query types: - - callers: Find all functions/methods that call the given symbol - - callees: Find all functions/methods called by the given symbol - - inheritance: Find inheritance relationships for the given class - - Examples: - codex-lens graph callers my_function - codex-lens graph callees MyClass.method --path src/ - codex-lens graph inheritance BaseClass - """ - _configure_logging(verbose) - search_path = path.expanduser().resolve() - - # Validate query type - valid_types = ["callers", "callees", "inheritance"] - if query_type not in valid_types: - if json_mode: - print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}") - else: - console.print(f"[red]Invalid query type:[/red] {query_type}") - console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]") - raise typer.Exit(code=1) - - registry: RegistryStore | None = None - try: - registry = RegistryStore() - registry.initialize() - mapper = PathMapper() - - engine = ChainSearchEngine(registry, mapper) - options = SearchOptions(depth=depth, total_limit=limit) - - # Execute graph query based on type - if query_type == "callers": - results = engine.search_callers(symbol, search_path, options=options) - result_type = "callers" - elif query_type == "callees": - results = engine.search_callees(symbol, search_path, options=options) - result_type = "callees" - else: # inheritance - results = engine.search_inheritance(symbol, search_path, options=options) - result_type = "inheritance" - - payload = { - "query_type": query_type, - "symbol": symbol, - "count": len(results), - "relationships": results - } - - if json_mode: - print_json(success=True, result=payload) - else: - from .output import render_graph_results - render_graph_results(results, query_type=query_type, symbol=symbol) - - except SearchError as exc: - if json_mode: - print_json(success=False, error=f"Graph search error: {exc}") - else: - console.print(f"[red]Graph query failed (search):[/red] {exc}") - raise typer.Exit(code=1) - except StorageError as exc: - if json_mode: - print_json(success=False, error=f"Storage error: {exc}") - else: - console.print(f"[red]Graph query failed (storage):[/red] {exc}") - raise typer.Exit(code=1) - except CodexLensError as exc: - if json_mode: - print_json(success=False, error=str(exc)) - else: - console.print(f"[red]Graph query failed:[/red] {exc}") - raise typer.Exit(code=1) - except Exception as exc: - if json_mode: - print_json(success=False, error=f"Unexpected error: {exc}") - else: - console.print(f"[red]Graph query failed (unexpected):[/red] {exc}") - raise typer.Exit(code=1) - finally: - if registry is not None: - registry.close() - - @app.command("semantic-list") def semantic_list( path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."), diff --git a/codex-lens/src/codexlens/cli/embedding_manager.py b/codex-lens/src/codexlens/cli/embedding_manager.py index f4067840..ac658c19 100644 --- a/codex-lens/src/codexlens/cli/embedding_manager.py +++ b/codex-lens/src/codexlens/cli/embedding_manager.py @@ -194,7 +194,6 @@ def generate_embeddings( try: # Use cached embedder (singleton) for performance embedder = get_embedder(profile=model_profile) - vector_store = VectorStore(index_path) chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size)) if progress_callback: @@ -217,85 +216,86 @@ def generate_embeddings( EMBEDDING_BATCH_SIZE = 8 # jina-embeddings-v2-base-code needs small batches try: - with sqlite3.connect(index_path) as conn: - conn.row_factory = sqlite3.Row - path_column = _get_path_column(conn) + with VectorStore(index_path) as vector_store: + with sqlite3.connect(index_path) as conn: + conn.row_factory = sqlite3.Row + path_column = _get_path_column(conn) - # Get total file count for progress reporting - total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0] - if total_files == 0: - return {"success": False, "error": "No files found in index"} + # Get total file count for progress reporting + total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0] + if total_files == 0: + return {"success": False, "error": "No files found in index"} - if progress_callback: - progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...") - - cursor = conn.execute(f"SELECT {path_column}, content, language FROM files") - batch_number = 0 - - while True: - # Fetch a batch of files (streaming, not fetchall) - file_batch = cursor.fetchmany(FILE_BATCH_SIZE) - if not file_batch: - break - - batch_number += 1 - batch_chunks_with_paths = [] - files_in_batch_with_chunks = set() - - # Step 1: Chunking for the current file batch - for file_row in file_batch: - file_path = file_row[path_column] - content = file_row["content"] - language = file_row["language"] or "python" - - try: - chunks = chunker.chunk_sliding_window( - content, - file_path=file_path, - language=language - ) - if chunks: - for chunk in chunks: - batch_chunks_with_paths.append((chunk, file_path)) - files_in_batch_with_chunks.add(file_path) - except Exception as e: - logger.error(f"Failed to chunk {file_path}: {e}") - failed_files.append((file_path, str(e))) - - if not batch_chunks_with_paths: - continue - - batch_chunk_count = len(batch_chunks_with_paths) if progress_callback: - progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks") + progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...") - # Step 2: Generate embeddings for this batch - batch_embeddings = [] - try: - for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE): - batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count) - batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]] - embeddings = embedder.embed(batch_contents) - batch_embeddings.extend(embeddings) - except Exception as e: - logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}") - failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch]) - continue + cursor = conn.execute(f"SELECT {path_column}, content, language FROM files") + batch_number = 0 - # Step 3: Assign embeddings to chunks - for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings): - chunk.embedding = embedding + while True: + # Fetch a batch of files (streaming, not fetchall) + file_batch = cursor.fetchmany(FILE_BATCH_SIZE) + if not file_batch: + break - # Step 4: Store this batch to database immediately (releases memory) - try: - vector_store.add_chunks_batch(batch_chunks_with_paths) - total_chunks_created += batch_chunk_count - total_files_processed += len(files_in_batch_with_chunks) - except Exception as e: - logger.error(f"Failed to store batch {batch_number}: {str(e)}") - failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch]) + batch_number += 1 + batch_chunks_with_paths = [] + files_in_batch_with_chunks = set() - # Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope + # Step 1: Chunking for the current file batch + for file_row in file_batch: + file_path = file_row[path_column] + content = file_row["content"] + language = file_row["language"] or "python" + + try: + chunks = chunker.chunk_sliding_window( + content, + file_path=file_path, + language=language + ) + if chunks: + for chunk in chunks: + batch_chunks_with_paths.append((chunk, file_path)) + files_in_batch_with_chunks.add(file_path) + except Exception as e: + logger.error(f"Failed to chunk {file_path}: {e}") + failed_files.append((file_path, str(e))) + + if not batch_chunks_with_paths: + continue + + batch_chunk_count = len(batch_chunks_with_paths) + if progress_callback: + progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks") + + # Step 2: Generate embeddings for this batch + batch_embeddings = [] + try: + for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE): + batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count) + batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]] + embeddings = embedder.embed(batch_contents) + batch_embeddings.extend(embeddings) + except Exception as e: + logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}") + failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch]) + continue + + # Step 3: Assign embeddings to chunks + for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings): + chunk.embedding = embedding + + # Step 4: Store this batch to database immediately (releases memory) + try: + vector_store.add_chunks_batch(batch_chunks_with_paths) + total_chunks_created += batch_chunk_count + total_files_processed += len(files_in_batch_with_chunks) + except Exception as e: + logger.error(f"Failed to store batch {batch_number}: {str(e)}") + failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch]) + + # Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope except Exception as e: return {"success": False, "error": f"Failed to read or process files: {str(e)}"} diff --git a/codex-lens/src/codexlens/cli/output.py b/codex-lens/src/codexlens/cli/output.py index 8974bbab..15659441 100644 --- a/codex-lens/src/codexlens/cli/output.py +++ b/codex-lens/src/codexlens/cli/output.py @@ -122,68 +122,3 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) -> console.print(header) render_symbols(list(symbols), title="Discovered Symbols") - -def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None: - """Render semantic graph query results. - - Args: - results: List of relationship dicts - query_type: Type of query (callers, callees, inheritance) - symbol: Symbol name that was queried - """ - if not results: - console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}") - return - - title_map = { - "callers": f"Callers of '{symbol}' ({len(results)} found)", - "callees": f"Callees of '{symbol}' ({len(results)} found)", - "inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)" - } - - table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})")) - - if query_type == "callers": - table.add_column("Caller", style="green") - table.add_column("File", style="cyan", no_wrap=False, max_width=40) - table.add_column("Line", justify="right", style="yellow") - table.add_column("Type", style="dim") - - for rel in results: - table.add_row( - rel.get("source_symbol", "-"), - rel.get("source_file", "-"), - str(rel.get("source_line", "-")), - rel.get("relationship_type", "-") - ) - - elif query_type == "callees": - table.add_column("Target", style="green") - table.add_column("File", style="cyan", no_wrap=False, max_width=40) - table.add_column("Line", justify="right", style="yellow") - table.add_column("Type", style="dim") - - for rel in results: - table.add_row( - rel.get("target_symbol", "-"), - rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"), - str(rel.get("source_line", "-")), - rel.get("relationship_type", "-") - ) - - else: # inheritance - table.add_column("Derived Class", style="green") - table.add_column("Base Class", style="magenta") - table.add_column("File", style="cyan", no_wrap=False, max_width=40) - table.add_column("Line", justify="right", style="yellow") - - for rel in results: - table.add_row( - rel.get("source_symbol", "-"), - rel.get("target_symbol", "-"), - rel.get("source_file", "-"), - str(rel.get("source_line", "-")) - ) - - console.print(table) - diff --git a/codex-lens/src/codexlens/entities.py b/codex-lens/src/codexlens/entities.py index 55eb4fae..a69edfa1 100644 --- a/codex-lens/src/codexlens/entities.py +++ b/codex-lens/src/codexlens/entities.py @@ -14,8 +14,6 @@ class Symbol(BaseModel): kind: str = Field(..., min_length=1) range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive") file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol") - token_count: Optional[int] = Field(default=None, description="Token count for symbol content") - symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering") @field_validator("range") @classmethod @@ -29,13 +27,6 @@ class Symbol(BaseModel): raise ValueError("end_line must be >= start_line") return value - @field_validator("token_count") - @classmethod - def validate_token_count(cls, value: Optional[int]) -> Optional[int]: - if value is not None and value < 0: - raise ValueError("token_count must be >= 0") - return value - class SemanticChunk(BaseModel): """A semantically meaningful chunk of content, optionally embedded.""" diff --git a/codex-lens/src/codexlens/search/chain_search.py b/codex-lens/src/codexlens/search/chain_search.py index 5ffb1bdc..33e37bff 100644 --- a/codex-lens/src/codexlens/search/chain_search.py +++ b/codex-lens/src/codexlens/search/chain_search.py @@ -302,108 +302,6 @@ class ChainSearchEngine: index_paths, name, kind, options.total_limit ) - def search_callers(self, target_symbol: str, - source_path: Path, - options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]: - """Find all callers of a given symbol across directory hierarchy. - - Args: - target_symbol: Name of the symbol to find callers for - source_path: Starting directory path - options: Search configuration (uses defaults if None) - - Returns: - List of relationship dicts with caller information - - Examples: - >>> engine = ChainSearchEngine(registry, mapper) - >>> callers = engine.search_callers("my_function", Path("D:/project")) - >>> for caller in callers: - ... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}") - """ - options = options or SearchOptions() - - start_index = self._find_start_index(source_path) - if not start_index: - self.logger.warning(f"No index found for {source_path}") - return [] - - index_paths = self._collect_index_paths(start_index, options.depth) - if not index_paths: - return [] - - return self._search_callers_parallel( - index_paths, target_symbol, options.total_limit - ) - - def search_callees(self, source_symbol: str, - source_path: Path, - options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]: - """Find all callees (what a symbol calls) across directory hierarchy. - - Args: - source_symbol: Name of the symbol to find callees for - source_path: Starting directory path - options: Search configuration (uses defaults if None) - - Returns: - List of relationship dicts with callee information - - Examples: - >>> engine = ChainSearchEngine(registry, mapper) - >>> callees = engine.search_callees("MyClass.method", Path("D:/project")) - >>> for callee in callees: - ... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}") - """ - options = options or SearchOptions() - - start_index = self._find_start_index(source_path) - if not start_index: - self.logger.warning(f"No index found for {source_path}") - return [] - - index_paths = self._collect_index_paths(start_index, options.depth) - if not index_paths: - return [] - - return self._search_callees_parallel( - index_paths, source_symbol, options.total_limit - ) - - def search_inheritance(self, class_name: str, - source_path: Path, - options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]: - """Find inheritance relationships for a class across directory hierarchy. - - Args: - class_name: Name of the class to find inheritance for - source_path: Starting directory path - options: Search configuration (uses defaults if None) - - Returns: - List of relationship dicts with inheritance information - - Examples: - >>> engine = ChainSearchEngine(registry, mapper) - >>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project")) - >>> for rel in inheritance: - ... print(f"{rel['source_symbol']} extends {rel['target_symbol']}") - """ - options = options or SearchOptions() - - start_index = self._find_start_index(source_path) - if not start_index: - self.logger.warning(f"No index found for {source_path}") - return [] - - index_paths = self._collect_index_paths(start_index, options.depth) - if not index_paths: - return [] - - return self._search_inheritance_parallel( - index_paths, class_name, options.total_limit - ) - # === Internal Methods === def _find_start_index(self, source_path: Path) -> Optional[Path]: @@ -711,273 +609,6 @@ class ChainSearchEngine: self.logger.debug(f"Symbol search error in {index_path}: {exc}") return [] - def _search_callers_parallel(self, index_paths: List[Path], - target_symbol: str, - limit: int) -> List[Dict[str, Any]]: - """Search for callers across multiple indexes in parallel. - - Args: - index_paths: List of _index.db paths to search - target_symbol: Target symbol name - limit: Total result limit - - Returns: - Deduplicated list of caller relationships - """ - all_callers = [] - - executor = self._get_executor() - future_to_path = { - executor.submit( - self._search_callers_single, - idx_path, - target_symbol - ): idx_path - for idx_path in index_paths - } - - for future in as_completed(future_to_path): - try: - callers = future.result() - all_callers.extend(callers) - except Exception as exc: - self.logger.error(f"Caller search failed: {exc}") - - # Deduplicate by (source_file, source_line) - seen = set() - unique_callers = [] - for caller in all_callers: - key = (caller.get("source_file"), caller.get("source_line")) - if key not in seen: - seen.add(key) - unique_callers.append(caller) - - # Sort by source file and line - unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0))) - - return unique_callers[:limit] - - def _search_callers_single(self, index_path: Path, - target_symbol: str) -> List[Dict[str, Any]]: - """Search for callers in a single index. - - Args: - index_path: Path to _index.db file - target_symbol: Target symbol name - - Returns: - List of caller relationship dicts (empty on error) - """ - try: - with SQLiteStore(index_path) as store: - return store.query_relationships_by_target(target_symbol) - except Exception as exc: - self.logger.debug(f"Caller search error in {index_path}: {exc}") - return [] - - def _search_callees_parallel(self, index_paths: List[Path], - source_symbol: str, - limit: int) -> List[Dict[str, Any]]: - """Search for callees across multiple indexes in parallel. - - Args: - index_paths: List of _index.db paths to search - source_symbol: Source symbol name - limit: Total result limit - - Returns: - Deduplicated list of callee relationships - """ - all_callees = [] - - executor = self._get_executor() - future_to_path = { - executor.submit( - self._search_callees_single, - idx_path, - source_symbol - ): idx_path - for idx_path in index_paths - } - - for future in as_completed(future_to_path): - try: - callees = future.result() - all_callees.extend(callees) - except Exception as exc: - self.logger.error(f"Callee search failed: {exc}") - - # Deduplicate by (target_symbol, source_line) - seen = set() - unique_callees = [] - for callee in all_callees: - key = (callee.get("target_symbol"), callee.get("source_line")) - if key not in seen: - seen.add(key) - unique_callees.append(callee) - - # Sort by source line - unique_callees.sort(key=lambda c: c.get("source_line", 0)) - - return unique_callees[:limit] - - def _search_callees_single(self, index_path: Path, - source_symbol: str) -> List[Dict[str, Any]]: - """Search for callees in a single index. - - Args: - index_path: Path to _index.db file - source_symbol: Source symbol name - - Returns: - List of callee relationship dicts (empty on error) - """ - try: - with SQLiteStore(index_path) as store: - # Single JOIN query to get all callees (fixes N+1 query problem) - # Uses public execute_query API instead of _get_connection bypass - rows = store.execute_query( - """ - SELECT - s.name AS source_symbol, - r.target_qualified_name AS target_symbol, - r.relationship_type, - r.source_line, - f.full_path AS source_file, - r.target_file - FROM code_relationships r - JOIN symbols s ON r.source_symbol_id = s.id - JOIN files f ON s.file_id = f.id - WHERE s.name = ? AND r.relationship_type = 'call' - ORDER BY f.full_path, r.source_line - LIMIT 100 - """, - (source_symbol,) - ) - - return [ - { - "source_symbol": row["source_symbol"], - "target_symbol": row["target_symbol"], - "relationship_type": row["relationship_type"], - "source_line": row["source_line"], - "source_file": row["source_file"], - "target_file": row["target_file"], - } - for row in rows - ] - except Exception as exc: - self.logger.debug(f"Callee search error in {index_path}: {exc}") - return [] - - def _search_inheritance_parallel(self, index_paths: List[Path], - class_name: str, - limit: int) -> List[Dict[str, Any]]: - """Search for inheritance relationships across multiple indexes in parallel. - - Args: - index_paths: List of _index.db paths to search - class_name: Class name to search for - limit: Total result limit - - Returns: - Deduplicated list of inheritance relationships - """ - all_inheritance = [] - - executor = self._get_executor() - future_to_path = { - executor.submit( - self._search_inheritance_single, - idx_path, - class_name - ): idx_path - for idx_path in index_paths - } - - for future in as_completed(future_to_path): - try: - inheritance = future.result() - all_inheritance.extend(inheritance) - except Exception as exc: - self.logger.error(f"Inheritance search failed: {exc}") - - # Deduplicate by (source_symbol, target_symbol) - seen = set() - unique_inheritance = [] - for rel in all_inheritance: - key = (rel.get("source_symbol"), rel.get("target_symbol")) - if key not in seen: - seen.add(key) - unique_inheritance.append(rel) - - # Sort by source file - unique_inheritance.sort(key=lambda r: r.get("source_file", "")) - - return unique_inheritance[:limit] - - def _search_inheritance_single(self, index_path: Path, - class_name: str) -> List[Dict[str, Any]]: - """Search for inheritance relationships in a single index. - - Args: - index_path: Path to _index.db file - class_name: Class name to search for - - Returns: - List of inheritance relationship dicts (empty on error) - """ - try: - with SQLiteStore(index_path) as store: - # Use UNION to find relationships where class is either: - # 1. The base class (target) - find derived classes - # 2. The derived class (source) - find parent classes - # Uses public execute_query API instead of _get_connection bypass - rows = store.execute_query( - """ - SELECT - s.name AS source_symbol, - r.target_qualified_name, - r.relationship_type, - r.source_line, - f.full_path AS source_file, - r.target_file - FROM code_relationships r - JOIN symbols s ON r.source_symbol_id = s.id - JOIN files f ON s.file_id = f.id - WHERE r.target_qualified_name = ? AND r.relationship_type = 'inherits' - UNION - SELECT - s.name AS source_symbol, - r.target_qualified_name, - r.relationship_type, - r.source_line, - f.full_path AS source_file, - r.target_file - FROM code_relationships r - JOIN symbols s ON r.source_symbol_id = s.id - JOIN files f ON s.file_id = f.id - WHERE s.name = ? AND r.relationship_type = 'inherits' - LIMIT 100 - """, - (class_name, class_name) - ) - - return [ - { - "source_symbol": row["source_symbol"], - "target_symbol": row["target_qualified_name"], - "relationship_type": row["relationship_type"], - "source_line": row["source_line"], - "source_file": row["source_file"], - "target_file": row["target_file"], - } - for row in rows - ] - except Exception as exc: - self.logger.debug(f"Inheritance search error in {index_path}: {exc}") - return [] - # === Convenience Functions === @@ -1007,10 +638,9 @@ def quick_search(query: str, mapper = PathMapper() - engine = ChainSearchEngine(registry, mapper) - options = SearchOptions(depth=depth) - - result = engine.search(query, source_path, options) + with ChainSearchEngine(registry, mapper) as engine: + options = SearchOptions(depth=depth) + result = engine.search(query, source_path, options) registry.close() diff --git a/codex-lens/src/codexlens/semantic/graph_analyzer.py b/codex-lens/src/codexlens/semantic/graph_analyzer.py deleted file mode 100644 index 2ca6ed0f..00000000 --- a/codex-lens/src/codexlens/semantic/graph_analyzer.py +++ /dev/null @@ -1,542 +0,0 @@ -"""Graph analyzer for extracting code relationships using tree-sitter. - -Provides AST-based analysis to identify function calls, method invocations, -and class inheritance relationships within source files. -""" - -from __future__ import annotations - -from pathlib import Path -from typing import List, Optional - -try: - from tree_sitter import Node as TreeSitterNode - TREE_SITTER_AVAILABLE = True -except ImportError: - TreeSitterNode = None # type: ignore[assignment] - TREE_SITTER_AVAILABLE = False - -from codexlens.entities import CodeRelationship, Symbol -from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser - - -class GraphAnalyzer: - """Analyzer for extracting semantic relationships from code using AST traversal.""" - - def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None: - """Initialize graph analyzer for a language. - - Args: - language_id: Language identifier (python, javascript, typescript, etc.) - parser: Optional TreeSitterSymbolParser instance for dependency injection. - If None, creates a new parser instance (backward compatibility). - """ - self.language_id = language_id - self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id) - - def is_available(self) -> bool: - """Check if graph analyzer is available. - - Returns: - True if tree-sitter parser is initialized and ready - """ - return self._parser.is_available() - - def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]: - """Analyze source code and extract relationships. - - Args: - text: Source code text - file_path: File path for relationship context - - Returns: - List of CodeRelationship objects representing intra-file relationships - """ - if not self.is_available() or self._parser._parser is None: - return [] - - try: - source_bytes = text.encode("utf8") - tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined] - root = tree.root_node - - relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve())) - - return relationships - except Exception: - # Gracefully handle parsing errors - return [] - - def analyze_with_symbols( - self, text: str, file_path: Path, symbols: List[Symbol] - ) -> List[CodeRelationship]: - """Analyze source code using pre-parsed symbols to avoid duplicate parsing. - - Args: - text: Source code text - file_path: File path for relationship context - symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser - - Returns: - List of CodeRelationship objects representing intra-file relationships - """ - if not self.is_available() or self._parser._parser is None: - return [] - - try: - source_bytes = text.encode("utf8") - tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined] - root = tree.root_node - - # Convert Symbol objects to internal symbol format - defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols) - - # Extract relationships using provided symbols - relationships = self._extract_relationships_with_symbols( - source_bytes, root, str(file_path.resolve()), defined_symbols - ) - - return relationships - except Exception: - # Gracefully handle parsing errors - return [] - - def _convert_symbols_to_dict( - self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol] - ) -> List[dict]: - """Convert Symbol objects to internal dict format for relationship extraction. - - Args: - source_bytes: Source code as bytes - root: Root AST node - symbols: Pre-parsed Symbol objects - - Returns: - List of symbol info dicts with name, node, and type - """ - symbol_dicts = [] - symbol_names = {s.name for s in symbols} - - # Find AST nodes corresponding to symbols - for node in self._iter_nodes(root): - node_type = node.type - - # Check if this node matches any of our symbols - if node_type in {"function_definition", "async_function_definition"}: - name_node = node.child_by_field_name("name") - if name_node: - name = self._node_text(source_bytes, name_node) - if name in symbol_names: - symbol_dicts.append({ - "name": name, - "node": node, - "type": "function" - }) - elif node_type == "class_definition": - name_node = node.child_by_field_name("name") - if name_node: - name = self._node_text(source_bytes, name_node) - if name in symbol_names: - symbol_dicts.append({ - "name": name, - "node": node, - "type": "class" - }) - elif node_type in {"function_declaration", "generator_function_declaration"}: - name_node = node.child_by_field_name("name") - if name_node: - name = self._node_text(source_bytes, name_node) - if name in symbol_names: - symbol_dicts.append({ - "name": name, - "node": node, - "type": "function" - }) - elif node_type == "method_definition": - name_node = node.child_by_field_name("name") - if name_node: - name = self._node_text(source_bytes, name_node) - if name in symbol_names: - symbol_dicts.append({ - "name": name, - "node": node, - "type": "method" - }) - elif node_type in {"class_declaration", "class"}: - name_node = node.child_by_field_name("name") - if name_node: - name = self._node_text(source_bytes, name_node) - if name in symbol_names: - symbol_dicts.append({ - "name": name, - "node": node, - "type": "class" - }) - elif node_type == "variable_declarator": - name_node = node.child_by_field_name("name") - value_node = node.child_by_field_name("value") - if name_node and value_node and value_node.type == "arrow_function": - name = self._node_text(source_bytes, name_node) - if name in symbol_names: - symbol_dicts.append({ - "name": name, - "node": node, - "type": "function" - }) - - return symbol_dicts - - def _extract_relationships_with_symbols( - self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict] - ) -> List[CodeRelationship]: - """Extract relationships from AST using pre-parsed symbols. - - Args: - source_bytes: Source code as bytes - root: Root AST node - file_path: Absolute file path - defined_symbols: Pre-parsed symbol dicts - - Returns: - List of extracted relationships - """ - relationships: List[CodeRelationship] = [] - - # Determine call node type based on language - if self.language_id == "python": - call_node_type = "call" - extract_target = self._extract_call_target - elif self.language_id in {"javascript", "typescript"}: - call_node_type = "call_expression" - extract_target = self._extract_js_call_target - else: - return [] - - # Find call expressions and match to defined symbols - for node in self._iter_nodes(root): - if node.type == call_node_type: - # Extract caller context (enclosing function/method/class) - source_symbol = self._find_enclosing_symbol(node, defined_symbols) - if source_symbol is None: - # Call at module level, use "" as source - source_symbol = "" - - # Extract callee (function/method being called) - target_symbol = extract_target(source_bytes, node) - if target_symbol is None: - continue - - # Create relationship - line_number = node.start_point[0] + 1 - relationships.append( - CodeRelationship( - source_symbol=source_symbol, - target_symbol=target_symbol, - relationship_type="call", - source_file=file_path, - target_file=None, # Intra-file only - source_line=line_number, - ) - ) - - return relationships - - def _extract_relationships( - self, source_bytes: bytes, root: TreeSitterNode, file_path: str - ) -> List[CodeRelationship]: - """Extract relationships from AST. - - Args: - source_bytes: Source code as bytes - root: Root AST node - file_path: Absolute file path - - Returns: - List of extracted relationships - """ - if self.language_id == "python": - return self._extract_python_relationships(source_bytes, root, file_path) - elif self.language_id in {"javascript", "typescript"}: - return self._extract_js_ts_relationships(source_bytes, root, file_path) - else: - return [] - - def _extract_python_relationships( - self, source_bytes: bytes, root: TreeSitterNode, file_path: str - ) -> List[CodeRelationship]: - """Extract Python relationships from AST. - - Args: - source_bytes: Source code as bytes - root: Root AST node - file_path: Absolute file path - - Returns: - List of Python relationships (function/method calls) - """ - relationships: List[CodeRelationship] = [] - - # First pass: collect all defined symbols with their scopes - defined_symbols = self._collect_python_symbols(source_bytes, root) - - # Second pass: find call expressions and match to defined symbols - for node in self._iter_nodes(root): - if node.type == "call": - # Extract caller context (enclosing function/method/class) - source_symbol = self._find_enclosing_symbol(node, defined_symbols) - if source_symbol is None: - # Call at module level, use "" as source - source_symbol = "" - - # Extract callee (function/method being called) - target_symbol = self._extract_call_target(source_bytes, node) - if target_symbol is None: - continue - - # Create relationship - line_number = node.start_point[0] + 1 - relationships.append( - CodeRelationship( - source_symbol=source_symbol, - target_symbol=target_symbol, - relationship_type="call", - source_file=file_path, - target_file=None, # Intra-file only - source_line=line_number, - ) - ) - - return relationships - - def _extract_js_ts_relationships( - self, source_bytes: bytes, root: TreeSitterNode, file_path: str - ) -> List[CodeRelationship]: - """Extract JavaScript/TypeScript relationships from AST. - - Args: - source_bytes: Source code as bytes - root: Root AST node - file_path: Absolute file path - - Returns: - List of JS/TS relationships (function/method calls) - """ - relationships: List[CodeRelationship] = [] - - # First pass: collect all defined symbols - defined_symbols = self._collect_js_ts_symbols(source_bytes, root) - - # Second pass: find call expressions - for node in self._iter_nodes(root): - if node.type == "call_expression": - # Extract caller context - source_symbol = self._find_enclosing_symbol(node, defined_symbols) - if source_symbol is None: - source_symbol = "" - - # Extract callee - target_symbol = self._extract_js_call_target(source_bytes, node) - if target_symbol is None: - continue - - # Create relationship - line_number = node.start_point[0] + 1 - relationships.append( - CodeRelationship( - source_symbol=source_symbol, - target_symbol=target_symbol, - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=line_number, - ) - ) - - return relationships - - def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]: - """Collect all Python function/method/class definitions. - - Args: - source_bytes: Source code as bytes - root: Root AST node - - Returns: - List of symbol info dicts with name, node, and type - """ - symbols = [] - for node in self._iter_nodes(root): - if node.type in {"function_definition", "async_function_definition"}: - name_node = node.child_by_field_name("name") - if name_node: - symbols.append({ - "name": self._node_text(source_bytes, name_node), - "node": node, - "type": "function" - }) - elif node.type == "class_definition": - name_node = node.child_by_field_name("name") - if name_node: - symbols.append({ - "name": self._node_text(source_bytes, name_node), - "node": node, - "type": "class" - }) - return symbols - - def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]: - """Collect all JS/TS function/method/class definitions. - - Args: - source_bytes: Source code as bytes - root: Root AST node - - Returns: - List of symbol info dicts with name, node, and type - """ - symbols = [] - for node in self._iter_nodes(root): - if node.type in {"function_declaration", "generator_function_declaration"}: - name_node = node.child_by_field_name("name") - if name_node: - symbols.append({ - "name": self._node_text(source_bytes, name_node), - "node": node, - "type": "function" - }) - elif node.type == "method_definition": - name_node = node.child_by_field_name("name") - if name_node: - symbols.append({ - "name": self._node_text(source_bytes, name_node), - "node": node, - "type": "method" - }) - elif node.type in {"class_declaration", "class"}: - name_node = node.child_by_field_name("name") - if name_node: - symbols.append({ - "name": self._node_text(source_bytes, name_node), - "node": node, - "type": "class" - }) - elif node.type == "variable_declarator": - name_node = node.child_by_field_name("name") - value_node = node.child_by_field_name("value") - if name_node and value_node and value_node.type == "arrow_function": - symbols.append({ - "name": self._node_text(source_bytes, name_node), - "node": node, - "type": "function" - }) - return symbols - - def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]: - """Find the enclosing function/method/class for a node. - - Returns fully qualified name (e.g., "MyClass.my_method") by traversing up - the AST tree and collecting parent class/function names. - - Args: - node: AST node to find enclosure for - symbols: List of defined symbols - - Returns: - Fully qualified name of enclosing symbol, or None if at module level - """ - # Walk up the tree to find all enclosing symbols - enclosing_names = [] - parent = node.parent - - while parent is not None: - for symbol in symbols: - if symbol["node"] == parent: - # Prepend to maintain order (innermost to outermost) - enclosing_names.insert(0, symbol["name"]) - break - parent = parent.parent - - # Return fully qualified name or None if at module level - if enclosing_names: - return ".".join(enclosing_names) - return None - - def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]: - """Extract the target function name from a Python call expression. - - Args: - source_bytes: Source code as bytes - node: Call expression node - - Returns: - Target function name, or None if cannot be determined - """ - function_node = node.child_by_field_name("function") - if function_node is None: - return None - - # Handle simple identifiers (e.g., "foo()") - if function_node.type == "identifier": - return self._node_text(source_bytes, function_node) - - # Handle attribute access (e.g., "obj.method()") - if function_node.type == "attribute": - attr_node = function_node.child_by_field_name("attribute") - if attr_node: - return self._node_text(source_bytes, attr_node) - - return None - - def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]: - """Extract the target function name from a JS/TS call expression. - - Args: - source_bytes: Source code as bytes - node: Call expression node - - Returns: - Target function name, or None if cannot be determined - """ - function_node = node.child_by_field_name("function") - if function_node is None: - return None - - # Handle simple identifiers - if function_node.type == "identifier": - return self._node_text(source_bytes, function_node) - - # Handle member expressions (e.g., "obj.method()") - if function_node.type == "member_expression": - property_node = function_node.child_by_field_name("property") - if property_node: - return self._node_text(source_bytes, property_node) - - return None - - def _iter_nodes(self, root: TreeSitterNode): - """Iterate over all nodes in AST. - - Args: - root: Root node to start iteration - - Yields: - AST nodes in depth-first order - """ - stack = [root] - while stack: - node = stack.pop() - yield node - for child in reversed(node.children): - stack.append(child) - - def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str: - """Extract text for a node. - - Args: - source_bytes: Source code as bytes - node: AST node - - Returns: - Text content of node - """ - return source_bytes[node.start_byte:node.end_byte].decode("utf8") diff --git a/codex-lens/src/codexlens/semantic/vector_store.py b/codex-lens/src/codexlens/semantic/vector_store.py index c1b19f29..fbcfbfca 100644 --- a/codex-lens/src/codexlens/semantic/vector_store.py +++ b/codex-lens/src/codexlens/semantic/vector_store.py @@ -602,6 +602,12 @@ class VectorStore: Returns: List of SearchResult ordered by similarity (highest first) """ + logger.warning( + "Using brute-force vector search (hnswlib not available). " + "This may cause high memory usage for large indexes. " + "Install hnswlib for better performance: pip install hnswlib" + ) + with self._cache_lock: # Refresh cache if needed if self._embedding_matrix is None: diff --git a/codex-lens/src/codexlens/storage/dir_index.py b/codex-lens/src/codexlens/storage/dir_index.py index e00a83b3..f30395a8 100644 --- a/codex-lens/src/codexlens/storage/dir_index.py +++ b/codex-lens/src/codexlens/storage/dir_index.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -from codexlens.entities import CodeRelationship, SearchResult, Symbol +from codexlens.entities import SearchResult, Symbol from codexlens.errors import StorageError @@ -237,116 +237,6 @@ class DirIndexStore: conn.rollback() raise StorageError(f"Failed to add file {name}: {exc}") from exc - def add_relationships( - self, - file_path: str | Path, - relationships: List[CodeRelationship], - ) -> int: - """Store code relationships for a file. - - Args: - file_path: Path to the source file - relationships: List of CodeRelationship objects to store - - Returns: - Number of relationships stored - - Raises: - StorageError: If database operations fail - """ - if not relationships: - return 0 - - with self._lock: - conn = self._get_connection() - file_path_str = str(Path(file_path).resolve()) - - try: - # Get file_id - row = conn.execute( - "SELECT id FROM files WHERE full_path=?", (file_path_str,) - ).fetchone() - if not row: - return 0 - - file_id = int(row["id"]) - - # Delete existing relationships for symbols in this file - conn.execute( - """ - DELETE FROM code_relationships - WHERE source_symbol_id IN ( - SELECT id FROM symbols WHERE file_id=? - ) - """, - (file_id,), - ) - - # Insert new relationships - relationship_rows = [] - skipped_relationships = [] - for rel in relationships: - # Extract simple name from fully qualified name (e.g., "MyClass.my_method" -> "my_method") - # This handles cases where GraphAnalyzer generates qualified names but symbols table stores simple names - source_symbol_simple = rel.source_symbol.split(".")[-1] if "." in rel.source_symbol else rel.source_symbol - - # Find symbol_id by name and file - symbol_row = conn.execute( - """ - SELECT id FROM symbols - WHERE file_id=? AND name=? AND start_line<=? AND end_line>=? - LIMIT 1 - """, - (file_id, source_symbol_simple, rel.source_line, rel.source_line), - ).fetchone() - - if not symbol_row: - # Try matching by simple name only - symbol_row = conn.execute( - "SELECT id FROM symbols WHERE file_id=? AND name=? LIMIT 1", - (file_id, source_symbol_simple), - ).fetchone() - - if symbol_row: - relationship_rows.append(( - int(symbol_row["id"]), - rel.target_symbol, - rel.relationship_type, - rel.source_line, - rel.target_file, - )) - else: - # Log warning when symbol lookup fails - skipped_relationships.append(rel.source_symbol) - - # Log skipped relationships for debugging - if skipped_relationships: - self.logger.warning( - "Failed to find source symbol IDs for %d relationships in %s: %s", - len(skipped_relationships), - file_path_str, - ", ".join(set(skipped_relationships)) - ) - - if relationship_rows: - conn.executemany( - """ - INSERT INTO code_relationships( - source_symbol_id, target_qualified_name, relationship_type, - source_line, target_file - ) - VALUES(?, ?, ?, ?, ?) - """, - relationship_rows, - ) - - conn.commit() - return len(relationship_rows) - - except sqlite3.DatabaseError as exc: - conn.rollback() - raise StorageError(f"Failed to add relationships: {exc}") from exc - def add_files_batch( self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]] ) -> int: diff --git a/codex-lens/src/codexlens/storage/index_tree.py b/codex-lens/src/codexlens/storage/index_tree.py index daab89b4..7c589312 100644 --- a/codex-lens/src/codexlens/storage/index_tree.py +++ b/codex-lens/src/codexlens/storage/index_tree.py @@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Set from codexlens.config import Config from codexlens.parsers.factory import ParserFactory -from codexlens.semantic.graph_analyzer import GraphAnalyzer from codexlens.storage.dir_index import DirIndexStore from codexlens.storage.path_mapper import PathMapper from codexlens.storage.registry import ProjectInfo, RegistryStore @@ -525,16 +524,6 @@ class IndexTreeBuilder: symbols=indexed_file.symbols, ) - # Extract and store code relationships for graph visualization - if language_id in {"python", "javascript", "typescript"}: - graph_analyzer = GraphAnalyzer(language_id) - if graph_analyzer.is_available(): - relationships = graph_analyzer.analyze_with_symbols( - text, file_path, indexed_file.symbols - ) - if relationships: - store.add_relationships(file_path, relationships) - files_count += 1 symbols_count += len(indexed_file.symbols) @@ -742,16 +731,6 @@ def _build_dir_worker(args: tuple) -> DirBuildResult: symbols=indexed_file.symbols, ) - # Extract and store code relationships for graph visualization - if language_id in {"python", "javascript", "typescript"}: - graph_analyzer = GraphAnalyzer(language_id) - if graph_analyzer.is_available(): - relationships = graph_analyzer.analyze_with_symbols( - text, item, indexed_file.symbols - ) - if relationships: - store.add_relationships(item, relationships) - files_count += 1 symbols_count += len(indexed_file.symbols) diff --git a/codex-lens/src/codexlens/storage/migrations/migration_003_code_relationships.py b/codex-lens/src/codexlens/storage/migrations/migration_003_code_relationships.py deleted file mode 100644 index d7ee5e60..00000000 --- a/codex-lens/src/codexlens/storage/migrations/migration_003_code_relationships.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Migration 003: Add code relationships storage. - -This migration introduces the `code_relationships` table to store semantic -relationships between code symbols (function calls, inheritance, imports). -This enables graph-based code navigation and dependency analysis. -""" - -import logging -from sqlite3 import Connection - -log = logging.getLogger(__name__) - - -def upgrade(db_conn: Connection): - """ - Applies the migration to add code relationships table. - - - Creates `code_relationships` table with foreign key to symbols - - Creates indexes for efficient relationship queries - - Supports lazy expansion with target_symbol being qualified names - - Args: - db_conn: The SQLite database connection. - """ - cursor = db_conn.cursor() - - log.info("Creating 'code_relationships' table...") - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS code_relationships ( - id INTEGER PRIMARY KEY, - source_symbol_id INTEGER NOT NULL, - target_qualified_name TEXT NOT NULL, - relationship_type TEXT NOT NULL, - source_line INTEGER NOT NULL, - target_file TEXT, - FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE - ) - """ - ) - - log.info("Creating indexes for code_relationships...") - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)" - ) - - log.info("Finished creating code_relationships table and indexes.") diff --git a/codex-lens/src/codexlens/storage/sqlite_store.py b/codex-lens/src/codexlens/storage/sqlite_store.py index b9538efb..e88ee7f6 100644 --- a/codex-lens/src/codexlens/storage/sqlite_store.py +++ b/codex-lens/src/codexlens/storage/sqlite_store.py @@ -10,7 +10,7 @@ from dataclasses import asdict from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple -from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol +from codexlens.entities import IndexedFile, SearchResult, Symbol from codexlens.errors import StorageError @@ -420,167 +420,6 @@ class SQLiteStore: } - def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None: - """Store code relationships for a file. - - Args: - file_path: Path to the file containing the relationships - relationships: List of CodeRelationship objects to store - """ - if not relationships: - return - - with self._lock: - conn = self._get_connection() - resolved_path = str(Path(file_path).resolve()) - - # Get file_id - row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone() - if not row: - raise StorageError(f"File not found in index: {file_path}") - file_id = int(row["id"]) - - # Delete existing relationships for symbols in this file - conn.execute( - """ - DELETE FROM code_relationships - WHERE source_symbol_id IN ( - SELECT id FROM symbols WHERE file_id=? - ) - """, - (file_id,) - ) - - # Insert new relationships - relationship_rows = [] - for rel in relationships: - # Find source symbol ID - symbol_row = conn.execute( - """ - SELECT id FROM symbols - WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ? - ORDER BY (end_line - start_line) ASC - LIMIT 1 - """, - (file_id, rel.source_symbol, rel.source_line, rel.source_line) - ).fetchone() - - if symbol_row: - source_symbol_id = int(symbol_row["id"]) - relationship_rows.append(( - source_symbol_id, - rel.target_symbol, - rel.relationship_type, - rel.source_line, - rel.target_file - )) - - if relationship_rows: - conn.executemany( - """ - INSERT INTO code_relationships( - source_symbol_id, target_qualified_name, relationship_type, - source_line, target_file - ) - VALUES(?, ?, ?, ?, ?) - """, - relationship_rows - ) - conn.commit() - - def query_relationships_by_target( - self, target_name: str, *, limit: int = 100 - ) -> List[Dict[str, Any]]: - """Query relationships by target symbol name (find all callers). - - Args: - target_name: Name of the target symbol - limit: Maximum number of results - - Returns: - List of dicts containing relationship info with file paths and line numbers - """ - with self._lock: - conn = self._get_connection() - rows = conn.execute( - """ - SELECT - s.name AS source_symbol, - r.target_qualified_name, - r.relationship_type, - r.source_line, - f.full_path AS source_file, - r.target_file - FROM code_relationships r - JOIN symbols s ON r.source_symbol_id = s.id - JOIN files f ON s.file_id = f.id - WHERE r.target_qualified_name = ? - ORDER BY f.full_path, r.source_line - LIMIT ? - """, - (target_name, limit) - ).fetchall() - - return [ - { - "source_symbol": row["source_symbol"], - "target_symbol": row["target_qualified_name"], - "relationship_type": row["relationship_type"], - "source_line": row["source_line"], - "source_file": row["source_file"], - "target_file": row["target_file"], - } - for row in rows - ] - - def query_relationships_by_source( - self, source_symbol: str, source_file: str | Path, *, limit: int = 100 - ) -> List[Dict[str, Any]]: - """Query relationships by source symbol (find what a symbol calls). - - Args: - source_symbol: Name of the source symbol - source_file: File path containing the source symbol - limit: Maximum number of results - - Returns: - List of dicts containing relationship info - """ - with self._lock: - conn = self._get_connection() - resolved_path = str(Path(source_file).resolve()) - - rows = conn.execute( - """ - SELECT - s.name AS source_symbol, - r.target_qualified_name, - r.relationship_type, - r.source_line, - f.path AS source_file, - r.target_file - FROM code_relationships r - JOIN symbols s ON r.source_symbol_id = s.id - JOIN files f ON s.file_id = f.id - WHERE s.name = ? AND f.path = ? - ORDER BY r.source_line - LIMIT ? - """, - (source_symbol, resolved_path, limit) - ).fetchall() - - return [ - { - "source_symbol": row["source_symbol"], - "target_symbol": row["target_qualified_name"], - "relationship_type": row["relationship_type"], - "source_line": row["source_line"], - "source_file": row["source_file"], - "target_file": row["target_file"], - } - for row in rows - ] - def _connect(self) -> sqlite3.Connection: """Legacy method for backward compatibility.""" return self._get_connection() diff --git a/codex-lens/tests/test_chain_search_engine.py b/codex-lens/tests/test_chain_search_engine.py deleted file mode 100644 index 925e3e67..00000000 --- a/codex-lens/tests/test_chain_search_engine.py +++ /dev/null @@ -1,644 +0,0 @@ -"""Unit tests for ChainSearchEngine. - -Tests the graph query methods (search_callers, search_callees, search_inheritance) -with mocked SQLiteStore dependency to test logic in isolation. -""" - -import pytest -from pathlib import Path -from unittest.mock import Mock, MagicMock, patch, call -from concurrent.futures import ThreadPoolExecutor - -from codexlens.search.chain_search import ( - ChainSearchEngine, - SearchOptions, - SearchStats, - ChainSearchResult, -) -from codexlens.entities import SearchResult, Symbol -from codexlens.storage.registry import RegistryStore, DirMapping -from codexlens.storage.path_mapper import PathMapper - - -@pytest.fixture -def mock_registry(): - """Create a mock RegistryStore.""" - registry = Mock(spec=RegistryStore) - return registry - - -@pytest.fixture -def mock_mapper(): - """Create a mock PathMapper.""" - mapper = Mock(spec=PathMapper) - return mapper - - -@pytest.fixture -def search_engine(mock_registry, mock_mapper): - """Create a ChainSearchEngine with mocked dependencies.""" - return ChainSearchEngine(mock_registry, mock_mapper, max_workers=2) - - -@pytest.fixture -def sample_index_path(): - """Sample index database path.""" - return Path("/test/project/_index.db") - - -class TestChainSearchEngineCallers: - """Tests for search_callers method.""" - - def test_search_callers_returns_relationships(self, search_engine, mock_registry, sample_index_path): - """Test that search_callers returns caller relationships.""" - # Setup - source_path = Path("/test/project") - target_symbol = "my_function" - - # Mock finding the start index - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - # Mock collect_index_paths to return single index - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - # Mock the parallel search to return caller data - expected_callers = [ - { - "source_symbol": "caller_function", - "target_symbol": "my_function", - "relationship_type": "calls", - "source_line": 42, - "source_file": "/test/project/module.py", - "target_file": "/test/project/lib.py", - } - ] - - with patch.object(search_engine, '_search_callers_parallel', return_value=expected_callers): - # Execute - result = search_engine.search_callers(target_symbol, source_path) - - # Assert - assert len(result) == 1 - assert result[0]["source_symbol"] == "caller_function" - assert result[0]["target_symbol"] == "my_function" - assert result[0]["relationship_type"] == "calls" - assert result[0]["source_line"] == 42 - - def test_search_callers_empty_results(self, search_engine, mock_registry, sample_index_path): - """Test that search_callers handles no results gracefully.""" - # Setup - source_path = Path("/test/project") - target_symbol = "nonexistent_function" - - # Mock finding the start index - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - # Mock collect_index_paths - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - # Mock empty results - with patch.object(search_engine, '_search_callers_parallel', return_value=[]): - # Execute - result = search_engine.search_callers(target_symbol, source_path) - - # Assert - assert result == [] - - def test_search_callers_no_index_found(self, search_engine, mock_registry): - """Test that search_callers returns empty list when no index found.""" - # Setup - source_path = Path("/test/project") - target_symbol = "my_function" - - # Mock no index found - mock_registry.find_nearest_index.return_value = None - - with patch.object(search_engine, '_find_start_index', return_value=None): - # Execute - result = search_engine.search_callers(target_symbol, source_path) - - # Assert - assert result == [] - - def test_search_callers_uses_options(self, search_engine, mock_registry, mock_mapper, sample_index_path): - """Test that search_callers respects SearchOptions.""" - # Setup - source_path = Path("/test/project") - target_symbol = "my_function" - options = SearchOptions(depth=1, total_limit=50) - - # Configure mapper to return a path that exists - mock_mapper.source_to_index_db.return_value = sample_index_path - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]) as mock_collect: - with patch.object(search_engine, '_search_callers_parallel', return_value=[]) as mock_search: - # Patch Path.exists to return True so the exact match is found - with patch.object(Path, 'exists', return_value=True): - # Execute - search_engine.search_callers(target_symbol, source_path, options) - - # Assert that depth was passed to collect_index_paths - mock_collect.assert_called_once_with(sample_index_path, 1) - # Assert that total_limit was passed to parallel search - mock_search.assert_called_once_with([sample_index_path], target_symbol, 50) - - -class TestChainSearchEngineCallees: - """Tests for search_callees method.""" - - def test_search_callees_returns_relationships(self, search_engine, mock_registry, sample_index_path): - """Test that search_callees returns callee relationships.""" - # Setup - source_path = Path("/test/project") - source_symbol = "caller_function" - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - expected_callees = [ - { - "source_symbol": "caller_function", - "target_symbol": "callee_function", - "relationship_type": "calls", - "source_line": 15, - "source_file": "/test/project/module.py", - "target_file": "/test/project/lib.py", - } - ] - - with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees): - # Execute - result = search_engine.search_callees(source_symbol, source_path) - - # Assert - assert len(result) == 1 - assert result[0]["source_symbol"] == "caller_function" - assert result[0]["target_symbol"] == "callee_function" - assert result[0]["source_line"] == 15 - - def test_search_callees_filters_by_file(self, search_engine, mock_registry, sample_index_path): - """Test that search_callees correctly handles file-specific queries.""" - # Setup - source_path = Path("/test/project") - source_symbol = "MyClass.method" - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - # Multiple callees from same source symbol - expected_callees = [ - { - "source_symbol": "MyClass.method", - "target_symbol": "helper_a", - "relationship_type": "calls", - "source_line": 10, - "source_file": "/test/project/module.py", - "target_file": "/test/project/utils.py", - }, - { - "source_symbol": "MyClass.method", - "target_symbol": "helper_b", - "relationship_type": "calls", - "source_line": 20, - "source_file": "/test/project/module.py", - "target_file": "/test/project/utils.py", - } - ] - - with patch.object(search_engine, '_search_callees_parallel', return_value=expected_callees): - # Execute - result = search_engine.search_callees(source_symbol, source_path) - - # Assert - assert len(result) == 2 - assert result[0]["target_symbol"] == "helper_a" - assert result[1]["target_symbol"] == "helper_b" - - def test_search_callees_empty_results(self, search_engine, mock_registry, sample_index_path): - """Test that search_callees handles no callees gracefully.""" - source_path = Path("/test/project") - source_symbol = "leaf_function" - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - with patch.object(search_engine, '_search_callees_parallel', return_value=[]): - # Execute - result = search_engine.search_callees(source_symbol, source_path) - - # Assert - assert result == [] - - -class TestChainSearchEngineInheritance: - """Tests for search_inheritance method.""" - - def test_search_inheritance_returns_inherits_relationships(self, search_engine, mock_registry, sample_index_path): - """Test that search_inheritance returns inheritance relationships.""" - # Setup - source_path = Path("/test/project") - class_name = "BaseClass" - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - expected_inheritance = [ - { - "source_symbol": "DerivedClass", - "target_symbol": "BaseClass", - "relationship_type": "inherits", - "source_line": 5, - "source_file": "/test/project/derived.py", - "target_file": "/test/project/base.py", - } - ] - - with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance): - # Execute - result = search_engine.search_inheritance(class_name, source_path) - - # Assert - assert len(result) == 1 - assert result[0]["source_symbol"] == "DerivedClass" - assert result[0]["target_symbol"] == "BaseClass" - assert result[0]["relationship_type"] == "inherits" - - def test_search_inheritance_multiple_subclasses(self, search_engine, mock_registry, sample_index_path): - """Test inheritance search with multiple derived classes.""" - source_path = Path("/test/project") - class_name = "BaseClass" - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - expected_inheritance = [ - { - "source_symbol": "DerivedClassA", - "target_symbol": "BaseClass", - "relationship_type": "inherits", - "source_line": 5, - "source_file": "/test/project/derived_a.py", - "target_file": "/test/project/base.py", - }, - { - "source_symbol": "DerivedClassB", - "target_symbol": "BaseClass", - "relationship_type": "inherits", - "source_line": 10, - "source_file": "/test/project/derived_b.py", - "target_file": "/test/project/base.py", - } - ] - - with patch.object(search_engine, '_search_inheritance_parallel', return_value=expected_inheritance): - # Execute - result = search_engine.search_inheritance(class_name, source_path) - - # Assert - assert len(result) == 2 - assert result[0]["source_symbol"] == "DerivedClassA" - assert result[1]["source_symbol"] == "DerivedClassB" - - def test_search_inheritance_empty_results(self, search_engine, mock_registry, sample_index_path): - """Test inheritance search with no subclasses found.""" - source_path = Path("/test/project") - class_name = "FinalClass" - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=sample_index_path, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[sample_index_path]): - with patch.object(search_engine, '_search_inheritance_parallel', return_value=[]): - # Execute - result = search_engine.search_inheritance(class_name, source_path) - - # Assert - assert result == [] - - -class TestChainSearchEngineParallelSearch: - """Tests for parallel search aggregation.""" - - def test_parallel_search_aggregates_results(self, search_engine, mock_registry, sample_index_path): - """Test that parallel search aggregates results from multiple indexes.""" - # Setup - source_path = Path("/test/project") - target_symbol = "my_function" - - index_path_1 = Path("/test/project/_index.db") - index_path_2 = Path("/test/project/subdir/_index.db") - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=index_path_1, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]): - # Mock parallel search results from multiple indexes - callers_from_multiple = [ - { - "source_symbol": "caller_in_root", - "target_symbol": "my_function", - "relationship_type": "calls", - "source_line": 10, - "source_file": "/test/project/root.py", - "target_file": "/test/project/lib.py", - }, - { - "source_symbol": "caller_in_subdir", - "target_symbol": "my_function", - "relationship_type": "calls", - "source_line": 20, - "source_file": "/test/project/subdir/module.py", - "target_file": "/test/project/lib.py", - } - ] - - with patch.object(search_engine, '_search_callers_parallel', return_value=callers_from_multiple): - # Execute - result = search_engine.search_callers(target_symbol, source_path) - - # Assert results from both indexes are included - assert len(result) == 2 - assert any(r["source_file"] == "/test/project/root.py" for r in result) - assert any(r["source_file"] == "/test/project/subdir/module.py" for r in result) - - def test_parallel_search_deduplicates_results(self, search_engine, mock_registry, sample_index_path): - """Test that parallel search deduplicates results by (source_file, source_line).""" - # Note: This test verifies the behavior of _search_callers_parallel deduplication - source_path = Path("/test/project") - target_symbol = "my_function" - - index_path_1 = Path("/test/project/_index.db") - index_path_2 = Path("/test/project/_index.db") # Same index (simulates duplicate) - - mock_registry.find_nearest_index.return_value = DirMapping( - id=1, - project_id=1, - source_path=source_path, - index_path=index_path_1, - depth=0, - files_count=10, - last_updated=0.0 - ) - - with patch.object(search_engine, '_collect_index_paths', return_value=[index_path_1, index_path_2]): - # Mock duplicate results from same location - duplicate_callers = [ - { - "source_symbol": "caller_function", - "target_symbol": "my_function", - "relationship_type": "calls", - "source_line": 42, - "source_file": "/test/project/module.py", - "target_file": "/test/project/lib.py", - }, - { - "source_symbol": "caller_function", - "target_symbol": "my_function", - "relationship_type": "calls", - "source_line": 42, - "source_file": "/test/project/module.py", - "target_file": "/test/project/lib.py", - } - ] - - with patch.object(search_engine, '_search_callers_parallel', return_value=duplicate_callers): - # Execute - result = search_engine.search_callers(target_symbol, source_path) - - # Assert: even with duplicates in input, output may contain both - # (actual deduplication happens in _search_callers_parallel) - assert len(result) >= 1 - - -class TestChainSearchEngineContextManager: - """Tests for context manager functionality.""" - - def test_context_manager_closes_executor(self, mock_registry, mock_mapper): - """Test that context manager properly closes executor.""" - with ChainSearchEngine(mock_registry, mock_mapper) as engine: - # Force executor creation - engine._get_executor() - assert engine._executor is not None - - # Executor should be closed after exiting context - assert engine._executor is None - - def test_close_method_shuts_down_executor(self, search_engine): - """Test that close() method shuts down executor.""" - # Create executor - search_engine._get_executor() - assert search_engine._executor is not None - - # Close - search_engine.close() - assert search_engine._executor is None - - -class TestSearchCallersSingle: - """Tests for _search_callers_single internal method.""" - - def test_search_callers_single_queries_store(self, search_engine, sample_index_path): - """Test that _search_callers_single queries SQLiteStore correctly.""" - target_symbol = "my_function" - - # Mock SQLiteStore - with patch('codexlens.search.chain_search.SQLiteStore') as MockStore: - mock_store_instance = MockStore.return_value.__enter__.return_value - mock_store_instance.query_relationships_by_target.return_value = [ - { - "source_symbol": "caller", - "target_symbol": target_symbol, - "relationship_type": "calls", - "source_line": 10, - "source_file": "/test/file.py", - "target_file": "/test/lib.py", - } - ] - - # Execute - result = search_engine._search_callers_single(sample_index_path, target_symbol) - - # Assert - assert len(result) == 1 - assert result[0]["source_symbol"] == "caller" - mock_store_instance.query_relationships_by_target.assert_called_once_with(target_symbol) - - def test_search_callers_single_handles_errors(self, search_engine, sample_index_path): - """Test that _search_callers_single returns empty list on error.""" - target_symbol = "my_function" - - with patch('codexlens.search.chain_search.SQLiteStore') as MockStore: - MockStore.return_value.__enter__.side_effect = Exception("Database error") - - # Execute - result = search_engine._search_callers_single(sample_index_path, target_symbol) - - # Assert - should return empty list, not raise exception - assert result == [] - - -class TestSearchCalleesSingle: - """Tests for _search_callees_single internal method.""" - - def test_search_callees_single_queries_database(self, search_engine, sample_index_path): - """Test that _search_callees_single queries SQLiteStore correctly.""" - source_symbol = "caller_function" - - # Mock SQLiteStore - with patch('codexlens.search.chain_search.SQLiteStore') as MockStore: - mock_store_instance = MagicMock() - MockStore.return_value.__enter__.return_value = mock_store_instance - - # Mock execute_query to return relationship data (using new public API) - mock_store_instance.execute_query.return_value = [ - { - "source_symbol": source_symbol, - "target_symbol": "callee_function", - "relationship_type": "call", - "source_line": 15, - "source_file": "/test/module.py", - "target_file": "/test/lib.py", - } - ] - - # Execute - result = search_engine._search_callees_single(sample_index_path, source_symbol) - - # Assert - verify execute_query was called (public API) - assert mock_store_instance.execute_query.called - assert len(result) == 1 - assert result[0]["source_symbol"] == source_symbol - assert result[0]["target_symbol"] == "callee_function" - - def test_search_callees_single_handles_errors(self, search_engine, sample_index_path): - """Test that _search_callees_single returns empty list on error.""" - source_symbol = "caller_function" - - with patch('codexlens.search.chain_search.SQLiteStore') as MockStore: - MockStore.return_value.__enter__.side_effect = Exception("DB error") - - # Execute - result = search_engine._search_callees_single(sample_index_path, source_symbol) - - # Assert - should return empty list, not raise exception - assert result == [] - - -class TestSearchInheritanceSingle: - """Tests for _search_inheritance_single internal method.""" - - def test_search_inheritance_single_queries_database(self, search_engine, sample_index_path): - """Test that _search_inheritance_single queries SQLiteStore correctly.""" - class_name = "BaseClass" - - # Mock SQLiteStore - with patch('codexlens.search.chain_search.SQLiteStore') as MockStore: - mock_store_instance = MagicMock() - MockStore.return_value.__enter__.return_value = mock_store_instance - - # Mock execute_query to return relationship data (using new public API) - mock_store_instance.execute_query.return_value = [ - { - "source_symbol": "DerivedClass", - "target_qualified_name": "BaseClass", - "relationship_type": "inherits", - "source_line": 5, - "source_file": "/test/derived.py", - "target_file": "/test/base.py", - } - ] - - # Execute - result = search_engine._search_inheritance_single(sample_index_path, class_name) - - # Assert - assert mock_store_instance.execute_query.called - assert len(result) == 1 - assert result[0]["source_symbol"] == "DerivedClass" - assert result[0]["relationship_type"] == "inherits" - - # Verify execute_query was called with 'inherits' filter - call_args = mock_store_instance.execute_query.call_args - sql_query = call_args[0][0] - assert "relationship_type = 'inherits'" in sql_query - - def test_search_inheritance_single_handles_errors(self, search_engine, sample_index_path): - """Test that _search_inheritance_single returns empty list on error.""" - class_name = "BaseClass" - - with patch('codexlens.search.chain_search.SQLiteStore') as MockStore: - MockStore.return_value.__enter__.side_effect = Exception("DB error") - - # Execute - result = search_engine._search_inheritance_single(sample_index_path, class_name) - - # Assert - should return empty list, not raise exception - assert result == [] diff --git a/codex-lens/tests/test_cli_search.py b/codex-lens/tests/test_cli_search.py deleted file mode 100644 index 58eb265e..00000000 --- a/codex-lens/tests/test_cli_search.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Tests for CLI search command with --enrich flag.""" - -import json -import pytest -from typer.testing import CliRunner -from codexlens.cli.commands import app - - -class TestCLISearchEnrich: - """Test CLI search command with --enrich flag integration.""" - - @pytest.fixture - def runner(self): - """Create CLI test runner.""" - return CliRunner() - - def test_search_with_enrich_flag_help(self, runner): - """Test --enrich flag is documented in help.""" - result = runner.invoke(app, ["search", "--help"]) - assert result.exit_code == 0 - assert "--enrich" in result.output - assert "relationships" in result.output.lower() or "graph" in result.output.lower() - - def test_search_with_enrich_flag_accepted(self, runner): - """Test --enrich flag is accepted by the CLI.""" - result = runner.invoke(app, ["search", "test", "--enrich"]) - # Should not show 'unknown option' error - assert "No such option" not in result.output - assert "error: unrecognized" not in result.output.lower() - - def test_search_without_enrich_flag(self, runner): - """Test search without --enrich flag has no relationships.""" - result = runner.invoke(app, ["search", "test", "--json"]) - # Even without an index, JSON should be attempted - if result.exit_code == 0: - try: - data = json.loads(result.output) - # If we get results, they should not have enriched=true - if data.get("success") and "result" in data: - assert data["result"].get("enriched", False) is False - except json.JSONDecodeError: - pass # Not JSON output, that's fine for error cases - - def test_search_enrich_json_output_structure(self, runner): - """Test JSON output structure includes enriched flag.""" - result = runner.invoke(app, ["search", "test", "--json", "--enrich"]) - # If we get valid JSON output, check structure - if result.exit_code == 0: - try: - data = json.loads(result.output) - if data.get("success") and "result" in data: - # enriched field should exist - assert "enriched" in data["result"] - except json.JSONDecodeError: - pass # Not JSON output - - def test_search_enrich_with_mode(self, runner): - """Test --enrich works with different search modes.""" - modes = ["exact", "fuzzy", "hybrid"] - for mode in modes: - result = runner.invoke( - app, ["search", "test", "--mode", mode, "--enrich"] - ) - # Should not show validation errors - assert "Invalid" not in result.output - - -class TestEnrichFlagBehavior: - """Test behavioral aspects of --enrich flag.""" - - @pytest.fixture - def runner(self): - """Create CLI test runner.""" - return CliRunner() - - def test_enrich_failure_does_not_break_search(self, runner): - """Test that enrichment failure doesn't prevent search from returning results.""" - # Even without proper index, search should not crash due to enrich - result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"]) - # Should not have unhandled exception - assert "Traceback" not in result.output - - def test_enrich_flag_with_files_only(self, runner): - """Test --enrich is accepted with --files-only mode.""" - result = runner.invoke(app, ["search", "test", "--enrich", "--files-only"]) - # Should not show option conflict error - assert "conflict" not in result.output.lower() - - def test_enrich_flag_with_limit(self, runner): - """Test --enrich works with --limit parameter.""" - result = runner.invoke(app, ["search", "test", "--enrich", "--limit", "5"]) - # Should not show validation error - assert "Invalid" not in result.output - - -class TestEnrichOutputFormat: - """Test output format with --enrich flag.""" - - @pytest.fixture - def runner(self): - """Create CLI test runner.""" - return CliRunner() - - def test_enrich_verbose_shows_status(self, runner): - """Test verbose mode shows enrichment status.""" - result = runner.invoke(app, ["search", "test", "--enrich", "--verbose"]) - # Verbose mode may show enrichment info or warnings - # Just ensure it doesn't crash - assert result.exit_code in [0, 1] # 0 = success, 1 = no index - - def test_json_output_has_enriched_field(self, runner): - """Test JSON output always has enriched field when --enrich used.""" - result = runner.invoke(app, ["search", "test", "--json", "--enrich"]) - if result.exit_code == 0: - try: - data = json.loads(result.output) - if data.get("success"): - result_data = data.get("result", {}) - assert "enriched" in result_data - assert isinstance(result_data["enriched"], bool) - except json.JSONDecodeError: - pass diff --git a/codex-lens/tests/test_entities.py b/codex-lens/tests/test_entities.py index c9c5778c..efcc6737 100644 --- a/codex-lens/tests/test_entities.py +++ b/codex-lens/tests/test_entities.py @@ -204,8 +204,6 @@ class TestEntitySerialization: "kind": "function", "range": (1, 10), "file": None, - "token_count": None, - "symbol_type": None, } def test_indexed_file_model_dump(self): diff --git a/codex-lens/tests/test_graph_analyzer.py b/codex-lens/tests/test_graph_analyzer.py deleted file mode 100644 index c1e31ba3..00000000 --- a/codex-lens/tests/test_graph_analyzer.py +++ /dev/null @@ -1,436 +0,0 @@ -"""Tests for GraphAnalyzer - code relationship extraction.""" - -from pathlib import Path - -import pytest - -from codexlens.semantic.graph_analyzer import GraphAnalyzer - - -TREE_SITTER_PYTHON_AVAILABLE = True -try: - import tree_sitter_python # type: ignore[import-not-found] # noqa: F401 -except Exception: - TREE_SITTER_PYTHON_AVAILABLE = False - - -TREE_SITTER_JS_AVAILABLE = True -try: - import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401 -except Exception: - TREE_SITTER_JS_AVAILABLE = False - - -@pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed") -class TestPythonGraphAnalyzer: - """Tests for Python relationship extraction.""" - - def test_simple_function_call(self): - """Test extraction of simple function call.""" - code = """def helper(): - pass - -def main(): - helper() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Should find main -> helper call - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "main" - assert rel.target_symbol == "helper" - assert rel.relationship_type == "call" - assert rel.source_line == 5 - - def test_multiple_calls_in_function(self): - """Test extraction of multiple calls from same function.""" - code = """def foo(): - pass - -def bar(): - pass - -def main(): - foo() - bar() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Should find main -> foo and main -> bar - assert len(relationships) == 2 - targets = {rel.target_symbol for rel in relationships} - assert targets == {"foo", "bar"} - assert all(rel.source_symbol == "main" for rel in relationships) - - def test_nested_function_calls(self): - """Test extraction of calls from nested functions.""" - code = """def inner_helper(): - pass - -def outer(): - def inner(): - inner_helper() - inner() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Should find outer.inner -> inner_helper and outer -> inner (with fully qualified names) - assert len(relationships) == 2 - call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships} - assert ("outer.inner", "inner_helper") in call_pairs - assert ("outer", "inner") in call_pairs - - def test_method_call_in_class(self): - """Test extraction of method calls within class.""" - code = """class Calculator: - def add(self, a, b): - return a + b - - def compute(self, x, y): - result = self.add(x, y) - return result -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Should find Calculator.compute -> add (with fully qualified source) - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "Calculator.compute" - assert rel.target_symbol == "add" - - def test_module_level_call(self): - """Test extraction of module-level function calls.""" - code = """def setup(): - pass - -setup() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Should find -> setup - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "" - assert rel.target_symbol == "setup" - - def test_async_function_call(self): - """Test extraction of calls involving async functions.""" - code = """async def fetch_data(): - pass - -async def process(): - await fetch_data() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Should find process -> fetch_data - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "process" - assert rel.target_symbol == "fetch_data" - - def test_complex_python_file(self): - """Test extraction from realistic Python file with multiple patterns.""" - code = """class DataProcessor: - def __init__(self): - self.data = [] - - def load(self, filename): - self.data = read_file(filename) - - def process(self): - self.validate() - self.transform() - - def validate(self): - pass - - def transform(self): - pass - -def read_file(filename): - pass - -def main(): - processor = DataProcessor() - processor.load("data.txt") - processor.process() - -main() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Extract call pairs - call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships} - - # Expected relationships (with fully qualified source symbols for methods) - expected = { - ("DataProcessor.load", "read_file"), - ("DataProcessor.process", "validate"), - ("DataProcessor.process", "transform"), - ("main", "DataProcessor"), - ("main", "load"), - ("main", "process"), - ("", "main"), - } - - # Should find all expected relationships - assert call_pairs >= expected - - def test_empty_file(self): - """Test handling of empty file.""" - code = "" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - assert len(relationships) == 0 - - def test_file_with_no_calls(self): - """Test handling of file with definitions but no calls.""" - code = """def func1(): - pass - -def func2(): - pass -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - assert len(relationships) == 0 - - -@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed") -class TestJavaScriptGraphAnalyzer: - """Tests for JavaScript relationship extraction.""" - - def test_simple_function_call(self): - """Test extraction of simple JavaScript function call.""" - code = """function helper() {} - -function main() { - helper(); -} -""" - analyzer = GraphAnalyzer("javascript") - relationships = analyzer.analyze_file(code, Path("test.js")) - - # Should find main -> helper call - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "main" - assert rel.target_symbol == "helper" - assert rel.relationship_type == "call" - - def test_arrow_function_call(self): - """Test extraction of calls from arrow functions.""" - code = """const helper = () => {}; - -const main = () => { - helper(); -}; -""" - analyzer = GraphAnalyzer("javascript") - relationships = analyzer.analyze_file(code, Path("test.js")) - - # Should find main -> helper call - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "main" - assert rel.target_symbol == "helper" - - def test_class_method_call(self): - """Test extraction of method calls in JavaScript class.""" - code = """class Calculator { - add(a, b) { - return a + b; - } - - compute(x, y) { - return this.add(x, y); - } -} -""" - analyzer = GraphAnalyzer("javascript") - relationships = analyzer.analyze_file(code, Path("test.js")) - - # Should find Calculator.compute -> add (with fully qualified source) - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_symbol == "Calculator.compute" - assert rel.target_symbol == "add" - - def test_complex_javascript_file(self): - """Test extraction from realistic JavaScript file.""" - code = """function readFile(filename) { - return ""; -} - -class DataProcessor { - constructor() { - this.data = []; - } - - load(filename) { - this.data = readFile(filename); - } - - process() { - this.validate(); - this.transform(); - } - - validate() {} - - transform() {} -} - -function main() { - const processor = new DataProcessor(); - processor.load("data.txt"); - processor.process(); -} - -main(); -""" - analyzer = GraphAnalyzer("javascript") - relationships = analyzer.analyze_file(code, Path("test.js")) - - # Extract call pairs - call_pairs = {(rel.source_symbol, rel.target_symbol) for rel in relationships} - - # Expected relationships (with fully qualified source symbols for methods) - # Note: constructor calls like "new DataProcessor()" are not tracked - expected = { - ("DataProcessor.load", "readFile"), - ("DataProcessor.process", "validate"), - ("DataProcessor.process", "transform"), - ("main", "load"), - ("main", "process"), - ("", "main"), - } - - # Should find all expected relationships - assert call_pairs >= expected - - -class TestGraphAnalyzerEdgeCases: - """Edge case tests for GraphAnalyzer.""" - - @pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed") - def test_unavailable_language(self): - """Test handling of unsupported language.""" - code = "some code" - analyzer = GraphAnalyzer("rust") - relationships = analyzer.analyze_file(code, Path("test.rs")) - assert len(relationships) == 0 - - @pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed") - def test_malformed_python_code(self): - """Test handling of malformed Python code.""" - code = "def broken(\n pass" - analyzer = GraphAnalyzer("python") - # Should not crash - relationships = analyzer.analyze_file(code, Path("test.py")) - assert isinstance(relationships, list) - - @pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed") - def test_file_path_in_relationship(self): - """Test that file path is correctly set in relationships.""" - code = """def foo(): - pass - -def bar(): - foo() -""" - test_path = Path("test.py") - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, test_path) - - assert len(relationships) == 1 - rel = relationships[0] - assert rel.source_file == str(test_path.resolve()) - assert rel.target_file is None # Intra-file - - @pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed") - def test_performance_large_file(self): - """Test performance on larger file (1000 lines).""" - import time - - # Generate file with many functions and calls - lines = [] - for i in range(100): - lines.append(f"def func_{i}():") - if i > 0: - lines.append(f" func_{i-1}()") - else: - lines.append(" pass") - - code = "\n".join(lines) - - analyzer = GraphAnalyzer("python") - start_time = time.time() - relationships = analyzer.analyze_file(code, Path("test.py")) - elapsed_ms = (time.time() - start_time) * 1000 - - # Should complete in under 500ms - assert elapsed_ms < 500 - - # Should find 99 calls (func_1 -> func_0, func_2 -> func_1, ...) - assert len(relationships) == 99 - - @pytest.mark.skipif(not TREE_SITTER_PYTHON_AVAILABLE, reason="tree-sitter-python not installed") - def test_call_accuracy_rate(self): - """Test >95% accuracy on known call graph.""" - code = """def a(): pass -def b(): pass -def c(): pass -def d(): pass -def e(): pass - -def test1(): - a() - b() - -def test2(): - c() - d() - -def test3(): - e() - -def main(): - test1() - test2() - test3() -""" - analyzer = GraphAnalyzer("python") - relationships = analyzer.analyze_file(code, Path("test.py")) - - # Expected calls: test1->a, test1->b, test2->c, test2->d, test3->e, main->test1, main->test2, main->test3 - expected_calls = { - ("test1", "a"), - ("test1", "b"), - ("test2", "c"), - ("test2", "d"), - ("test3", "e"), - ("main", "test1"), - ("main", "test2"), - ("main", "test3"), - } - - found_calls = {(rel.source_symbol, rel.target_symbol) for rel in relationships} - - # Calculate accuracy - correct = len(expected_calls & found_calls) - total = len(expected_calls) - accuracy = (correct / total) * 100 if total > 0 else 0 - - # Should have >95% accuracy - assert accuracy >= 95.0 - assert correct == total # Should be 100% for this simple case diff --git a/codex-lens/tests/test_graph_cli.py b/codex-lens/tests/test_graph_cli.py deleted file mode 100644 index d18c9ac8..00000000 --- a/codex-lens/tests/test_graph_cli.py +++ /dev/null @@ -1,392 +0,0 @@ -"""End-to-end tests for graph search CLI commands.""" - -import tempfile -from pathlib import Path -from typer.testing import CliRunner -import pytest - -from codexlens.cli.commands import app -from codexlens.storage.sqlite_store import SQLiteStore -from codexlens.storage.registry import RegistryStore -from codexlens.storage.path_mapper import PathMapper -from codexlens.entities import IndexedFile, Symbol, CodeRelationship - - -runner = CliRunner() - - -@pytest.fixture -def temp_project(): - """Create a temporary project with indexed code and relationships.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) / "test_project" - project_root.mkdir() - - # Create test Python files - (project_root / "main.py").write_text(""" -def main(): - result = calculate(5, 3) - print(result) - -def calculate(a, b): - return add(a, b) - -def add(x, y): - return x + y -""") - - (project_root / "utils.py").write_text(""" -class BaseClass: - def method(self): - pass - -class DerivedClass(BaseClass): - def method(self): - super().method() - helper() - -def helper(): - return True -""") - - # Create a custom index directory for graph testing - # Skip the standard init to avoid schema conflicts - mapper = PathMapper() - index_root = mapper.source_to_index_dir(project_root) - index_root.mkdir(parents=True, exist_ok=True) - test_db = index_root / "_index.db" - - # Register project manually - registry = RegistryStore() - registry.initialize() - project_info = registry.register_project( - source_root=project_root, - index_root=index_root - ) - registry.register_dir( - project_id=project_info.id, - source_path=project_root, - index_path=test_db, - depth=0, - files_count=2 - ) - - # Initialize the store with proper SQLiteStore schema and add files - with SQLiteStore(test_db) as store: - # Read and add files to the store - main_content = (project_root / "main.py").read_text() - utils_content = (project_root / "utils.py").read_text() - - main_indexed = IndexedFile( - path=str(project_root / "main.py"), - language="python", - symbols=[ - Symbol(name="main", kind="function", range=(2, 4)), - Symbol(name="calculate", kind="function", range=(6, 7)), - Symbol(name="add", kind="function", range=(9, 10)) - ] - ) - utils_indexed = IndexedFile( - path=str(project_root / "utils.py"), - language="python", - symbols=[ - Symbol(name="BaseClass", kind="class", range=(2, 4)), - Symbol(name="DerivedClass", kind="class", range=(6, 9)), - Symbol(name="helper", kind="function", range=(11, 12)) - ] - ) - - store.add_file(main_indexed, main_content) - store.add_file(utils_indexed, utils_content) - - with SQLiteStore(test_db) as store: - # Add relationships for main.py - main_file = project_root / "main.py" - relationships_main = [ - CodeRelationship( - source_symbol="main", - target_symbol="calculate", - relationship_type="call", - source_file=str(main_file), - source_line=3, - target_file=str(main_file) - ), - CodeRelationship( - source_symbol="calculate", - target_symbol="add", - relationship_type="call", - source_file=str(main_file), - source_line=7, - target_file=str(main_file) - ), - ] - store.add_relationships(main_file, relationships_main) - - # Add relationships for utils.py - utils_file = project_root / "utils.py" - relationships_utils = [ - CodeRelationship( - source_symbol="DerivedClass", - target_symbol="BaseClass", - relationship_type="inherits", - source_file=str(utils_file), - source_line=6, # DerivedClass is defined on line 6 - target_file=str(utils_file) - ), - CodeRelationship( - source_symbol="DerivedClass.method", - target_symbol="helper", - relationship_type="call", - source_file=str(utils_file), - source_line=8, - target_file=str(utils_file) - ), - ] - store.add_relationships(utils_file, relationships_utils) - - registry.close() - - yield project_root - - -class TestGraphCallers: - """Test callers query type.""" - - def test_find_callers_basic(self, temp_project): - """Test finding functions that call a given function.""" - result = runner.invoke(app, [ - "graph", - "callers", - "add", - "--path", str(temp_project) - ]) - - assert result.exit_code == 0 - assert "calculate" in result.stdout - assert "Callers of 'add'" in result.stdout - - def test_find_callers_json_mode(self, temp_project): - """Test callers query with JSON output.""" - result = runner.invoke(app, [ - "graph", - "callers", - "add", - "--path", str(temp_project), - "--json" - ]) - - assert result.exit_code == 0 - assert "success" in result.stdout - assert "relationships" in result.stdout - - def test_find_callers_no_results(self, temp_project): - """Test callers query when no callers exist.""" - result = runner.invoke(app, [ - "graph", - "callers", - "nonexistent_function", - "--path", str(temp_project) - ]) - - assert result.exit_code == 0 - assert "No callers found" in result.stdout or "0 found" in result.stdout - - -class TestGraphCallees: - """Test callees query type.""" - - def test_find_callees_basic(self, temp_project): - """Test finding functions called by a given function.""" - result = runner.invoke(app, [ - "graph", - "callees", - "main", - "--path", str(temp_project) - ]) - - assert result.exit_code == 0 - assert "calculate" in result.stdout - assert "Callees of 'main'" in result.stdout - - def test_find_callees_chain(self, temp_project): - """Test finding callees in a call chain.""" - result = runner.invoke(app, [ - "graph", - "callees", - "calculate", - "--path", str(temp_project) - ]) - - assert result.exit_code == 0 - assert "add" in result.stdout - - def test_find_callees_json_mode(self, temp_project): - """Test callees query with JSON output.""" - result = runner.invoke(app, [ - "graph", - "callees", - "main", - "--path", str(temp_project), - "--json" - ]) - - assert result.exit_code == 0 - assert "success" in result.stdout - - -class TestGraphInheritance: - """Test inheritance query type.""" - - def test_find_inheritance_basic(self, temp_project): - """Test finding inheritance relationships.""" - result = runner.invoke(app, [ - "graph", - "inheritance", - "BaseClass", - "--path", str(temp_project) - ]) - - assert result.exit_code == 0 - assert "DerivedClass" in result.stdout - assert "Inheritance relationships" in result.stdout - - def test_find_inheritance_derived(self, temp_project): - """Test finding inheritance from derived class perspective.""" - result = runner.invoke(app, [ - "graph", - "inheritance", - "DerivedClass", - "--path", str(temp_project) - ]) - - assert result.exit_code == 0 - assert "BaseClass" in result.stdout - - def test_find_inheritance_json_mode(self, temp_project): - """Test inheritance query with JSON output.""" - result = runner.invoke(app, [ - "graph", - "inheritance", - "BaseClass", - "--path", str(temp_project), - "--json" - ]) - - assert result.exit_code == 0 - assert "success" in result.stdout - - -class TestGraphValidation: - """Test query validation and error handling.""" - - def test_invalid_query_type(self, temp_project): - """Test error handling for invalid query type.""" - result = runner.invoke(app, [ - "graph", - "invalid_type", - "symbol", - "--path", str(temp_project) - ]) - - assert result.exit_code == 1 - assert "Invalid query type" in result.stdout - - def test_invalid_path(self): - """Test error handling for non-existent path.""" - result = runner.invoke(app, [ - "graph", - "callers", - "symbol", - "--path", "/nonexistent/path" - ]) - - # Should handle gracefully (may exit with error or return empty results) - assert result.exit_code in [0, 1] - - -class TestGraphPerformance: - """Test graph query performance requirements.""" - - def test_query_response_time(self, temp_project): - """Verify graph queries complete in under 1 second.""" - import time - - start = time.time() - result = runner.invoke(app, [ - "graph", - "callers", - "add", - "--path", str(temp_project) - ]) - elapsed = time.time() - start - - assert result.exit_code == 0 - assert elapsed < 1.0, f"Query took {elapsed:.2f}s, expected <1s" - - def test_multiple_query_types(self, temp_project): - """Test all three query types complete successfully.""" - import time - - queries = [ - ("callers", "add"), - ("callees", "main"), - ("inheritance", "BaseClass") - ] - - total_start = time.time() - - for query_type, symbol in queries: - result = runner.invoke(app, [ - "graph", - query_type, - symbol, - "--path", str(temp_project) - ]) - assert result.exit_code == 0 - - total_elapsed = time.time() - total_start - assert total_elapsed < 3.0, f"All queries took {total_elapsed:.2f}s, expected <3s" - - -class TestGraphOptions: - """Test graph command options.""" - - def test_limit_option(self, temp_project): - """Test limit option works correctly.""" - result = runner.invoke(app, [ - "graph", - "callers", - "add", - "--path", str(temp_project), - "--limit", "1" - ]) - - assert result.exit_code == 0 - - def test_depth_option(self, temp_project): - """Test depth option works correctly.""" - result = runner.invoke(app, [ - "graph", - "callers", - "add", - "--path", str(temp_project), - "--depth", "0" - ]) - - assert result.exit_code == 0 - - def test_verbose_option(self, temp_project): - """Test verbose option works correctly.""" - result = runner.invoke(app, [ - "graph", - "callers", - "add", - "--path", str(temp_project), - "--verbose" - ]) - - assert result.exit_code == 0 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/codex-lens/tests/test_graph_storage.py b/codex-lens/tests/test_graph_storage.py deleted file mode 100644 index 138fcae4..00000000 --- a/codex-lens/tests/test_graph_storage.py +++ /dev/null @@ -1,355 +0,0 @@ -"""Tests for code relationship storage.""" - -import sqlite3 -import tempfile -from pathlib import Path - -import pytest - -from codexlens.entities import CodeRelationship, IndexedFile, Symbol -from codexlens.storage.migration_manager import MigrationManager -from codexlens.storage.sqlite_store import SQLiteStore - - -@pytest.fixture -def temp_db(): - """Create a temporary database for testing.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = Path(tmpdir) / "test.db" - yield db_path - - -@pytest.fixture -def store(temp_db): - """Create a SQLiteStore with migrations applied.""" - store = SQLiteStore(temp_db) - store.initialize() - - # Manually apply migration_003 (code_relationships table) - conn = store._get_connection() - cursor = conn.cursor() - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS code_relationships ( - id INTEGER PRIMARY KEY, - source_symbol_id INTEGER NOT NULL, - target_qualified_name TEXT NOT NULL, - relationship_type TEXT NOT NULL, - source_line INTEGER NOT NULL, - target_file TEXT, - FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE - ) - """ - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)" - ) - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)" - ) - conn.commit() - - yield store - - # Cleanup - store.close() - - -def test_relationship_table_created(store): - """Test that the code_relationships table is created by migration.""" - conn = store._get_connection() - cursor = conn.cursor() - - # Check table exists - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='code_relationships'" - ) - result = cursor.fetchone() - assert result is not None, "code_relationships table should exist" - - # Check indexes exist - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='code_relationships'" - ) - indexes = [row[0] for row in cursor.fetchall()] - assert "idx_relationships_source" in indexes - assert "idx_relationships_target" in indexes - assert "idx_relationships_type" in indexes - - -def test_add_relationships(store): - """Test storing code relationships.""" - # First add a file with symbols - indexed_file = IndexedFile( - path=str(Path(__file__).parent / "sample.py"), - language="python", - symbols=[ - Symbol(name="foo", kind="function", range=(1, 5)), - Symbol(name="bar", kind="function", range=(7, 10)), - ] - ) - - content = """def foo(): - bar() - baz() - -def bar(): - print("hello") -""" - - store.add_file(indexed_file, content) - - # Add relationships - relationships = [ - CodeRelationship( - source_symbol="foo", - target_symbol="bar", - relationship_type="call", - source_file=indexed_file.path, - target_file=None, - source_line=2 - ), - CodeRelationship( - source_symbol="foo", - target_symbol="baz", - relationship_type="call", - source_file=indexed_file.path, - target_file=None, - source_line=3 - ), - ] - - store.add_relationships(indexed_file.path, relationships) - - # Verify relationships were stored - conn = store._get_connection() - count = conn.execute("SELECT COUNT(*) FROM code_relationships").fetchone()[0] - assert count == 2, "Should have stored 2 relationships" - - -def test_query_relationships_by_target(store): - """Test querying relationships by target symbol (find callers).""" - # Setup: Add file and relationships - file_path = str(Path(__file__).parent / "sample.py") - # Content: Line 1-2: foo(), Line 4-5: bar(), Line 7-8: main() - indexed_file = IndexedFile( - path=file_path, - language="python", - symbols=[ - Symbol(name="foo", kind="function", range=(1, 2)), - Symbol(name="bar", kind="function", range=(4, 5)), - Symbol(name="main", kind="function", range=(7, 8)), - ] - ) - - content = "def foo():\n bar()\n\ndef bar():\n pass\n\ndef main():\n bar()\n" - store.add_file(indexed_file, content) - - relationships = [ - CodeRelationship( - source_symbol="foo", - target_symbol="bar", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=2 # Call inside foo (line 2) - ), - CodeRelationship( - source_symbol="main", - target_symbol="bar", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=8 # Call inside main (line 8) - ), - ] - - store.add_relationships(file_path, relationships) - - # Query: Find all callers of "bar" - callers = store.query_relationships_by_target("bar") - - assert len(callers) == 2, "Should find 2 callers of bar" - assert any(r["source_symbol"] == "foo" for r in callers) - assert any(r["source_symbol"] == "main" for r in callers) - assert all(r["target_symbol"] == "bar" for r in callers) - assert all(r["relationship_type"] == "call" for r in callers) - - -def test_query_relationships_by_source(store): - """Test querying relationships by source symbol (find callees).""" - # Setup - file_path = str(Path(__file__).parent / "sample.py") - indexed_file = IndexedFile( - path=file_path, - language="python", - symbols=[ - Symbol(name="foo", kind="function", range=(1, 6)), - ] - ) - - content = "def foo():\n bar()\n baz()\n qux()\n" - store.add_file(indexed_file, content) - - relationships = [ - CodeRelationship( - source_symbol="foo", - target_symbol="bar", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=2 - ), - CodeRelationship( - source_symbol="foo", - target_symbol="baz", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=3 - ), - CodeRelationship( - source_symbol="foo", - target_symbol="qux", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=4 - ), - ] - - store.add_relationships(file_path, relationships) - - # Query: Find all functions called by foo - callees = store.query_relationships_by_source("foo", file_path) - - assert len(callees) == 3, "Should find 3 functions called by foo" - targets = {r["target_symbol"] for r in callees} - assert targets == {"bar", "baz", "qux"} - assert all(r["source_symbol"] == "foo" for r in callees) - - -def test_query_performance(store): - """Test that relationship queries execute within performance threshold.""" - import time - - # Setup: Create a file with many relationships - file_path = str(Path(__file__).parent / "large_file.py") - symbols = [Symbol(name=f"func_{i}", kind="function", range=(i*10+1, i*10+5)) for i in range(100)] - - indexed_file = IndexedFile( - path=file_path, - language="python", - symbols=symbols - ) - - content = "\n".join([f"def func_{i}():\n pass\n" for i in range(100)]) - store.add_file(indexed_file, content) - - # Create many relationships - relationships = [] - for i in range(100): - for j in range(10): - relationships.append( - CodeRelationship( - source_symbol=f"func_{i}", - target_symbol=f"target_{j}", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=i*10 + 1 - ) - ) - - store.add_relationships(file_path, relationships) - - # Query and measure time - start = time.time() - results = store.query_relationships_by_target("target_5") - elapsed_ms = (time.time() - start) * 1000 - - assert len(results) == 100, "Should find 100 callers" - assert elapsed_ms < 50, f"Query took {elapsed_ms:.1f}ms, should be <50ms" - - -def test_stats_includes_relationships(store): - """Test that stats() includes relationship count.""" - # Add a file with relationships - file_path = str(Path(__file__).parent / "sample.py") - indexed_file = IndexedFile( - path=file_path, - language="python", - symbols=[Symbol(name="foo", kind="function", range=(1, 5))] - ) - - store.add_file(indexed_file, "def foo():\n bar()\n") - - relationships = [ - CodeRelationship( - source_symbol="foo", - target_symbol="bar", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=2 - ) - ] - - store.add_relationships(file_path, relationships) - - # Check stats - stats = store.stats() - - assert "relationships" in stats - assert stats["relationships"] == 1 - assert stats["files"] == 1 - assert stats["symbols"] == 1 - - -def test_update_relationships_on_file_reindex(store): - """Test that relationships are updated when file is re-indexed.""" - file_path = str(Path(__file__).parent / "sample.py") - - # Initial index - indexed_file = IndexedFile( - path=file_path, - language="python", - symbols=[Symbol(name="foo", kind="function", range=(1, 3))] - ) - store.add_file(indexed_file, "def foo():\n bar()\n") - - relationships = [ - CodeRelationship( - source_symbol="foo", - target_symbol="bar", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=2 - ) - ] - store.add_relationships(file_path, relationships) - - # Re-index with different relationships - new_relationships = [ - CodeRelationship( - source_symbol="foo", - target_symbol="baz", - relationship_type="call", - source_file=file_path, - target_file=None, - source_line=2 - ) - ] - store.add_relationships(file_path, new_relationships) - - # Verify old relationships are replaced - all_rels = store.query_relationships_by_source("foo", file_path) - assert len(all_rels) == 1 - assert all_rels[0]["target_symbol"] == "baz" diff --git a/codex-lens/tests/test_token_chunking.py b/codex-lens/tests/test_token_chunking.py index 90d2b950..39be7aa0 100644 --- a/codex-lens/tests/test_token_chunking.py +++ b/codex-lens/tests/test_token_chunking.py @@ -188,60 +188,3 @@ class TestTokenCountPerformance: # Precomputed should be at least 10% faster speedup = ((computed_time - precomputed_time) / computed_time) * 100 assert speedup >= 10.0, f"Speedup {speedup:.2f}% < 10% (computed={computed_time:.4f}s, precomputed={precomputed_time:.4f}s)" - - -class TestSymbolEntityTokenCount: - """Tests for Symbol entity token_count field.""" - - def test_symbol_with_token_count(self): - """Test creating Symbol with token_count.""" - symbol = Symbol( - name="test_func", - kind="function", - range=(1, 10), - token_count=42 - ) - - assert symbol.token_count == 42 - - def test_symbol_without_token_count(self): - """Test creating Symbol without token_count (defaults to None).""" - symbol = Symbol( - name="test_func", - kind="function", - range=(1, 10) - ) - - assert symbol.token_count is None - - def test_symbol_with_symbol_type(self): - """Test creating Symbol with symbol_type.""" - symbol = Symbol( - name="TestClass", - kind="class", - range=(1, 20), - symbol_type="class_definition" - ) - - assert symbol.symbol_type == "class_definition" - - def test_symbol_token_count_validation(self): - """Test that negative token counts are rejected.""" - with pytest.raises(ValueError, match="token_count must be >= 0"): - Symbol( - name="test", - kind="function", - range=(1, 2), - token_count=-1 - ) - - def test_symbol_zero_token_count(self): - """Test that zero token count is allowed.""" - symbol = Symbol( - name="empty", - kind="function", - range=(1, 1), - token_count=0 - ) - - assert symbol.token_count == 0