mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-09 02:24:11 +08:00
Add comprehensive tests for tokenizer, performance benchmarks, and TreeSitter parser functionality
- Implemented unit tests for the Tokenizer class, covering various text inputs, edge cases, and fallback mechanisms. - Created performance benchmarks comparing tiktoken and pure Python implementations for token counting. - Developed extensive tests for TreeSitterSymbolParser across Python, JavaScript, and TypeScript, ensuring accurate symbol extraction and parsing. - Added configuration documentation for MCP integration and custom prompts, enhancing usability and flexibility. - Introduced a refactor script for GraphAnalyzer to streamline future improvements.
This commit is contained in:
@@ -1100,6 +1100,103 @@ def clean(
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def graph(
|
||||
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
|
||||
symbol: str = typer.Argument(..., help="Symbol name to query"),
|
||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
|
||||
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
|
||||
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Query semantic graph for code relationships.
|
||||
|
||||
Supported query types:
|
||||
- callers: Find all functions/methods that call the given symbol
|
||||
- callees: Find all functions/methods called by the given symbol
|
||||
- inheritance: Find inheritance relationships for the given class
|
||||
|
||||
Examples:
|
||||
codex-lens graph callers my_function
|
||||
codex-lens graph callees MyClass.method --path src/
|
||||
codex-lens graph inheritance BaseClass
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
search_path = path.expanduser().resolve()
|
||||
|
||||
# Validate query type
|
||||
valid_types = ["callers", "callees", "inheritance"]
|
||||
if query_type not in valid_types:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
|
||||
else:
|
||||
console.print(f"[red]Invalid query type:[/red] {query_type}")
|
||||
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
registry: RegistryStore | None = None
|
||||
try:
|
||||
registry = RegistryStore()
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
|
||||
engine = ChainSearchEngine(registry, mapper)
|
||||
options = SearchOptions(depth=depth, total_limit=limit)
|
||||
|
||||
# Execute graph query based on type
|
||||
if query_type == "callers":
|
||||
results = engine.search_callers(symbol, search_path, options=options)
|
||||
result_type = "callers"
|
||||
elif query_type == "callees":
|
||||
results = engine.search_callees(symbol, search_path, options=options)
|
||||
result_type = "callees"
|
||||
else: # inheritance
|
||||
results = engine.search_inheritance(symbol, search_path, options=options)
|
||||
result_type = "inheritance"
|
||||
|
||||
payload = {
|
||||
"query_type": query_type,
|
||||
"symbol": symbol,
|
||||
"count": len(results),
|
||||
"relationships": results
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
print_json(success=True, result=payload)
|
||||
else:
|
||||
from .output import render_graph_results
|
||||
render_graph_results(results, query_type=query_type, symbol=symbol)
|
||||
|
||||
except SearchError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Graph search error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (search):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except StorageError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Storage error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except CodexLensError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=str(exc))
|
||||
else:
|
||||
console.print(f"[red]Graph query failed:[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Unexpected error: {exc}")
|
||||
else:
|
||||
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
if registry is not None:
|
||||
registry.close()
|
||||
|
||||
|
||||
@app.command("semantic-list")
|
||||
def semantic_list(
|
||||
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),
|
||||
|
||||
@@ -89,3 +89,68 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
|
||||
console.print(header)
|
||||
render_symbols(list(symbols), title="Discovered Symbols")
|
||||
|
||||
|
||||
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
|
||||
"""Render semantic graph query results.
|
||||
|
||||
Args:
|
||||
results: List of relationship dicts
|
||||
query_type: Type of query (callers, callees, inheritance)
|
||||
symbol: Symbol name that was queried
|
||||
"""
|
||||
if not results:
|
||||
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
|
||||
return
|
||||
|
||||
title_map = {
|
||||
"callers": f"Callers of '{symbol}' ({len(results)} found)",
|
||||
"callees": f"Callees of '{symbol}' ({len(results)} found)",
|
||||
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
|
||||
}
|
||||
|
||||
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
|
||||
|
||||
if query_type == "callers":
|
||||
table.add_column("Caller", style="green")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
table.add_column("Type", style="dim")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("source_symbol", "-"),
|
||||
rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-")),
|
||||
rel.get("relationship_type", "-")
|
||||
)
|
||||
|
||||
elif query_type == "callees":
|
||||
table.add_column("Target", style="green")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
table.add_column("Type", style="dim")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("target_symbol", "-"),
|
||||
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-")),
|
||||
rel.get("relationship_type", "-")
|
||||
)
|
||||
|
||||
else: # inheritance
|
||||
table.add_column("Derived Class", style="green")
|
||||
table.add_column("Base Class", style="magenta")
|
||||
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
|
||||
table.add_column("Line", justify="right", style="yellow")
|
||||
|
||||
for rel in results:
|
||||
table.add_row(
|
||||
rel.get("source_symbol", "-"),
|
||||
rel.get("target_symbol", "-"),
|
||||
rel.get("source_file", "-"),
|
||||
str(rel.get("source_line", "-"))
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@@ -83,6 +83,9 @@ class Config:
|
||||
llm_timeout_ms: int = 300000
|
||||
llm_batch_size: int = 5
|
||||
|
||||
# Hybrid chunker configuration
|
||||
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
|
||||
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
|
||||
def __post_init__(self) -> None:
|
||||
try:
|
||||
self.data_dir = self.data_dir.expanduser().resolve()
|
||||
|
||||
@@ -13,6 +13,8 @@ class Symbol(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
kind: str = Field(..., min_length=1)
|
||||
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
|
||||
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
|
||||
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
|
||||
|
||||
@field_validator("range")
|
||||
@classmethod
|
||||
@@ -26,6 +28,13 @@ class Symbol(BaseModel):
|
||||
raise ValueError("end_line must be >= start_line")
|
||||
return value
|
||||
|
||||
@field_validator("token_count")
|
||||
@classmethod
|
||||
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("token_count must be >= 0")
|
||||
return value
|
||||
|
||||
|
||||
class SemanticChunk(BaseModel):
|
||||
"""A semantically meaningful chunk of content, optionally embedded."""
|
||||
@@ -61,6 +70,25 @@ class IndexedFile(BaseModel):
|
||||
return cleaned
|
||||
|
||||
|
||||
class CodeRelationship(BaseModel):
|
||||
"""A relationship between code symbols (e.g., function calls, inheritance)."""
|
||||
|
||||
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
|
||||
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
|
||||
relationship_type: str = Field(..., min_length=1, description="Type of relationship (call, inherits, etc.)")
|
||||
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
|
||||
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
|
||||
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
|
||||
|
||||
@field_validator("relationship_type")
|
||||
@classmethod
|
||||
def validate_relationship_type(cls, value: str) -> str:
|
||||
allowed_types = {"call", "inherits", "imports"}
|
||||
if value not in allowed_types:
|
||||
raise ValueError(f"relationship_type must be one of {allowed_types}")
|
||||
return value
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""A unified search result for lexical or semantic search."""
|
||||
|
||||
|
||||
@@ -10,19 +10,11 @@ from __future__ import annotations
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Protocol
|
||||
|
||||
try:
|
||||
from tree_sitter import Language as TreeSitterLanguage
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
from tree_sitter import Parser as TreeSitterParser
|
||||
except Exception: # pragma: no cover
|
||||
TreeSitterLanguage = None # type: ignore[assignment]
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TreeSitterParser = None # type: ignore[assignment]
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import IndexedFile, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class Parser(Protocol):
|
||||
@@ -34,10 +26,24 @@ class SimpleRegexParser:
|
||||
language_id: str
|
||||
|
||||
def parse(self, text: str, path: Path) -> IndexedFile:
|
||||
# Try tree-sitter first for supported languages
|
||||
if self.language_id in {"python", "javascript", "typescript"}:
|
||||
ts_parser = TreeSitterSymbolParser(self.language_id, path)
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
)
|
||||
|
||||
# Fallback to regex parsing
|
||||
if self.language_id == "python":
|
||||
symbols = _parse_python_symbols(text)
|
||||
symbols = _parse_python_symbols_regex(text)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
symbols = _parse_js_ts_symbols(text, self.language_id, path)
|
||||
symbols = _parse_js_ts_symbols_regex(text)
|
||||
elif self.language_id == "java":
|
||||
symbols = _parse_java_symbols(text)
|
||||
elif self.language_id == "go":
|
||||
@@ -64,120 +70,35 @@ class ParserFactory:
|
||||
return self._parsers[language_id]
|
||||
|
||||
|
||||
# Regex-based fallback parsers
|
||||
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
|
||||
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
|
||||
|
||||
_TREE_SITTER_LANGUAGE_CACHE: Dict[str, TreeSitterLanguage] = {}
|
||||
|
||||
|
||||
def _get_tree_sitter_language(language_id: str, path: Path | None = None) -> TreeSitterLanguage | None:
|
||||
if TreeSitterLanguage is None:
|
||||
return None
|
||||
|
||||
cache_key = language_id
|
||||
if language_id == "typescript" and path is not None and path.suffix.lower() == ".tsx":
|
||||
cache_key = "tsx"
|
||||
|
||||
cached = _TREE_SITTER_LANGUAGE_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
if cache_key == "python":
|
||||
import tree_sitter_python # type: ignore[import-not-found]
|
||||
|
||||
language = TreeSitterLanguage(tree_sitter_python.language())
|
||||
elif cache_key == "javascript":
|
||||
import tree_sitter_javascript # type: ignore[import-not-found]
|
||||
|
||||
language = TreeSitterLanguage(tree_sitter_javascript.language())
|
||||
elif cache_key == "typescript":
|
||||
import tree_sitter_typescript # type: ignore[import-not-found]
|
||||
|
||||
language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
|
||||
elif cache_key == "tsx":
|
||||
import tree_sitter_typescript # type: ignore[import-not-found]
|
||||
|
||||
language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
_TREE_SITTER_LANGUAGE_CACHE[cache_key] = language
|
||||
return language
|
||||
def _parse_python_symbols(text: str) -> List[Symbol]:
|
||||
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
|
||||
ts_parser = TreeSitterSymbolParser("python")
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_python_symbols_regex(text)
|
||||
|
||||
|
||||
def _iter_tree_sitter_nodes(root: TreeSitterNode) -> Iterable[TreeSitterNode]:
|
||||
stack: List[TreeSitterNode] = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
|
||||
def _node_text(source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
|
||||
|
||||
def _node_range(node: TreeSitterNode) -> tuple[int, int]:
|
||||
start_line = node.start_point[0] + 1
|
||||
end_line = node.end_point[0] + 1
|
||||
return (start_line, max(start_line, end_line))
|
||||
|
||||
|
||||
def _python_kind_for_function_node(node: TreeSitterNode) -> str:
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"function_definition", "async_function_definition"}:
|
||||
return "function"
|
||||
if parent.type == "class_definition":
|
||||
return "method"
|
||||
parent = parent.parent
|
||||
return "function"
|
||||
|
||||
|
||||
def _parse_python_symbols_tree_sitter(text: str) -> List[Symbol] | None:
|
||||
if TreeSitterParser is None:
|
||||
return None
|
||||
|
||||
language = _get_tree_sitter_language("python")
|
||||
if language is None:
|
||||
return None
|
||||
|
||||
parser = TreeSitterParser()
|
||||
if hasattr(parser, "set_language"):
|
||||
parser.set_language(language) # type: ignore[attr-defined]
|
||||
else:
|
||||
parser.language = language # type: ignore[assignment]
|
||||
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = parser.parse(source_bytes)
|
||||
root = tree.root_node
|
||||
|
||||
symbols: List[Symbol] = []
|
||||
for node in _iter_tree_sitter_nodes(root):
|
||||
if node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=_node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=_node_range(node),
|
||||
))
|
||||
elif node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=_node_text(source_bytes, name_node),
|
||||
kind=_python_kind_for_function_node(node),
|
||||
range=_node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
def _parse_js_ts_symbols(
|
||||
text: str,
|
||||
language_id: str = "javascript",
|
||||
path: Optional[Path] = None,
|
||||
) -> List[Symbol]:
|
||||
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
|
||||
ts_parser = TreeSitterSymbolParser(language_id, path)
|
||||
if ts_parser.is_available():
|
||||
symbols = ts_parser.parse_symbols(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_js_ts_symbols_regex(text)
|
||||
|
||||
|
||||
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
|
||||
@@ -202,13 +123,6 @@ def _parse_python_symbols_regex(text: str) -> List[Symbol]:
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_python_symbols(text: str) -> List[Symbol]:
|
||||
symbols = _parse_python_symbols_tree_sitter(text)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_python_symbols_regex(text)
|
||||
|
||||
|
||||
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
|
||||
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
|
||||
_JS_ARROW_RE = re.compile(
|
||||
@@ -217,88 +131,6 @@ _JS_ARROW_RE = re.compile(
|
||||
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
|
||||
|
||||
|
||||
def _js_has_class_ancestor(node: TreeSitterNode) -> bool:
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"class_declaration", "class"}:
|
||||
return True
|
||||
parent = parent.parent
|
||||
return False
|
||||
|
||||
|
||||
def _parse_js_ts_symbols_tree_sitter(
|
||||
text: str,
|
||||
language_id: str,
|
||||
path: Path | None = None,
|
||||
) -> List[Symbol] | None:
|
||||
if TreeSitterParser is None:
|
||||
return None
|
||||
|
||||
language = _get_tree_sitter_language(language_id, path)
|
||||
if language is None:
|
||||
return None
|
||||
|
||||
parser = TreeSitterParser()
|
||||
if hasattr(parser, "set_language"):
|
||||
parser.set_language(language) # type: ignore[attr-defined]
|
||||
else:
|
||||
parser.language = language # type: ignore[assignment]
|
||||
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = parser.parse(source_bytes)
|
||||
root = tree.root_node
|
||||
|
||||
symbols: List[Symbol] = []
|
||||
for node in _iter_tree_sitter_nodes(root):
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=_node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=_node_range(node),
|
||||
))
|
||||
elif node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=_node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=_node_range(node),
|
||||
))
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is None
|
||||
or value_node is None
|
||||
or name_node.type not in {"identifier", "property_identifier"}
|
||||
or value_node.type != "arrow_function"
|
||||
):
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=_node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=_node_range(node),
|
||||
))
|
||||
elif node.type == "method_definition" and _js_has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
name = _node_text(source_bytes, name_node)
|
||||
if name == "constructor":
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=name,
|
||||
kind="method",
|
||||
range=_node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||
symbols: List[Symbol] = []
|
||||
in_class = False
|
||||
@@ -338,17 +170,6 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
|
||||
return symbols
|
||||
|
||||
|
||||
def _parse_js_ts_symbols(
|
||||
text: str,
|
||||
language_id: str = "javascript",
|
||||
path: Path | None = None,
|
||||
) -> List[Symbol]:
|
||||
symbols = _parse_js_ts_symbols_tree_sitter(text, language_id, path)
|
||||
if symbols is not None:
|
||||
return symbols
|
||||
return _parse_js_ts_symbols_regex(text)
|
||||
|
||||
|
||||
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
|
||||
_JAVA_METHOD_RE = re.compile(
|
||||
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
|
||||
|
||||
98
codex-lens/src/codexlens/parsers/tokenizer.py
Normal file
98
codex-lens/src/codexlens/parsers/tokenizer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Token counting utilities for CodexLens.
|
||||
|
||||
Provides accurate token counting using tiktoken with character count fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import tiktoken
|
||||
TIKTOKEN_AVAILABLE = True
|
||||
except ImportError:
|
||||
TIKTOKEN_AVAILABLE = False
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Token counter with tiktoken primary and character count fallback."""
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base") -> None:
|
||||
"""Initialize tokenizer.
|
||||
|
||||
Args:
|
||||
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
|
||||
"""
|
||||
self._encoding: Optional[object] = None
|
||||
self._encoding_name = encoding_name
|
||||
|
||||
if TIKTOKEN_AVAILABLE:
|
||||
try:
|
||||
self._encoding = tiktoken.get_encoding(encoding_name)
|
||||
except Exception:
|
||||
# Fallback to character counting if encoding fails
|
||||
self._encoding = None
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text.
|
||||
|
||||
Uses tiktoken if available, otherwise falls back to character count / 4.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if self._encoding is not None:
|
||||
try:
|
||||
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
# Fall through to character count fallback
|
||||
pass
|
||||
|
||||
# Fallback: rough estimate using character count
|
||||
# Average of ~4 characters per token for English text
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
def is_using_tiktoken(self) -> bool:
|
||||
"""Check if tiktoken is being used.
|
||||
|
||||
Returns:
|
||||
True if tiktoken is available and initialized
|
||||
"""
|
||||
return self._encoding is not None
|
||||
|
||||
|
||||
# Global default tokenizer instance
|
||||
_default_tokenizer: Optional[Tokenizer] = None
|
||||
|
||||
|
||||
def get_default_tokenizer() -> Tokenizer:
|
||||
"""Get the global default tokenizer instance.
|
||||
|
||||
Returns:
|
||||
Shared Tokenizer instance
|
||||
"""
|
||||
global _default_tokenizer
|
||||
if _default_tokenizer is None:
|
||||
_default_tokenizer = Tokenizer()
|
||||
return _default_tokenizer
|
||||
|
||||
|
||||
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
|
||||
"""Count tokens in text using default or provided tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
tokenizer: Optional tokenizer instance (uses default if None)
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_default_tokenizer()
|
||||
return tokenizer.count_tokens(text)
|
||||
335
codex-lens/src/codexlens/parsers/treesitter_parser.py
Normal file
335
codex-lens/src/codexlens/parsers/treesitter_parser.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tree-sitter based parser for CodexLens.
|
||||
|
||||
Provides precise AST-level parsing with fallback to regex-based parsing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Language as TreeSitterLanguage
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
from tree_sitter import Parser as TreeSitterParser
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterLanguage = None # type: ignore[assignment]
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TreeSitterParser = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import IndexedFile, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
class TreeSitterSymbolParser:
|
||||
"""Parser using tree-sitter for AST-level symbol extraction."""
|
||||
|
||||
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
|
||||
"""Initialize tree-sitter parser for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
path: Optional file path for language variant detection (e.g., .tsx)
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self.path = path
|
||||
self._parser: Optional[object] = None
|
||||
self._language: Optional[TreeSitterLanguage] = None
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
|
||||
if TREE_SITTER_AVAILABLE:
|
||||
self._initialize_parser()
|
||||
|
||||
def _initialize_parser(self) -> None:
|
||||
"""Initialize tree-sitter parser and language."""
|
||||
if TreeSitterParser is None or TreeSitterLanguage is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# Load language grammar
|
||||
if self.language_id == "python":
|
||||
import tree_sitter_python
|
||||
self._language = TreeSitterLanguage(tree_sitter_python.language())
|
||||
elif self.language_id == "javascript":
|
||||
import tree_sitter_javascript
|
||||
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
|
||||
elif self.language_id == "typescript":
|
||||
import tree_sitter_typescript
|
||||
# Detect TSX files by extension
|
||||
if self.path is not None and self.path.suffix.lower() == ".tsx":
|
||||
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
|
||||
else:
|
||||
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
|
||||
else:
|
||||
return
|
||||
|
||||
# Create parser
|
||||
self._parser = TreeSitterParser()
|
||||
if hasattr(self._parser, "set_language"):
|
||||
self._parser.set_language(self._language) # type: ignore[attr-defined]
|
||||
else:
|
||||
self._parser.language = self._language # type: ignore[assignment]
|
||||
|
||||
except Exception:
|
||||
# Gracefully handle missing language bindings
|
||||
self._parser = None
|
||||
self._language = None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if tree-sitter parser is available.
|
||||
|
||||
Returns:
|
||||
True if parser is initialized and ready
|
||||
"""
|
||||
return self._parser is not None and self._language is not None
|
||||
|
||||
|
||||
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
|
||||
"""Parse source code and extract symbols without creating IndexedFile.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
|
||||
Returns:
|
||||
List of symbols if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
if not self.is_available() or self._parser is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
return self._extract_symbols(source_bytes, root)
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return None
|
||||
|
||||
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
|
||||
"""Parse source code and extract symbols.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
path: File path
|
||||
|
||||
Returns:
|
||||
IndexedFile if parsing succeeds, None if tree-sitter unavailable
|
||||
"""
|
||||
if not self.is_available() or self._parser is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
symbols = self.parse_symbols(text)
|
||||
if symbols is None:
|
||||
return None
|
||||
|
||||
return IndexedFile(
|
||||
path=str(path.resolve()),
|
||||
language=self.language_id,
|
||||
symbols=symbols,
|
||||
chunks=[],
|
||||
)
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return None
|
||||
|
||||
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of extracted symbols
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_symbols(source_bytes, root)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_symbols(source_bytes, root)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract Python symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of Python symbols (classes, functions, methods)
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind=self._python_function_kind(node),
|
||||
range=self._node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
|
||||
"""Extract JavaScript/TypeScript symbols from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of JS/TS symbols (classes, functions, methods)
|
||||
"""
|
||||
symbols: List[Symbol] = []
|
||||
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="class",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if (
|
||||
name_node is None
|
||||
or value_node is None
|
||||
or name_node.type not in {"identifier", "property_identifier"}
|
||||
or value_node.type != "arrow_function"
|
||||
):
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=self._node_text(source_bytes, name_node),
|
||||
kind="function",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
elif node.type == "method_definition" and self._has_class_ancestor(node):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node is None:
|
||||
continue
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name == "constructor":
|
||||
continue
|
||||
symbols.append(Symbol(
|
||||
name=name,
|
||||
kind="method",
|
||||
range=self._node_range(node),
|
||||
))
|
||||
|
||||
return symbols
|
||||
|
||||
def _python_function_kind(self, node: TreeSitterNode) -> str:
|
||||
"""Determine if Python function is a method or standalone function.
|
||||
|
||||
Args:
|
||||
node: Function definition node
|
||||
|
||||
Returns:
|
||||
'method' if inside a class, 'function' otherwise
|
||||
"""
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"function_definition", "async_function_definition"}:
|
||||
return "function"
|
||||
if parent.type == "class_definition":
|
||||
return "method"
|
||||
parent = parent.parent
|
||||
return "function"
|
||||
|
||||
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
|
||||
"""Check if node has a class ancestor.
|
||||
|
||||
Args:
|
||||
node: AST node to check
|
||||
|
||||
Returns:
|
||||
True if node is inside a class
|
||||
"""
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
if parent.type in {"class_declaration", "class"}:
|
||||
return True
|
||||
parent = parent.parent
|
||||
return False
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
|
||||
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
|
||||
"""Get line range for a node.
|
||||
|
||||
Args:
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
(start_line, end_line) tuple, 1-based inclusive
|
||||
"""
|
||||
start_line = node.start_point[0] + 1
|
||||
end_line = node.end_point[0] + 1
|
||||
return (start_line, max(start_line, end_line))
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
"""Count tokens in text.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Token count
|
||||
"""
|
||||
return self._tokenizer.count_tokens(text)
|
||||
@@ -17,6 +17,7 @@ from codexlens.entities import SearchResult, Symbol
|
||||
from codexlens.storage.registry import RegistryStore, DirMapping
|
||||
from codexlens.storage.dir_index import DirIndexStore, SubdirLink
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.sqlite_store import SQLiteStore
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -278,6 +279,108 @@ class ChainSearchEngine:
|
||||
index_paths, name, kind, options.total_limit
|
||||
)
|
||||
|
||||
def search_callers(self, target_symbol: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find all callers of a given symbol across directory hierarchy.
|
||||
|
||||
Args:
|
||||
target_symbol: Name of the symbol to find callers for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with caller information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> callers = engine.search_callers("my_function", Path("D:/project"))
|
||||
>>> for caller in callers:
|
||||
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_callers_parallel(
|
||||
index_paths, target_symbol, options.total_limit
|
||||
)
|
||||
|
||||
def search_callees(self, source_symbol: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find all callees (what a symbol calls) across directory hierarchy.
|
||||
|
||||
Args:
|
||||
source_symbol: Name of the symbol to find callees for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with callee information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
|
||||
>>> for callee in callees:
|
||||
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_callees_parallel(
|
||||
index_paths, source_symbol, options.total_limit
|
||||
)
|
||||
|
||||
def search_inheritance(self, class_name: str,
|
||||
source_path: Path,
|
||||
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
|
||||
"""Find inheritance relationships for a class across directory hierarchy.
|
||||
|
||||
Args:
|
||||
class_name: Name of the class to find inheritance for
|
||||
source_path: Starting directory path
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
List of relationship dicts with inheritance information
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper)
|
||||
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
|
||||
>>> for rel in inheritance:
|
||||
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
return []
|
||||
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
if not index_paths:
|
||||
return []
|
||||
|
||||
return self._search_inheritance_parallel(
|
||||
index_paths, class_name, options.total_limit
|
||||
)
|
||||
|
||||
# === Internal Methods ===
|
||||
|
||||
def _find_start_index(self, source_path: Path) -> Optional[Path]:
|
||||
@@ -553,6 +656,252 @@ class ChainSearchEngine:
|
||||
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_callers_parallel(self, index_paths: List[Path],
|
||||
target_symbol: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for callers across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
target_symbol: Target symbol name
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of caller relationships
|
||||
"""
|
||||
all_callers = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_callers_single,
|
||||
idx_path,
|
||||
target_symbol
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
callers = future.result()
|
||||
all_callers.extend(callers)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Caller search failed: {exc}")
|
||||
|
||||
# Deduplicate by (source_file, source_line)
|
||||
seen = set()
|
||||
unique_callers = []
|
||||
for caller in all_callers:
|
||||
key = (caller.get("source_file"), caller.get("source_line"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_callers.append(caller)
|
||||
|
||||
# Sort by source file and line
|
||||
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
|
||||
|
||||
return unique_callers[:limit]
|
||||
|
||||
def _search_callers_single(self, index_path: Path,
|
||||
target_symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Search for callers in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
target_symbol: Target symbol name
|
||||
|
||||
Returns:
|
||||
List of caller relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
return store.query_relationships_by_target(target_symbol)
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Caller search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_callees_parallel(self, index_paths: List[Path],
|
||||
source_symbol: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for callees across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
source_symbol: Source symbol name
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of callee relationships
|
||||
"""
|
||||
all_callees = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_callees_single,
|
||||
idx_path,
|
||||
source_symbol
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
callees = future.result()
|
||||
all_callees.extend(callees)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Callee search failed: {exc}")
|
||||
|
||||
# Deduplicate by (target_symbol, source_line)
|
||||
seen = set()
|
||||
unique_callees = []
|
||||
for callee in all_callees:
|
||||
key = (callee.get("target_symbol"), callee.get("source_line"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_callees.append(callee)
|
||||
|
||||
# Sort by source line
|
||||
unique_callees.sort(key=lambda c: c.get("source_line", 0))
|
||||
|
||||
return unique_callees[:limit]
|
||||
|
||||
def _search_callees_single(self, index_path: Path,
|
||||
source_symbol: str) -> List[Dict[str, Any]]:
|
||||
"""Search for callees in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
source_symbol: Source symbol name
|
||||
|
||||
Returns:
|
||||
List of callee relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
# Use the connection pool via SQLiteStore
|
||||
with SQLiteStore(index_path) as store:
|
||||
# Search across all files containing the symbol
|
||||
# Get all files that have this symbol
|
||||
conn = store._get_connection()
|
||||
file_rows = conn.execute(
|
||||
"""
|
||||
SELECT DISTINCT f.path
|
||||
FROM symbols s
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ?
|
||||
""",
|
||||
(source_symbol,)
|
||||
).fetchall()
|
||||
|
||||
# Collect results from all matching files
|
||||
all_results = []
|
||||
for file_row in file_rows:
|
||||
file_path = file_row["path"]
|
||||
results = store.query_relationships_by_source(source_symbol, file_path)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Callee search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
def _search_inheritance_parallel(self, index_paths: List[Path],
|
||||
class_name: str,
|
||||
limit: int) -> List[Dict[str, Any]]:
|
||||
"""Search for inheritance relationships across multiple indexes in parallel.
|
||||
|
||||
Args:
|
||||
index_paths: List of _index.db paths to search
|
||||
class_name: Class name to search for
|
||||
limit: Total result limit
|
||||
|
||||
Returns:
|
||||
Deduplicated list of inheritance relationships
|
||||
"""
|
||||
all_inheritance = []
|
||||
|
||||
executor = self._get_executor()
|
||||
future_to_path = {
|
||||
executor.submit(
|
||||
self._search_inheritance_single,
|
||||
idx_path,
|
||||
class_name
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_path):
|
||||
try:
|
||||
inheritance = future.result()
|
||||
all_inheritance.extend(inheritance)
|
||||
except Exception as exc:
|
||||
self.logger.error(f"Inheritance search failed: {exc}")
|
||||
|
||||
# Deduplicate by (source_symbol, target_symbol)
|
||||
seen = set()
|
||||
unique_inheritance = []
|
||||
for rel in all_inheritance:
|
||||
key = (rel.get("source_symbol"), rel.get("target_symbol"))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_inheritance.append(rel)
|
||||
|
||||
# Sort by source file
|
||||
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
|
||||
|
||||
return unique_inheritance[:limit]
|
||||
|
||||
def _search_inheritance_single(self, index_path: Path,
|
||||
class_name: str) -> List[Dict[str, Any]]:
|
||||
"""Search for inheritance relationships in a single index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
class_name: Class name to search for
|
||||
|
||||
Returns:
|
||||
List of inheritance relationship dicts (empty on error)
|
||||
"""
|
||||
try:
|
||||
with SQLiteStore(index_path) as store:
|
||||
conn = store._get_connection()
|
||||
|
||||
# Search both as base class (target) and derived class (source)
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE (s.name = ? OR r.target_qualified_name LIKE ?)
|
||||
AND r.relationship_type = 'inherits'
|
||||
ORDER BY f.path, r.source_line
|
||||
LIMIT 100
|
||||
""",
|
||||
(class_name, f"%{class_name}%")
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except Exception as exc:
|
||||
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
|
||||
return []
|
||||
|
||||
|
||||
# === Convenience Functions ===
|
||||
|
||||
|
||||
@@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SemanticChunk, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,6 +15,7 @@ class ChunkConfig:
|
||||
"""Configuration for chunking strategies."""
|
||||
max_chunk_size: int = 1000 # Max characters per chunk
|
||||
overlap: int = 100 # Overlap for sliding window
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
|
||||
|
||||
@@ -22,6 +24,7 @@ class Chunker:
|
||||
|
||||
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||
self.config = config or ChunkConfig()
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
|
||||
def chunk_by_symbol(
|
||||
self,
|
||||
@@ -29,10 +32,18 @@ class Chunker:
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code by extracted symbols (functions, classes).
|
||||
|
||||
Each symbol becomes one chunk with its full content.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
@@ -47,6 +58,13 @@ class Chunker:
|
||||
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._tokenizer.count_tokens(chunk_content)
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
embedding=None,
|
||||
@@ -58,6 +76,7 @@ class Chunker:
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
|
||||
@@ -68,10 +87,19 @@ class Chunker:
|
||||
content: str,
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
line_mapping: Optional[List[int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code using sliding window approach.
|
||||
|
||||
Used for files without clear symbol boundaries or very long functions.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
line_mapping: Optional list mapping content line indices to original line numbers
|
||||
(1-indexed). If provided, line_mapping[i] is the original line number
|
||||
for the i-th line in content.
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
@@ -92,6 +120,18 @@ class Chunker:
|
||||
chunk_content = "".join(lines[start:end])
|
||||
|
||||
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||
token_count = self._tokenizer.count_tokens(chunk_content)
|
||||
|
||||
# Calculate correct line numbers
|
||||
if line_mapping:
|
||||
# Use line mapping to get original line numbers
|
||||
start_line = line_mapping[start]
|
||||
end_line = line_mapping[end - 1]
|
||||
else:
|
||||
# Default behavior: treat content as starting at line 1
|
||||
start_line = start + 1
|
||||
end_line = end
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
embedding=None,
|
||||
@@ -99,9 +139,10 @@ class Chunker:
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start + 1,
|
||||
"end_line": end,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
chunk_idx += 1
|
||||
@@ -119,12 +160,239 @@ class Chunker:
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk a file using the best strategy.
|
||||
|
||||
Uses symbol-based chunking if symbols available,
|
||||
falls back to sliding window for files without symbols.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
if symbols:
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language)
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
|
||||
return self.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
class DocstringExtractor:
|
||||
"""Extract docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract Python docstrings with their line ranges.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
start_line = i + 1
|
||||
|
||||
if stripped.count(quote_type) >= 2:
|
||||
docstring_content = line
|
||||
end_line = i + 1
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
docstring_lines = [line]
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
docstring_lines.append(lines[i])
|
||||
if quote_type in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
docstring_content = "".join(docstring_lines)
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return docstrings
|
||||
|
||||
@staticmethod
|
||||
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract JSDoc comments with their line ranges.
|
||||
|
||||
Returns: List of (comment_content, start_line, end_line) tuples
|
||||
"""
|
||||
comments: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('/**'):
|
||||
start_line = i + 1
|
||||
comment_lines = [line]
|
||||
i += 1
|
||||
|
||||
while i < len(lines):
|
||||
comment_lines.append(lines[i])
|
||||
if '*/' in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
comment_content = "".join(comment_lines)
|
||||
comments.append((comment_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return comments
|
||||
|
||||
@classmethod
|
||||
def extract_docstrings(
|
||||
cls,
|
||||
content: str,
|
||||
language: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Extract docstrings based on language.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.extract_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.extract_jsdoc_comments(content)
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
|
||||
|
||||
Composition-based strategy that:
|
||||
1. Extracts docstrings as dedicated chunks
|
||||
2. For remaining code, uses base chunker (symbol or sliding window)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_chunker: Chunker | None = None,
|
||||
config: ChunkConfig | None = None
|
||||
) -> None:
|
||||
"""Initialize hybrid chunker.
|
||||
|
||||
Args:
|
||||
base_chunker: Chunker to use for non-docstring content
|
||||
config: Configuration for chunking
|
||||
"""
|
||||
self.config = config or ChunkConfig()
|
||||
self.base_chunker = base_chunker or Chunker(self.config)
|
||||
self.docstring_extractor = DocstringExtractor()
|
||||
|
||||
def _get_excluded_line_ranges(
|
||||
self,
|
||||
docstrings: List[Tuple[str, int, int]]
|
||||
) -> set[int]:
|
||||
"""Get set of line numbers that are part of docstrings."""
|
||||
excluded_lines: set[int] = set()
|
||||
for _, start_line, end_line in docstrings:
|
||||
for line_num in range(start_line, end_line + 1):
|
||||
excluded_lines.add(line_num)
|
||||
return excluded_lines
|
||||
|
||||
def _filter_symbols_outside_docstrings(
|
||||
self,
|
||||
symbols: List[Symbol],
|
||||
excluded_lines: set[int]
|
||||
) -> List[Symbol]:
|
||||
"""Filter symbols to exclude those completely within docstrings."""
|
||||
filtered: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
symbol_lines = set(range(start_line, end_line + 1))
|
||||
if not symbol_lines.issubset(excluded_lines):
|
||||
filtered.append(symbol)
|
||||
return filtered
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk file using hybrid strategy.
|
||||
|
||||
Extracts docstrings first, then chunks remaining code.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
tokenizer = get_default_tokenizer()
|
||||
|
||||
# Step 1: Extract docstrings as dedicated chunks
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
|
||||
for docstring_content, start_line, end_line in docstrings:
|
||||
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||
token_count = tokenizer.count_tokens(docstring_content)
|
||||
chunks.append(SemanticChunk(
|
||||
content=docstring_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
|
||||
# Step 2: Get line ranges occupied by docstrings
|
||||
excluded_lines = self._get_excluded_line_ranges(docstrings)
|
||||
|
||||
# Step 3: Filter symbols to exclude docstring-only ranges
|
||||
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
|
||||
|
||||
# Step 4: Chunk remaining content using base chunker
|
||||
if filtered_symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, filtered_symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
lines = content.splitlines(keepends=True)
|
||||
remaining_lines: List[str] = []
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
if i not in excluded_lines:
|
||||
remaining_lines.append(line)
|
||||
|
||||
if remaining_lines:
|
||||
remaining_content = "".join(remaining_lines)
|
||||
if len(remaining_content.strip()) >= self.config.min_chunk_size:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(
|
||||
remaining_content, file_path, language
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
531
codex-lens/src/codexlens/semantic/graph_analyzer.py
Normal file
531
codex-lens/src/codexlens/semantic/graph_analyzer.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""Graph analyzer for extracting code relationships using tree-sitter.
|
||||
|
||||
Provides AST-based analysis to identify function calls, method invocations,
|
||||
and class inheritance relationships within source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import CodeRelationship, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class GraphAnalyzer:
|
||||
"""Analyzer for extracting semantic relationships from code using AST traversal."""
|
||||
|
||||
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
|
||||
"""Initialize graph analyzer for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
parser: Optional TreeSitterSymbolParser instance for dependency injection.
|
||||
If None, creates a new parser instance (backward compatibility).
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if graph analyzer is available.
|
||||
|
||||
Returns:
|
||||
True if tree-sitter parser is initialized and ready
|
||||
"""
|
||||
return self._parser.is_available()
|
||||
|
||||
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
|
||||
"""Analyze source code and extract relationships.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def analyze_with_symbols(
|
||||
self, text: str, file_path: Path, symbols: List[Symbol]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
# Convert Symbol objects to internal symbol format
|
||||
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
|
||||
|
||||
# Extract relationships using provided symbols
|
||||
relationships = self._extract_relationships_with_symbols(
|
||||
source_bytes, root, str(file_path.resolve()), defined_symbols
|
||||
)
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def _convert_symbols_to_dict(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
|
||||
) -> List[dict]:
|
||||
"""Convert Symbol objects to internal dict format for relationship extraction.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
symbols: Pre-parsed Symbol objects
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbol_dicts = []
|
||||
symbol_names = {s.name for s in symbols}
|
||||
|
||||
# Find AST nodes corresponding to symbols
|
||||
for node in self._iter_nodes(root):
|
||||
node_type = node.type
|
||||
|
||||
# Check if this node matches any of our symbols
|
||||
if node_type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node_type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
|
||||
return symbol_dicts
|
||||
|
||||
def _extract_relationships_with_symbols(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST using pre-parsed symbols.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
defined_symbols: Pre-parsed symbol dicts
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# Determine call node type based on language
|
||||
if self.language_id == "python":
|
||||
call_node_type = "call"
|
||||
extract_target = self._extract_call_target
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
call_node_type = "call_expression"
|
||||
extract_target = self._extract_js_call_target
|
||||
else:
|
||||
return []
|
||||
|
||||
# Find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == call_node_type:
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = extract_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, file_path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, file_path)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract Python relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of Python relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols with their scopes
|
||||
defined_symbols = self._collect_python_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call":
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = self._extract_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract JavaScript/TypeScript relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of JS/TS relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols
|
||||
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call_expression":
|
||||
# Extract caller context
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee
|
||||
target_symbol = self._extract_js_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all Python function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all JS/TS function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
|
||||
"""Find the enclosing function/method/class for a node.
|
||||
|
||||
Args:
|
||||
node: AST node to find enclosure for
|
||||
symbols: List of defined symbols
|
||||
|
||||
Returns:
|
||||
Name of enclosing symbol, or None if at module level
|
||||
"""
|
||||
# Walk up the tree to find enclosing symbol
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
for symbol in symbols:
|
||||
if symbol["node"] == parent:
|
||||
return symbol["name"]
|
||||
parent = parent.parent
|
||||
return None
|
||||
|
||||
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a Python call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers (e.g., "foo()")
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle attribute access (e.g., "obj.method()")
|
||||
if function_node.type == "attribute":
|
||||
attr_node = function_node.child_by_field_name("attribute")
|
||||
if attr_node:
|
||||
return self._node_text(source_bytes, attr_node)
|
||||
|
||||
return None
|
||||
|
||||
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a JS/TS call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle member expressions (e.g., "obj.method()")
|
||||
if function_node.type == "member_expression":
|
||||
property_node = function_node.child_by_field_name("property")
|
||||
if property_node:
|
||||
return self._node_text(source_bytes, property_node)
|
||||
|
||||
return None
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
@@ -75,6 +75,34 @@ class LLMEnhancer:
|
||||
external LLM tools (gemini, qwen) via CCW CLI subprocess.
|
||||
"""
|
||||
|
||||
CHUNK_REFINEMENT_PROMPT = '''PURPOSE: Identify optimal semantic split points in code chunk
|
||||
TASK:
|
||||
- Analyze the code structure to find natural semantic boundaries
|
||||
- Identify logical groupings (functions, classes, related statements)
|
||||
- Suggest split points that maintain semantic cohesion
|
||||
MODE: analysis
|
||||
EXPECTED: JSON format with split positions
|
||||
|
||||
=== CODE CHUNK ===
|
||||
{code_chunk}
|
||||
|
||||
=== OUTPUT FORMAT ===
|
||||
Return ONLY valid JSON (no markdown, no explanation):
|
||||
{{
|
||||
"split_points": [
|
||||
{{
|
||||
"line": <line_number>,
|
||||
"reason": "brief reason for split (e.g., 'start of new function', 'end of class definition')"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- Split at function/class/method boundaries
|
||||
- Keep related code together (don't split mid-function)
|
||||
- Aim for chunks between 500-2000 characters
|
||||
- Return empty split_points if no good splits found'''
|
||||
|
||||
PROMPT_TEMPLATE = '''PURPOSE: Generate semantic summaries and search keywords for code files
|
||||
TASK:
|
||||
- For each code block, generate a concise summary (1-2 sentences)
|
||||
@@ -168,42 +196,246 @@ Return ONLY valid JSON (no markdown, no explanation):
|
||||
return results
|
||||
|
||||
def enhance_file(
|
||||
|
||||
self,
|
||||
|
||||
path: str,
|
||||
|
||||
content: str,
|
||||
|
||||
language: str,
|
||||
|
||||
working_dir: Optional[Path] = None,
|
||||
|
||||
) -> SemanticMetadata:
|
||||
|
||||
"""Enhance a single file with LLM-generated semantic metadata.
|
||||
|
||||
|
||||
|
||||
Convenience method that wraps enhance_files for single file processing.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
path: File path
|
||||
|
||||
content: File content
|
||||
|
||||
language: Programming language
|
||||
|
||||
working_dir: Optional working directory for CCW CLI
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
SemanticMetadata for the file
|
||||
|
||||
|
||||
|
||||
Raises:
|
||||
|
||||
ValueError: If enhancement fails
|
||||
|
||||
"""
|
||||
|
||||
file_data = FileData(path=path, content=content, language=language)
|
||||
|
||||
results = self.enhance_files([file_data], working_dir)
|
||||
|
||||
|
||||
|
||||
if path not in results:
|
||||
|
||||
# Return default metadata if enhancement failed
|
||||
|
||||
return SemanticMetadata(
|
||||
|
||||
summary=f"Code file written in {language}",
|
||||
|
||||
keywords=[language, "code"],
|
||||
|
||||
purpose="unknown",
|
||||
|
||||
file_path=path,
|
||||
|
||||
llm_tool=self.config.tool,
|
||||
|
||||
)
|
||||
|
||||
|
||||
|
||||
return results[path]
|
||||
|
||||
def refine_chunk_boundaries(
|
||||
self,
|
||||
chunk: SemanticChunk,
|
||||
max_chunk_size: int = 2000,
|
||||
working_dir: Optional[Path] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Refine chunk boundaries using LLM for large code chunks.
|
||||
|
||||
Uses LLM to identify semantic split points in large chunks,
|
||||
breaking them into smaller, more cohesive pieces.
|
||||
|
||||
Args:
|
||||
chunk: Original chunk to refine
|
||||
max_chunk_size: Maximum characters before triggering refinement
|
||||
working_dir: Optional working directory for CCW CLI
|
||||
|
||||
Returns:
|
||||
SemanticMetadata for the file
|
||||
|
||||
Raises:
|
||||
ValueError: If enhancement fails
|
||||
List of refined chunks (original chunk if no splits or refinement fails)
|
||||
"""
|
||||
file_data = FileData(path=path, content=content, language=language)
|
||||
results = self.enhance_files([file_data], working_dir)
|
||||
# Skip if chunk is small enough
|
||||
if len(chunk.content) <= max_chunk_size:
|
||||
return [chunk]
|
||||
|
||||
if path not in results:
|
||||
# Return default metadata if enhancement failed
|
||||
return SemanticMetadata(
|
||||
summary=f"Code file written in {language}",
|
||||
keywords=[language, "code"],
|
||||
purpose="unknown",
|
||||
file_path=path,
|
||||
llm_tool=self.config.tool,
|
||||
# Skip if LLM enhancement disabled or unavailable
|
||||
if not self.config.enabled or not self.check_available():
|
||||
return [chunk]
|
||||
|
||||
# Skip docstring chunks - only refine code chunks
|
||||
if chunk.metadata.get("chunk_type") == "docstring":
|
||||
return [chunk]
|
||||
|
||||
try:
|
||||
# Build refinement prompt
|
||||
prompt = self.CHUNK_REFINEMENT_PROMPT.format(code_chunk=chunk.content)
|
||||
|
||||
# Invoke LLM
|
||||
result = self._invoke_ccw_cli(
|
||||
prompt,
|
||||
tool=self.config.tool,
|
||||
working_dir=working_dir,
|
||||
)
|
||||
|
||||
return results[path]
|
||||
# Fallback if primary tool fails
|
||||
if not result["success"] and self.config.fallback_tool:
|
||||
result = self._invoke_ccw_cli(
|
||||
prompt,
|
||||
tool=self.config.fallback_tool,
|
||||
working_dir=working_dir,
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
logger.debug("LLM refinement failed, returning original chunk")
|
||||
return [chunk]
|
||||
|
||||
# Parse split points
|
||||
split_points = self._parse_split_points(result["stdout"])
|
||||
if not split_points:
|
||||
logger.debug("No split points identified, returning original chunk")
|
||||
return [chunk]
|
||||
|
||||
# Split chunk at identified boundaries
|
||||
refined_chunks = self._split_chunk_at_points(chunk, split_points)
|
||||
logger.debug(
|
||||
"Refined chunk into %d smaller chunks (was %d chars)",
|
||||
len(refined_chunks),
|
||||
len(chunk.content),
|
||||
)
|
||||
return refined_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Chunk refinement error: %s, returning original chunk", e)
|
||||
return [chunk]
|
||||
|
||||
def _parse_split_points(self, stdout: str) -> List[int]:
|
||||
"""Parse split points from LLM response.
|
||||
|
||||
Args:
|
||||
stdout: Raw stdout from CCW CLI
|
||||
|
||||
Returns:
|
||||
List of line numbers where splits should occur (sorted)
|
||||
"""
|
||||
# Extract JSON from response
|
||||
json_str = self._extract_json(stdout)
|
||||
if not json_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
split_points_data = data.get("split_points", [])
|
||||
|
||||
# Extract line numbers
|
||||
lines = []
|
||||
for point in split_points_data:
|
||||
if isinstance(point, dict) and "line" in point:
|
||||
line_num = point["line"]
|
||||
if isinstance(line_num, int) and line_num > 0:
|
||||
lines.append(line_num)
|
||||
|
||||
return sorted(set(lines))
|
||||
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
logger.debug("Failed to parse split points: %s", e)
|
||||
return []
|
||||
|
||||
def _split_chunk_at_points(
|
||||
self,
|
||||
chunk: SemanticChunk,
|
||||
split_points: List[int],
|
||||
) -> List[SemanticChunk]:
|
||||
"""Split chunk at specified line numbers.
|
||||
|
||||
Args:
|
||||
chunk: Original chunk to split
|
||||
split_points: Sorted list of line numbers to split at
|
||||
|
||||
Returns:
|
||||
List of smaller chunks
|
||||
"""
|
||||
lines = chunk.content.splitlines(keepends=True)
|
||||
chunks: List[SemanticChunk] = []
|
||||
|
||||
# Get original metadata
|
||||
base_metadata = dict(chunk.metadata)
|
||||
original_start = base_metadata.get("start_line", 1)
|
||||
|
||||
# Add start and end boundaries
|
||||
boundaries = [0] + split_points + [len(lines)]
|
||||
|
||||
for i in range(len(boundaries) - 1):
|
||||
start_idx = boundaries[i]
|
||||
end_idx = boundaries[i + 1]
|
||||
|
||||
# Skip empty sections
|
||||
if start_idx >= end_idx:
|
||||
continue
|
||||
|
||||
# Extract content
|
||||
section_lines = lines[start_idx:end_idx]
|
||||
section_content = "".join(section_lines)
|
||||
|
||||
# Skip if too small
|
||||
if len(section_content.strip()) < 50:
|
||||
continue
|
||||
|
||||
# Create new chunk with updated metadata
|
||||
new_metadata = base_metadata.copy()
|
||||
new_metadata["start_line"] = original_start + start_idx
|
||||
new_metadata["end_line"] = original_start + end_idx - 1
|
||||
new_metadata["refined_by_llm"] = True
|
||||
new_metadata["original_chunk_size"] = len(chunk.content)
|
||||
|
||||
chunks.append(
|
||||
SemanticChunk(
|
||||
content=section_content,
|
||||
embedding=None, # Embeddings will be regenerated
|
||||
metadata=new_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# If no valid chunks created, return original
|
||||
if not chunks:
|
||||
return [chunk]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
|
||||
|
||||
def _process_batch(
|
||||
|
||||
@@ -149,15 +149,21 @@ class DirIndexStore:
|
||||
# Replace symbols
|
||||
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
|
||||
if symbols:
|
||||
# Extract token_count and symbol_type from symbol metadata if available
|
||||
symbol_rows = []
|
||||
for s in symbols:
|
||||
token_count = getattr(s, 'token_count', None)
|
||||
symbol_type = getattr(s, 'symbol_type', None) or s.kind
|
||||
symbol_rows.append(
|
||||
(file_id, s.name, s.kind, s.range[0], s.range[1], token_count, symbol_type)
|
||||
)
|
||||
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
INSERT INTO symbols(file_id, name, kind, start_line, end_line, token_count, symbol_type)
|
||||
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(file_id, s.name, s.kind, s.range[0], s.range[1])
|
||||
for s in symbols
|
||||
],
|
||||
symbol_rows,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -216,15 +222,21 @@ class DirIndexStore:
|
||||
|
||||
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
|
||||
if symbols:
|
||||
# Extract token_count and symbol_type from symbol metadata if available
|
||||
symbol_rows = []
|
||||
for s in symbols:
|
||||
token_count = getattr(s, 'token_count', None)
|
||||
symbol_type = getattr(s, 'symbol_type', None) or s.kind
|
||||
symbol_rows.append(
|
||||
(file_id, s.name, s.kind, s.range[0], s.range[1], token_count, symbol_type)
|
||||
)
|
||||
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
INSERT INTO symbols(file_id, name, kind, start_line, end_line, token_count, symbol_type)
|
||||
VALUES(?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(file_id, s.name, s.kind, s.range[0], s.range[1])
|
||||
for s in symbols
|
||||
],
|
||||
symbol_rows,
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
@@ -1021,7 +1033,9 @@ class DirIndexStore:
|
||||
name TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
start_line INTEGER,
|
||||
end_line INTEGER
|
||||
end_line INTEGER,
|
||||
token_count INTEGER,
|
||||
symbol_type TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -1083,6 +1097,7 @@ class DirIndexStore:
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_subdirs_name ON subdirs(name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_semantic_file ON semantic_metadata(file_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON keywords(keyword)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_file_id ON file_keywords(file_id)")
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Migration 002: Add token_count and symbol_type to symbols table.
|
||||
|
||||
This migration adds token counting metadata to symbols for accurate chunk
|
||||
splitting and performance optimization. It also adds symbol_type for better
|
||||
filtering in searches.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to add token metadata to symbols.
|
||||
|
||||
- Adds token_count column to symbols table
|
||||
- Adds symbol_type column to symbols table (for future use)
|
||||
- Creates index on symbol_type for efficient filtering
|
||||
- Backfills existing symbols with NULL token_count (to be calculated lazily)
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Adding token_count column to symbols table...")
|
||||
try:
|
||||
cursor.execute("ALTER TABLE symbols ADD COLUMN token_count INTEGER")
|
||||
log.info("Successfully added token_count column.")
|
||||
except Exception as e:
|
||||
# Column might already exist
|
||||
log.warning(f"Could not add token_count column (might already exist): {e}")
|
||||
|
||||
log.info("Adding symbol_type column to symbols table...")
|
||||
try:
|
||||
cursor.execute("ALTER TABLE symbols ADD COLUMN symbol_type TEXT")
|
||||
log.info("Successfully added symbol_type column.")
|
||||
except Exception as e:
|
||||
# Column might already exist
|
||||
log.warning(f"Could not add symbol_type column (might already exist): {e}")
|
||||
|
||||
log.info("Creating index on symbol_type for efficient filtering...")
|
||||
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
|
||||
|
||||
log.info("Migration 002 completed successfully.")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Migration 003: Add code relationships storage.
|
||||
|
||||
This migration introduces the `code_relationships` table to store semantic
|
||||
relationships between code symbols (function calls, inheritance, imports).
|
||||
This enables graph-based code navigation and dependency analysis.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection):
|
||||
"""
|
||||
Applies the migration to add code relationships table.
|
||||
|
||||
- Creates `code_relationships` table with foreign key to symbols
|
||||
- Creates indexes for efficient relationship queries
|
||||
- Supports lazy expansion with target_symbol being qualified names
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating 'code_relationships' table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT,
|
||||
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for code_relationships...")
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
|
||||
)
|
||||
|
||||
log.info("Finished creating code_relationships table and indexes.")
|
||||
@@ -9,7 +9,7 @@ from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from codexlens.entities import IndexedFile, SearchResult, Symbol
|
||||
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
|
||||
@@ -309,13 +309,184 @@ class SQLiteStore:
|
||||
"SELECT language, COUNT(*) AS c FROM files GROUP BY language ORDER BY c DESC"
|
||||
).fetchall()
|
||||
languages = {row["language"]: row["c"] for row in lang_rows}
|
||||
# Include relationship count if table exists
|
||||
relationship_count = 0
|
||||
try:
|
||||
rel_row = conn.execute("SELECT COUNT(*) AS c FROM code_relationships").fetchone()
|
||||
relationship_count = int(rel_row["c"]) if rel_row else 0
|
||||
except sqlite3.DatabaseError:
|
||||
pass
|
||||
|
||||
return {
|
||||
"files": int(file_count),
|
||||
"symbols": int(symbol_count),
|
||||
"relationships": relationship_count,
|
||||
"languages": languages,
|
||||
"db_path": str(self.db_path),
|
||||
}
|
||||
|
||||
|
||||
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
|
||||
"""Store code relationships for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file containing the relationships
|
||||
relationships: List of CodeRelationship objects to store
|
||||
"""
|
||||
if not relationships:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
resolved_path = str(Path(file_path).resolve())
|
||||
|
||||
# Get file_id
|
||||
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
|
||||
if not row:
|
||||
raise StorageError(f"File not found in index: {file_path}")
|
||||
file_id = int(row["id"])
|
||||
|
||||
# Delete existing relationships for symbols in this file
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM code_relationships
|
||||
WHERE source_symbol_id IN (
|
||||
SELECT id FROM symbols WHERE file_id=?
|
||||
)
|
||||
""",
|
||||
(file_id,)
|
||||
)
|
||||
|
||||
# Insert new relationships
|
||||
relationship_rows = []
|
||||
for rel in relationships:
|
||||
# Find source symbol ID
|
||||
symbol_row = conn.execute(
|
||||
"""
|
||||
SELECT id FROM symbols
|
||||
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
|
||||
ORDER BY (end_line - start_line) ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
|
||||
).fetchone()
|
||||
|
||||
if symbol_row:
|
||||
source_symbol_id = int(symbol_row["id"])
|
||||
relationship_rows.append((
|
||||
source_symbol_id,
|
||||
rel.target_symbol,
|
||||
rel.relationship_type,
|
||||
rel.source_line,
|
||||
rel.target_file
|
||||
))
|
||||
|
||||
if relationship_rows:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT INTO code_relationships(
|
||||
source_symbol_id, target_qualified_name, relationship_type,
|
||||
source_line, target_file
|
||||
)
|
||||
VALUES(?, ?, ?, ?, ?)
|
||||
""",
|
||||
relationship_rows
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def query_relationships_by_target(
|
||||
self, target_name: str, *, limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query relationships by target symbol name (find all callers).
|
||||
|
||||
Args:
|
||||
target_name: Name of the target symbol
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of dicts containing relationship info with file paths and line numbers
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE r.target_qualified_name = ?
|
||||
ORDER BY f.path, r.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(target_name, limit)
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def query_relationships_by_source(
|
||||
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query relationships by source symbol (find what a symbol calls).
|
||||
|
||||
Args:
|
||||
source_symbol: Name of the source symbol
|
||||
source_file: File path containing the source symbol
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of dicts containing relationship info
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
resolved_path = str(Path(source_file).resolve())
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.name AS source_symbol,
|
||||
r.target_qualified_name,
|
||||
r.relationship_type,
|
||||
r.source_line,
|
||||
f.path AS source_file,
|
||||
r.target_file
|
||||
FROM code_relationships r
|
||||
JOIN symbols s ON r.source_symbol_id = s.id
|
||||
JOIN files f ON s.file_id = f.id
|
||||
WHERE s.name = ? AND f.path = ?
|
||||
ORDER BY r.source_line
|
||||
LIMIT ?
|
||||
""",
|
||||
(source_symbol, resolved_path, limit)
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_symbol": row["source_symbol"],
|
||||
"target_symbol": row["target_qualified_name"],
|
||||
"relationship_type": row["relationship_type"],
|
||||
"source_line": row["source_line"],
|
||||
"source_file": row["source_file"],
|
||||
"target_file": row["target_file"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
"""Legacy method for backward compatibility."""
|
||||
return self._get_connection()
|
||||
@@ -348,6 +519,20 @@ class SQLiteStore:
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_kind ON symbols(kind)")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS code_relationships (
|
||||
id INTEGER PRIMARY KEY,
|
||||
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
|
||||
target_qualified_name TEXT NOT NULL,
|
||||
relationship_type TEXT NOT NULL,
|
||||
source_line INTEGER NOT NULL,
|
||||
target_file TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
|
||||
conn.commit()
|
||||
except sqlite3.DatabaseError as exc:
|
||||
raise StorageError(f"Failed to initialize database schema: {exc}") from exc
|
||||
|
||||
Reference in New Issue
Block a user