mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
refactor: 移除图索引功能,修复内存泄露,优化嵌入生成
主要更改: 1. 移除图索引功能 (graph indexing) - 删除 graph_analyzer.py 及相关迁移文件 - 移除 CLI 的 graph 命令和 --enrich 标志 - 清理 chain_search.py 中的图查询方法 (370行) - 删除相关测试文件 2. 修复嵌入生成内存问题 - 重构 generate_embeddings.py 使用流式批处理 - 改用 embedding_manager 的内存安全实现 - 文件从 548 行精简到 259 行 (52.7% 减少) 3. 修复内存泄露 - chain_search.py: quick_search 使用 with 语句管理 ChainSearchEngine - embedding_manager.py: 使用 with 语句管理 VectorStore - vector_store.py: 添加暴力搜索内存警告 4. 代码清理 - 移除 Symbol 模型的 token_count 和 symbol_type 字段 - 清理相关测试用例 测试: 760 passed, 7 skipped 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -268,7 +268,6 @@ def search(
|
||||
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
|
||||
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
|
||||
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
|
||||
enrich: bool = typer.Option(False, "--enrich", help="Enrich results with code graph relationships (calls, imports)."),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
@@ -423,30 +422,10 @@ def search(
|
||||
for r in result.results
|
||||
]
|
||||
|
||||
# Enrich results with relationship data if requested
|
||||
enriched = False
|
||||
if enrich:
|
||||
try:
|
||||
from codexlens.search.enrichment import RelationshipEnricher
|
||||
|
||||
# Find index path for the search path
|
||||
project_record = registry.find_by_source_path(str(search_path))
|
||||
if project_record:
|
||||
index_path = Path(project_record["index_root"]) / "_index.db"
|
||||
if index_path.exists():
|
||||
with RelationshipEnricher(index_path) as enricher:
|
||||
results_list = enricher.enrich(results_list, limit=limit)
|
||||
enriched = True
|
||||
except Exception as e:
|
||||
# Enrichment failure should not break search
|
||||
if verbose:
|
||||
console.print(f"[yellow]Warning: Enrichment failed: {e}[/yellow]")
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"mode": actual_mode,
|
||||
"count": len(results_list),
|
||||
"enriched": enriched,
|
||||
"results": results_list,
|
||||
"stats": {
|
||||
"dirs_searched": result.stats.dirs_searched,
|
||||
@@ -458,8 +437,7 @@ def search(
|
||||
print_json(success=True, result=payload)
|
||||
else:
|
||||
render_search_results(result.results, verbose=verbose)
|
||||
enrich_status = " | [green]Enriched[/green]" if enriched else ""
|
||||
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms{enrich_status}[/dim]")
|
||||
console.print(f"[dim]Mode: {actual_mode} | Searched {result.stats.dirs_searched} directories in {result.stats.time_ms:.1f}ms[/dim]")
|
||||
|
||||
except SearchError as exc:
|
||||
if json_mode:
|
||||
@@ -1376,103 +1354,6 @@ def clean(
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def graph(
|
||||
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
|
||||
symbol: str = typer.Argument(..., help="Symbol name to query"),
|
||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
|
||||
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
|
||||
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Query semantic graph for code relationships.
|
||||
|
||||
Supported query types:
|
||||
- callers: Find all functions/methods that call the given symbol
|
||||
- callees: Find all functions/methods called by the given symbol
|
||||
- inheritance: Find inheritance relationships for the given class
|
||||
|
||||
Examples:
|
||||
codex-lens graph callers my_function
|
||||
codex-lens graph callees MyClass.method --path src/
|
||||
codex-lens graph inheritance BaseClass
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
search_path = path.expanduser().resolve()
|
||||
|
||||
# Validate query type
|
||||
valid_types = ["callers", "callees", "inheritance"]
|
||||
if query_type not in valid_types:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
|
||||
else:
|
||||
console.print(f"[red]Invalid query type:[/red] {query_type}")
|
||||
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
registry: RegistryStore | None = None
|
||||
try:
|
||||
registry = RegistryStore()
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
options = SearchOptions(depth=depth, total_limit=limit)
|
||||
|
||||
# Execute graph query based on type
|
||||
if query_type == "callers":
|
||||
results = engine.search_callers(symbol, search_path, options=options)
|
||||
result_type = "callers"
|
||||
elif query_type == "callees":
|
||||
results = engine.search_callees(symbol, search_path, options=options)
|
||||
result_type = "callees"
|
||||
else: # inheritance
|
||||
results = engine.search_inheritance(symbol, search_path, options=options)
|
||||
result_type = "inheritance"
|
||||
|
||||
payload = {
|
||||
"query_type": query_type,
|
||||
"symbol": symbol,
|
||||
"count": len(results),
|
||||
"relationships": results
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
print_json(success=True, result=payload)
|
||||
else:
|
||||
from .output import render_graph_results
|
||||
render_graph_results(results, query_type=query_type, symbol=symbol)
|
||||
|
||||
except SearchError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Graph search error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (search):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except StorageError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Storage error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except CodexLensError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=str(exc))
|
||||
else:
|
||||
console.print(f"[red]Graph query failed:[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Unexpected error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
if registry is not None:
|
||||
registry.close()
|
||||
|
||||
|
||||
@app.command("semantic-list")
|
||||
def semantic_list(
|
||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),
|
||||
|
||||
@@ -194,7 +194,6 @@ def generate_embeddings(
|
||||
try:
|
||||
# Use cached embedder (singleton) for performance
|
||||
embedder = get_embedder(profile=model_profile)
|
||||
vector_store = VectorStore(index_path)
|
||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size))
|
||||
|
||||
if progress_callback:
|
||||
@@ -217,85 +216,86 @@ def generate_embeddings(
|
||||
EMBEDDING_BATCH_SIZE = 8 # jina-embeddings-v2-base-code needs small batches
|
||||
|
||||
try:
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
path_column = _get_path_column(conn)
|
||||
with VectorStore(index_path) as vector_store:
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
path_column = _get_path_column(conn)
|
||||
|
||||
# Get total file count for progress reporting
|
||||
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
|
||||
if total_files == 0:
|
||||
return {"success": False, "error": "No files found in index"}
|
||||
# Get total file count for progress reporting
|
||||
total_files = conn.execute("SELECT COUNT(*) FROM files").fetchone()[0]
|
||||
if total_files == 0:
|
||||
return {"success": False, "error": "No files found in index"}
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
|
||||
|
||||
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
|
||||
batch_number = 0
|
||||
|
||||
while True:
|
||||
# Fetch a batch of files (streaming, not fetchall)
|
||||
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
|
||||
if not file_batch:
|
||||
break
|
||||
|
||||
batch_number += 1
|
||||
batch_chunks_with_paths = []
|
||||
files_in_batch_with_chunks = set()
|
||||
|
||||
# Step 1: Chunking for the current file batch
|
||||
for file_row in file_batch:
|
||||
file_path = file_row[path_column]
|
||||
content = file_row["content"]
|
||||
language = file_row["language"] or "python"
|
||||
|
||||
try:
|
||||
chunks = chunker.chunk_sliding_window(
|
||||
content,
|
||||
file_path=file_path,
|
||||
language=language
|
||||
)
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
batch_chunks_with_paths.append((chunk, file_path))
|
||||
files_in_batch_with_chunks.add(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to chunk {file_path}: {e}")
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
if not batch_chunks_with_paths:
|
||||
continue
|
||||
|
||||
batch_chunk_count = len(batch_chunks_with_paths)
|
||||
if progress_callback:
|
||||
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
|
||||
progress_callback(f"Processing {total_files} files in batches of {FILE_BATCH_SIZE}...")
|
||||
|
||||
# Step 2: Generate embeddings for this batch
|
||||
batch_embeddings = []
|
||||
try:
|
||||
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
|
||||
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
|
||||
embeddings = embedder.embed(batch_contents)
|
||||
batch_embeddings.extend(embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
continue
|
||||
cursor = conn.execute(f"SELECT {path_column}, content, language FROM files")
|
||||
batch_number = 0
|
||||
|
||||
# Step 3: Assign embeddings to chunks
|
||||
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
|
||||
chunk.embedding = embedding
|
||||
while True:
|
||||
# Fetch a batch of files (streaming, not fetchall)
|
||||
file_batch = cursor.fetchmany(FILE_BATCH_SIZE)
|
||||
if not file_batch:
|
||||
break
|
||||
|
||||
# Step 4: Store this batch to database immediately (releases memory)
|
||||
try:
|
||||
vector_store.add_chunks_batch(batch_chunks_with_paths)
|
||||
total_chunks_created += batch_chunk_count
|
||||
total_files_processed += len(files_in_batch_with_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
batch_number += 1
|
||||
batch_chunks_with_paths = []
|
||||
files_in_batch_with_chunks = set()
|
||||
|
||||
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
|
||||
# Step 1: Chunking for the current file batch
|
||||
for file_row in file_batch:
|
||||
file_path = file_row[path_column]
|
||||
content = file_row["content"]
|
||||
language = file_row["language"] or "python"
|
||||
|
||||
try:
|
||||
chunks = chunker.chunk_sliding_window(
|
||||
content,
|
||||
file_path=file_path,
|
||||
language=language
|
||||
)
|
||||
if chunks:
|
||||
for chunk in chunks:
|
||||
batch_chunks_with_paths.append((chunk, file_path))
|
||||
files_in_batch_with_chunks.add(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to chunk {file_path}: {e}")
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
if not batch_chunks_with_paths:
|
||||
continue
|
||||
|
||||
batch_chunk_count = len(batch_chunks_with_paths)
|
||||
if progress_callback:
|
||||
progress_callback(f" Batch {batch_number}: {len(file_batch)} files, {batch_chunk_count} chunks")
|
||||
|
||||
# Step 2: Generate embeddings for this batch
|
||||
batch_embeddings = []
|
||||
try:
|
||||
for i in range(0, batch_chunk_count, EMBEDDING_BATCH_SIZE):
|
||||
batch_end = min(i + EMBEDDING_BATCH_SIZE, batch_chunk_count)
|
||||
batch_contents = [chunk.content for chunk, _ in batch_chunks_with_paths[i:batch_end]]
|
||||
embeddings = embedder.embed(batch_contents)
|
||||
batch_embeddings.extend(embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embeddings for batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
continue
|
||||
|
||||
# Step 3: Assign embeddings to chunks
|
||||
for (chunk, _), embedding in zip(batch_chunks_with_paths, batch_embeddings):
|
||||
chunk.embedding = embedding
|
||||
|
||||
# Step 4: Store this batch to database immediately (releases memory)
|
||||
try:
|
||||
vector_store.add_chunks_batch(batch_chunks_with_paths)
|
||||
total_chunks_created += batch_chunk_count
|
||||
total_files_processed += len(files_in_batch_with_chunks)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store batch {batch_number}: {str(e)}")
|
||||
failed_files.extend([(file_row[path_column], str(e)) for file_row in file_batch])
|
||||
|
||||
# Memory is released here as batch_chunks_with_paths and batch_embeddings go out of scope
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Failed to read or process files: {str(e)}"}
|
||||
|
||||
@@ -122,68 +122,3 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
|
||||
console.print(header)
|
||||
render_symbols(list(symbols), title="Discovered Symbols")
|
||||
|
||||
|
||||
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
|
||||
"""Render semantic graph query results.
|
||||
|
||||
Args:
|
||||
results: List of relationship dicts
|
||||
query_type: Type of query (callers, callees, inheritance)
|
||||
symbol: Symbol name that was queried
|
||||
"""
|
||||
if not results:
|
||||
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
|
||||
return
|
||||
|
||||
title_map = {
|
||||
"callers": f"Callers of '{symbol}' ({len(results)} found)",
|
||||
"callees": f"Callees of '{symbol}' ({len(results)} found)",
|
||||
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
|
||||
}
|
||||
|
||||
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
|
||||
|
||||
if query_type == "callers":
|
||||
table.add_column("Caller", style="green")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
table.add_column("Type", style="dim")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("source_symbol", "-"),
|
||||
rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-")),
|
||||
rel.get("relationship_type", "-")
|
||||
)
|
||||
|
||||
elif query_type == "callees":
|
||||
table.add_column("Target", style="green")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
table.add_column("Type", style="dim")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("target_symbol", "-"),
|
||||
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-")),
|
||||
rel.get("relationship_type", "-")
|
||||
)
|
||||
|
||||
else: # inheritance
|
||||
table.add_column("Derived Class", style="green")
|
||||
table.add_column("Base Class", style="magenta")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("source_symbol", "-"),
|
||||
rel.get("target_symbol", "-"),
|
||||
rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-"))
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@@ -14,8 +14,6 @@ class Symbol(BaseModel):
|
||||
kind: str = Field(..., min_length=1)
|
||||
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||
file: Optional[str] = Field(default=None, description="Full path to the file containing this symbol")
|
||||
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
|
||||
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
|
||||
|
||||
@field_validator("range")
|
||||
@classmethod
|
||||
@@ -29,13 +27,6 @@ class Symbol(BaseModel):
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
return value
|
||||
|
||||
@field_validator("token_count")
|
||||
@classmethod
|
||||
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("token_count must be >= 0")
|
||||
return value
|
||||
|
||||
|
||||
class SemanticChunk(BaseModel):
|
||||
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||
|
||||
@@ -302,108 +302,6 @@ class ChainSearchEngine:
|
||||
index_paths, name, kind, options.total_limit
|
||||
)
|
||||
|
||||
def search_callers(self, target_symbol: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find all callers of a given symbol across directory hierarchy.
|
||||
|
||||
Args:
|
||||
target_symbol: Name of the symbol to find callers for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with caller information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> callers = engine.search_callers("my_function", Path("D:/project"))
|
||||
>>> for caller in callers:
|
||||
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_callers_parallel(
|
||||
index_paths, target_symbol, options.total_limit
|
||||
)
|
||||
|
||||
def search_callees(self, source_symbol: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find all callees (what a symbol calls) across directory hierarchy.
|
||||
|
||||
Args:
|
||||
source_symbol: Name of the symbol to find callees for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with callee information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
|
||||
>>> for callee in callees:
|
||||
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_callees_parallel(
|
||||
index_paths, source_symbol, options.total_limit
|
||||
)
|
||||
|
||||
def search_inheritance(self, class_name: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find inheritance relationships for a class across directory hierarchy.
|
||||
|
||||
Args:
|
||||
class_name: Name of the class to find inheritance for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with inheritance information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
|
||||
>>> for rel in inheritance:
|
||||
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_inheritance_parallel(
|
||||
index_paths, class_name, options.total_limit
|
||||
)
|
||||
|
||||
# === Internal Methods ===
|
||||
|
||||
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
||||
@@ -711,273 +609,6 @@ class ChainSearchEngine:
|
||||
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_callers_parallel(self, index_paths: List[Path],
|
||||
target_symbol: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for callers across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
target_symbol: Target symbol name
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of caller relationships
|
||||
"""
|
||||
all_callers = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_callers_single,
|
||||
idx_path,
|
||||
target_symbol
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
callers = future.result()
|
||||
all_callers.extend(callers)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Caller search failed: {exc}")
|
||||
|
||||
# Deduplicate by (source_file, source_line)
|
||||
seen = set()
|
||||
unique_callers = []
|
||||
for caller in all_callers:
|
||||
key = (caller.get("source_file"), caller.get("source_line"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_callers.append(caller)
|
||||
|
||||
# Sort by source file and line
|
||||
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
|
||||
|
||||
return unique_callers[:limit]
|
||||
|
||||
def _search_callers_single(self, index_path: Path,
|
||||
target_symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Search for callers in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
target_symbol: Target symbol name
|
||||
|
||||
Returns:
|
||||
List of caller relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
return store.query_relationships_by_target(target_symbol)
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Caller search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_callees_parallel(self, index_paths: List[Path],
|
||||
source_symbol: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for callees across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
source_symbol: Source symbol name
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of callee relationships
|
||||
"""
|
||||
all_callees = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_callees_single,
|
||||
idx_path,
|
||||
source_symbol
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
callees = future.result()
|
||||
all_callees.extend(callees)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Callee search failed: {exc}")
|
||||
|
||||
# Deduplicate by (target_symbol, source_line)
|
||||
seen = set()
|
||||
unique_callees = []
|
||||
for callee in all_callees:
|
||||
key = (callee.get("target_symbol"), callee.get("source_line"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_callees.append(callee)
|
||||
|
||||
# Sort by source line
|
||||
unique_callees.sort(key=lambda c: c.get("source_line", 0))
|
||||
|
||||
return unique_callees[:limit]
|
||||
|
||||
def _search_callees_single(self, index_path: Path,
|
||||
source_symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Search for callees in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
source_symbol: Source symbol name
|
||||
|
||||
Returns:
|
||||
List of callee relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
# Single JOIN query to get all callees (fixes N+1 query problem)
|
||||
# Uses public execute_query API instead of _get_connection bypass
|
||||
rows = store.execute_query(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name AS target_symbol,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND r.relationship_type = 'call'
|
||||
ORDER BY f.full_path, r.source_line
|
||||
LIMIT 100
|
||||
""",
|
||||
(source_symbol,)
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_symbol"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Callee search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_inheritance_parallel(self, index_paths: List[Path],
|
||||
class_name: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for inheritance relationships across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
class_name: Class name to search for
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of inheritance relationships
|
||||
"""
|
||||
all_inheritance = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_inheritance_single,
|
||||
idx_path,
|
||||
class_name
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
inheritance = future.result()
|
||||
all_inheritance.extend(inheritance)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Inheritance search failed: {exc}")
|
||||
|
||||
# Deduplicate by (source_symbol, target_symbol)
|
||||
seen = set()
|
||||
unique_inheritance = []
|
||||
for rel in all_inheritance:
|
||||
key = (rel.get("source_symbol"), rel.get("target_symbol"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_inheritance.append(rel)
|
||||
|
||||
# Sort by source file
|
||||
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
|
||||
|
||||
return unique_inheritance[:limit]
|
||||
|
||||
def _search_inheritance_single(self, index_path: Path,
|
||||
class_name: str) -> List[Dict[str, Any]]:
|
||||
"""Search for inheritance relationships in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
class_name: Class name to search for
|
||||
|
||||
Returns:
|
||||
List of inheritance relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
# Use UNION to find relationships where class is either:
|
||||
# 1. The base class (target) - find derived classes
|
||||
# 2. The derived class (source) - find parent classes
|
||||
# Uses public execute_query API instead of _get_connection bypass
|
||||
rows = store.execute_query(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE r.target_qualified_name = ? AND r.relationship_type = 'inherits'
|
||||
UNION
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND r.relationship_type = 'inherits'
|
||||
LIMIT 100
|
||||
""",
|
||||
(class_name, class_name)
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
|
||||
# === Convenience Functions ===
|
||||
|
||||
@@ -1007,10 +638,9 @@ def quick_search(query: str,
|
||||
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
options = SearchOptions(depth=depth)
|
||||
|
||||
result = engine.search(query, source_path, options)
|
||||
with ChainSearchEngine(registry, mapper) as engine:
|
||||
options = SearchOptions(depth=depth)
|
||||
result = engine.search(query, source_path, options)
|
||||
|
||||
registry.close()
|
||||
|
||||
|
||||
@@ -1,542 +0,0 @@
|
||||
"""Graph analyzer for extracting code relationships using tree-sitter.
|
||||
|
||||
Provides AST-based analysis to identify function calls, method invocations,
|
||||
and class inheritance relationships within source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import CodeRelationship, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class GraphAnalyzer:
|
||||
"""Analyzer for extracting semantic relationships from code using AST traversal."""
|
||||
|
||||
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
|
||||
"""Initialize graph analyzer for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
parser: Optional TreeSitterSymbolParser instance for dependency injection.
|
||||
If None, creates a new parser instance (backward compatibility).
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if graph analyzer is available.
|
||||
|
||||
Returns:
|
||||
True if tree-sitter parser is initialized and ready
|
||||
"""
|
||||
return self._parser.is_available()
|
||||
|
||||
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
|
||||
"""Analyze source code and extract relationships.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def analyze_with_symbols(
|
||||
self, text: str, file_path: Path, symbols: List[Symbol]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
# Convert Symbol objects to internal symbol format
|
||||
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
|
||||
|
||||
# Extract relationships using provided symbols
|
||||
relationships = self._extract_relationships_with_symbols(
|
||||
source_bytes, root, str(file_path.resolve()), defined_symbols
|
||||
)
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def _convert_symbols_to_dict(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
|
||||
) -> List[dict]:
|
||||
"""Convert Symbol objects to internal dict format for relationship extraction.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
symbols: Pre-parsed Symbol objects
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbol_dicts = []
|
||||
symbol_names = {s.name for s in symbols}
|
||||
|
||||
# Find AST nodes corresponding to symbols
|
||||
for node in self._iter_nodes(root):
|
||||
node_type = node.type
|
||||
|
||||
# Check if this node matches any of our symbols
|
||||
if node_type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node_type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
|
||||
return symbol_dicts
|
||||
|
||||
def _extract_relationships_with_symbols(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST using pre-parsed symbols.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
defined_symbols: Pre-parsed symbol dicts
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# Determine call node type based on language
|
||||
if self.language_id == "python":
|
||||
call_node_type = "call"
|
||||
extract_target = self._extract_call_target
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
call_node_type = "call_expression"
|
||||
extract_target = self._extract_js_call_target
|
||||
else:
|
||||
return []
|
||||
|
||||
# Find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == call_node_type:
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = extract_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, file_path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, file_path)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract Python relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of Python relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols with their scopes
|
||||
defined_symbols = self._collect_python_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call":
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = self._extract_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract JavaScript/TypeScript relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of JS/TS relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols
|
||||
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call_expression":
|
||||
# Extract caller context
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee
|
||||
target_symbol = self._extract_js_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all Python function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all JS/TS function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
|
||||
"""Find the enclosing function/method/class for a node.
|
||||
|
||||
Returns fully qualified name (e.g., "MyClass.my_method") by traversing up
|
||||
the AST tree and collecting parent class/function names.
|
||||
|
||||
Args:
|
||||
node: AST node to find enclosure for
|
||||
symbols: List of defined symbols
|
||||
|
||||
Returns:
|
||||
Fully qualified name of enclosing symbol, or None if at module level
|
||||
"""
|
||||
# Walk up the tree to find all enclosing symbols
|
||||
enclosing_names = []
|
||||
parent = node.parent
|
||||
|
||||
while parent is not None:
|
||||
for symbol in symbols:
|
||||
if symbol["node"] == parent:
|
||||
# Prepend to maintain order (innermost to outermost)
|
||||
enclosing_names.insert(0, symbol["name"])
|
||||
break
|
||||
parent = parent.parent
|
||||
|
||||
# Return fully qualified name or None if at module level
|
||||
if enclosing_names:
|
||||
return ".".join(enclosing_names)
|
||||
return None
|
||||
|
||||
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a Python call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers (e.g., "foo()")
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle attribute access (e.g., "obj.method()")
|
||||
if function_node.type == "attribute":
|
||||
attr_node = function_node.child_by_field_name("attribute")
|
||||
if attr_node:
|
||||
return self._node_text(source_bytes, attr_node)
|
||||
|
||||
return None
|
||||
|
||||
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a JS/TS call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle member expressions (e.g., "obj.method()")
|
||||
if function_node.type == "member_expression":
|
||||
property_node = function_node.child_by_field_name("property")
|
||||
if property_node:
|
||||
return self._node_text(source_bytes, property_node)
|
||||
|
||||
return None
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
@@ -602,6 +602,12 @@ class VectorStore:
|
||||
Returns:
|
||||
List of SearchResult ordered by similarity (highest first)
|
||||
"""
|
||||
logger.warning(
|
||||
"Using brute-force vector search (hnswlib not available). "
|
||||
"This may cause high memory usage for large indexes. "
|
||||
"Install hnswlib for better performance: pip install hnswlib"
|
||||
)
|
||||
|
||||
with self._cache_lock:
|
||||
# Refresh cache if needed
|
||||
if self._embedding_matrix is None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import CodeRelationship, SearchResult, Symbol
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
|
||||
@@ -237,116 +237,6 @@ class DirIndexStore:
|
||||
conn.rollback()
|
||||
raise StorageError(f"Failed to add file {name}: {exc}") from exc
|
||||
|
||||
def add_relationships(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
relationships: List[CodeRelationship],
|
||||
) -> int:
|
||||
"""Store code relationships for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file
|
||||
relationships: List of CodeRelationship objects to store
|
||||
|
||||
Returns:
|
||||
Number of relationships stored
|
||||
|
||||
Raises:
|
||||
StorageError: If database operations fail
|
||||
"""
|
||||
if not relationships:
|
||||
return 0
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
file_path_str = str(Path(file_path).resolve())
|
||||
|
||||
try:
|
||||
# Get file_id
|
||||
row = conn.execute(
|
||||
"SELECT id FROM files WHERE full_path=?", (file_path_str,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
return 0
|
||||
|
||||
file_id = int(row["id"])
|
||||
|
||||
# Delete existing relationships for symbols in this file
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM code_relationships
|
||||
WHERE source_symbol_id IN (
|
||||
SELECT id FROM symbols WHERE file_id=?
|
||||
)
|
||||
""",
|
||||
(file_id,),
|
||||
)
|
||||
|
||||
# Insert new relationships
|
||||
relationship_rows = []
|
||||
skipped_relationships = []
|
||||
for rel in relationships:
|
||||
# Extract simple name from fully qualified name (e.g., "MyClass.my_method" -> "my_method")
|
||||
# This handles cases where GraphAnalyzer generates qualified names but symbols table stores simple names
|
||||
source_symbol_simple = rel.source_symbol.split(".")[-1] if "." in rel.source_symbol else rel.source_symbol
|
||||
|
||||
# Find symbol_id by name and file
|
||||
symbol_row = conn.execute(
|
||||
"""
|
||||
SELECT id FROM symbols
|
||||
WHERE file_id=? AND name=? AND start_line<=? AND end_line>=?
|
||||
LIMIT 1
|
||||
""",
|
||||
(file_id, source_symbol_simple, rel.source_line, rel.source_line),
|
||||
).fetchone()
|
||||
|
||||
if not symbol_row:
|
||||
# Try matching by simple name only
|
||||
symbol_row = conn.execute(
|
||||
"SELECT id FROM symbols WHERE file_id=? AND name=? LIMIT 1",
|
||||
(file_id, source_symbol_simple),
|
||||
).fetchone()
|
||||
|
||||
if symbol_row:
|
||||
relationship_rows.append((
|
||||
int(symbol_row["id"]),
|
||||
rel.target_symbol,
|
||||
rel.relationship_type,
|
||||
rel.source_line,
|
||||
rel.target_file,
|
||||
))
|
||||
else:
|
||||
# Log warning when symbol lookup fails
|
||||
skipped_relationships.append(rel.source_symbol)
|
||||
|
||||
# Log skipped relationships for debugging
|
||||
if skipped_relationships:
|
||||
self.logger.warning(
|
||||
"Failed to find source symbol IDs for %d relationships in %s: %s",
|
||||
len(skipped_relationships),
|
||||
file_path_str,
|
||||
", ".join(set(skipped_relationships))
|
||||
)
|
||||
|
||||
if relationship_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO code_relationships(
|
||||
source_symbol_id, target_qualified_name, relationship_type,
|
||||
source_line, target_file
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
""",
|
||||
relationship_rows,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
return len(relationship_rows)
|
||||
|
||||
except sqlite3.DatabaseError as exc:
|
||||
conn.rollback()
|
||||
raise StorageError(f"Failed to add relationships: {exc}") from exc
|
||||
|
||||
def add_files_batch(
|
||||
self, files: List[Tuple[str, Path, str, str, Optional[List[Symbol]]]]
|
||||
) -> int:
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Set
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.parsers.factory import ParserFactory
|
||||
from codexlens.semantic.graph_analyzer import GraphAnalyzer
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import ProjectInfo, RegistryStore
|
||||
@@ -525,16 +524,6 @@ class IndexTreeBuilder:
|
||||
symbols=indexed_file.symbols,
|
||||
)
|
||||
|
||||
# Extract and store code relationships for graph visualization
|
||||
if language_id in {"python", "javascript", "typescript"}:
|
||||
graph_analyzer = GraphAnalyzer(language_id)
|
||||
if graph_analyzer.is_available():
|
||||
relationships = graph_analyzer.analyze_with_symbols(
|
||||
text, file_path, indexed_file.symbols
|
||||
)
|
||||
if relationships:
|
||||
store.add_relationships(file_path, relationships)
|
||||
|
||||
files_count += 1
|
||||
symbols_count += len(indexed_file.symbols)
|
||||
|
||||
@@ -742,16 +731,6 @@ def _build_dir_worker(args: tuple) -> DirBuildResult:
|
||||
symbols=indexed_file.symbols,
|
||||
)
|
||||
|
||||
# Extract and store code relationships for graph visualization
|
||||
if language_id in {"python", "javascript", "typescript"}:
|
||||
graph_analyzer = GraphAnalyzer(language_id)
|
||||
if graph_analyzer.is_available():
|
||||
relationships = graph_analyzer.analyze_with_symbols(
|
||||
text, item, indexed_file.symbols
|
||||
)
|
||||
if relationships:
|
||||
store.add_relationships(item, relationships)
|
||||
|
||||
files_count += 1
|
||||
symbols_count += len(indexed_file.symbols)
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""
|
||||
Migration 003: Add code relationships storage.
|
||||
|
||||
This migration introduces the `code_relationships` table to store semantic
|
||||
relationships between code symbols (function calls, inheritance, imports).
|
||||
This enables graph-based code navigation and dependency analysis.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to add code relationships table.
|
||||
|
||||
- Creates `code_relationships` table with foreign key to symbols
|
||||
- Creates indexes for efficient relationship queries
|
||||
- Supports lazy expansion with target_symbol being qualified names
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating 'code_relationships' table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for code_relationships...")
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
||||
)
|
||||
|
||||
log.info("Finished creating code_relationships table and indexes.")
|
||||
@@ -10,7 +10,7 @@ from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
|
||||
from codexlens.entities import IndexedFile, SearchResult, Symbol
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
|
||||
@@ -420,167 +420,6 @@ class SQLiteStore:
|
||||
}
|
||||
|
||||
|
||||
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
|
||||
"""Store code relationships for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the relationships
|
||||
relationships: List of CodeRelationship objects to store
|
||||
"""
|
||||
if not relationships:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
resolved_path = str(Path(file_path).resolve())
|
||||
|
||||
# Get file_id
|
||||
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
|
||||
if not row:
|
||||
raise StorageError(f"File not found in index: {file_path}")
|
||||
file_id = int(row["id"])
|
||||
|
||||
# Delete existing relationships for symbols in this file
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM code_relationships
|
||||
WHERE source_symbol_id IN (
|
||||
SELECT id FROM symbols WHERE file_id=?
|
||||
)
|
||||
""",
|
||||
(file_id,)
|
||||
)
|
||||
|
||||
# Insert new relationships
|
||||
relationship_rows = []
|
||||
for rel in relationships:
|
||||
# Find source symbol ID
|
||||
symbol_row = conn.execute(
|
||||
"""
|
||||
SELECT id FROM symbols
|
||||
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
|
||||
ORDER BY (end_line - start_line) ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
|
||||
).fetchone()
|
||||
|
||||
if symbol_row:
|
||||
source_symbol_id = int(symbol_row["id"])
|
||||
relationship_rows.append((
|
||||
source_symbol_id,
|
||||
rel.target_symbol,
|
||||
rel.relationship_type,
|
||||
rel.source_line,
|
||||
rel.target_file
|
||||
))
|
||||
|
||||
if relationship_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO code_relationships(
|
||||
source_symbol_id, target_qualified_name, relationship_type,
|
||||
source_line, target_file
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
""",
|
||||
relationship_rows
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def query_relationships_by_target(
|
||||
self, target_name: str, *, limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query relationships by target symbol name (find all callers).
|
||||
|
||||
Args:
|
||||
target_name: Name of the target symbol
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of dicts containing relationship info with file paths and line numbers
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.full_path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE r.target_qualified_name = ?
|
||||
ORDER BY f.full_path, r.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(target_name, limit)
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def query_relationships_by_source(
|
||||
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query relationships by source symbol (find what a symbol calls).
|
||||
|
||||
Args:
|
||||
source_symbol: Name of the source symbol
|
||||
source_file: File path containing the source symbol
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of dicts containing relationship info
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
resolved_path = str(Path(source_file).resolve())
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND f.path = ?
|
||||
ORDER BY r.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(source_symbol, resolved_path, limit)
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
"""Legacy method for backward compatibility."""
|
||||
return self._get_connection()
|
||||
|
||||
Reference in New Issue
Block a user