Add comprehensive tests for tokenizer, performance benchmarks, and TreeSitter parser functionality

- Implemented unit tests for the Tokenizer class, covering various text inputs, edge cases, and fallback mechanisms.
- Created performance benchmarks comparing tiktoken and pure Python implementations for token counting.
- Developed extensive tests for TreeSitterSymbolParser across Python, JavaScript, and TypeScript, ensuring accurate symbol extraction and parsing.
- Added configuration documentation for MCP integration and custom prompts, enhancing usability and flexibility.
- Introduced a refactor script for GraphAnalyzer to streamline future improvements.
This commit is contained in:
catlog22
2025-12-15 14:36:09 +08:00
parent 82dcafff00
commit 0fe16963cd
49 changed files with 9307 additions and 438 deletions

View File

@@ -1100,6 +1100,103 @@ def clean(
raise typer.Exit(code=1)
@app.command()
def graph(
query_type: str = typer.Argument(..., help="Query type: callers, callees, or inheritance"),
symbol: str = typer.Argument(..., help="Symbol name to query"),
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
limit: int = typer.Option(50, "--limit", "-n", min=1, max=500, help="Max results."),
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited)."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
"""Query semantic graph for code relationships.
Supported query types:
- callers: Find all functions/methods that call the given symbol
- callees: Find all functions/methods called by the given symbol
- inheritance: Find inheritance relationships for the given class
Examples:
codex-lens graph callers my_function
codex-lens graph callees MyClass.method --path src/
codex-lens graph inheritance BaseClass
"""
_configure_logging(verbose)
search_path = path.expanduser().resolve()
# Validate query type
valid_types = ["callers", "callees", "inheritance"]
if query_type not in valid_types:
if json_mode:
print_json(success=False, error=f"Invalid query type: {query_type}. Must be one of: {', '.join(valid_types)}")
else:
console.print(f"[red]Invalid query type:[/red] {query_type}")
console.print(f"[dim]Valid types: {', '.join(valid_types)}[/dim]")
raise typer.Exit(code=1)
registry: RegistryStore | None = None
try:
registry = RegistryStore()
registry.initialize()
mapper = PathMapper()
engine = ChainSearchEngine(registry, mapper)
options = SearchOptions(depth=depth, total_limit=limit)
# Execute graph query based on type
if query_type == "callers":
results = engine.search_callers(symbol, search_path, options=options)
result_type = "callers"
elif query_type == "callees":
results = engine.search_callees(symbol, search_path, options=options)
result_type = "callees"
else: # inheritance
results = engine.search_inheritance(symbol, search_path, options=options)
result_type = "inheritance"
payload = {
"query_type": query_type,
"symbol": symbol,
"count": len(results),
"relationships": results
}
if json_mode:
print_json(success=True, result=payload)
else:
from .output import render_graph_results
render_graph_results(results, query_type=query_type, symbol=symbol)
except SearchError as exc:
if json_mode:
print_json(success=False, error=f"Graph search error: {exc}")
else:
console.print(f"[red]Graph query failed (search):[/red] {exc}")
raise typer.Exit(code=1)
except StorageError as exc:
if json_mode:
print_json(success=False, error=f"Storage error: {exc}")
else:
console.print(f"[red]Graph query failed (storage):[/red] {exc}")
raise typer.Exit(code=1)
except CodexLensError as exc:
if json_mode:
print_json(success=False, error=str(exc))
else:
console.print(f"[red]Graph query failed:[/red] {exc}")
raise typer.Exit(code=1)
except Exception as exc:
if json_mode:
print_json(success=False, error=f"Unexpected error: {exc}")
else:
console.print(f"[red]Graph query failed (unexpected):[/red] {exc}")
raise typer.Exit(code=1)
finally:
if registry is not None:
registry.close()
@app.command("semantic-list")
def semantic_list(
path: Path = typer.Option(Path("."), "--path", "-p", help="Project path to list metadata from."),

View File

@@ -89,3 +89,68 @@ def render_file_inspect(path: str, language: str, symbols: Iterable[Symbol]) ->
console.print(header)
render_symbols(list(symbols), title="Discovered Symbols")
def render_graph_results(results: list[dict[str, Any]], *, query_type: str, symbol: str) -> None:
"""Render semantic graph query results.
Args:
results: List of relationship dicts
query_type: Type of query (callers, callees, inheritance)
symbol: Symbol name that was queried
"""
if not results:
console.print(f"[yellow]No {query_type} found for symbol:[/yellow] {symbol}")
return
title_map = {
"callers": f"Callers of '{symbol}' ({len(results)} found)",
"callees": f"Callees of '{symbol}' ({len(results)} found)",
"inheritance": f"Inheritance relationships for '{symbol}' ({len(results)} found)"
}
table = Table(title=title_map.get(query_type, f"Graph Results ({len(results)})"))
if query_type == "callers":
table.add_column("Caller", style="green")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
table.add_column("Type", style="dim")
for rel in results:
table.add_row(
rel.get("source_symbol", "-"),
rel.get("source_file", "-"),
str(rel.get("source_line", "-")),
rel.get("relationship_type", "-")
)
elif query_type == "callees":
table.add_column("Target", style="green")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
table.add_column("Type", style="dim")
for rel in results:
table.add_row(
rel.get("target_symbol", "-"),
rel.get("target_file", "-") if rel.get("target_file") else rel.get("source_file", "-"),
str(rel.get("source_line", "-")),
rel.get("relationship_type", "-")
)
else: # inheritance
table.add_column("Derived Class", style="green")
table.add_column("Base Class", style="magenta")
table.add_column("File", style="cyan", no_wrap=False, max_width=40)
table.add_column("Line", justify="right", style="yellow")
for rel in results:
table.add_row(
rel.get("source_symbol", "-"),
rel.get("target_symbol", "-"),
rel.get("source_file", "-"),
str(rel.get("source_line", "-"))
)
console.print(table)

View File

@@ -83,6 +83,9 @@ class Config:
llm_timeout_ms: int = 300000
llm_batch_size: int = 5
# Hybrid chunker configuration
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()

View File

@@ -13,6 +13,8 @@ class Symbol(BaseModel):
name: str = Field(..., min_length=1)
kind: str = Field(..., min_length=1)
range: Tuple[int, int] = Field(..., description="(start_line, end_line), 1-based inclusive")
token_count: Optional[int] = Field(default=None, description="Token count for symbol content")
symbol_type: Optional[str] = Field(default=None, description="Extended symbol type for filtering")
@field_validator("range")
@classmethod
@@ -26,6 +28,13 @@ class Symbol(BaseModel):
raise ValueError("end_line must be >= start_line")
return value
@field_validator("token_count")
@classmethod
def validate_token_count(cls, value: Optional[int]) -> Optional[int]:
if value is not None and value < 0:
raise ValueError("token_count must be >= 0")
return value
class SemanticChunk(BaseModel):
"""A semantically meaningful chunk of content, optionally embedded."""
@@ -61,6 +70,25 @@ class IndexedFile(BaseModel):
return cleaned
class CodeRelationship(BaseModel):
"""A relationship between code symbols (e.g., function calls, inheritance)."""
source_symbol: str = Field(..., min_length=1, description="Name of source symbol")
target_symbol: str = Field(..., min_length=1, description="Name of target symbol")
relationship_type: str = Field(..., min_length=1, description="Type of relationship (call, inherits, etc.)")
source_file: str = Field(..., min_length=1, description="File path containing source symbol")
target_file: Optional[str] = Field(default=None, description="File path containing target (None if same file)")
source_line: int = Field(..., ge=1, description="Line number where relationship occurs (1-based)")
@field_validator("relationship_type")
@classmethod
def validate_relationship_type(cls, value: str) -> str:
allowed_types = {"call", "inherits", "imports"}
if value not in allowed_types:
raise ValueError(f"relationship_type must be one of {allowed_types}")
return value
class SearchResult(BaseModel):
"""A unified search result for lexical or semantic search."""

View File

@@ -10,19 +10,11 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Protocol
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
except Exception: # pragma: no cover
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import IndexedFile, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class Parser(Protocol):
@@ -34,10 +26,24 @@ class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
# Try tree-sitter first for supported languages
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols(text)
symbols = _parse_python_symbols_regex(text)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols(text, self.language_id, path)
symbols = _parse_js_ts_symbols_regex(text)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
elif self.language_id == "go":
@@ -64,120 +70,35 @@ class ParserFactory:
return self._parsers[language_id]
# Regex-based fallback parsers
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_TREE_SITTER_LANGUAGE_CACHE: Dict[str, TreeSitterLanguage] = {}
def _get_tree_sitter_language(language_id: str, path: Path | None = None) -> TreeSitterLanguage | None:
if TreeSitterLanguage is None:
return None
cache_key = language_id
if language_id == "typescript" and path is not None and path.suffix.lower() == ".tsx":
cache_key = "tsx"
cached = _TREE_SITTER_LANGUAGE_CACHE.get(cache_key)
if cached is not None:
return cached
try:
if cache_key == "python":
import tree_sitter_python # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_python.language())
elif cache_key == "javascript":
import tree_sitter_javascript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_javascript.language())
elif cache_key == "typescript":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
elif cache_key == "tsx":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
return None
except Exception:
return None
_TREE_SITTER_LANGUAGE_CACHE[cache_key] = language
return language
def _parse_python_symbols(text: str) -> List[Symbol]:
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser("python")
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
def _iter_tree_sitter_nodes(root: TreeSitterNode) -> Iterable[TreeSitterNode]:
stack: List[TreeSitterNode] = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(source_bytes: bytes, node: TreeSitterNode) -> str:
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(node: TreeSitterNode) -> tuple[int, int]:
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def _python_kind_for_function_node(node: TreeSitterNode) -> str:
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _parse_python_symbols_tree_sitter(text: str) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language("python")
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind=_python_kind_for_function_node(node),
range=_node_range(node),
))
return symbols
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Optional[Path] = None,
) -> List[Symbol]:
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser(language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
@@ -202,13 +123,6 @@ def _parse_python_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_python_symbols(text: str) -> List[Symbol]:
symbols = _parse_python_symbols_tree_sitter(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
@@ -217,88 +131,6 @@ _JS_ARROW_RE = re.compile(
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
def _js_has_class_ancestor(node: TreeSitterNode) -> bool:
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _parse_js_ts_symbols_tree_sitter(
text: str,
language_id: str,
path: Path | None = None,
) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language(language_id, path)
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "method_definition" and _js_has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = _node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=_node_range(node),
))
return symbols
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
@@ -338,17 +170,6 @@ def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
return symbols
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Path | None = None,
) -> List[Symbol]:
symbols = _parse_js_ts_symbols_tree_sitter(text, language_id, path)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("

View File

@@ -0,0 +1,98 @@
"""Token counting utilities for CodexLens.
Provides accurate token counting using tiktoken with character count fallback.
"""
from __future__ import annotations
from typing import Optional
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
class Tokenizer:
"""Token counter with tiktoken primary and character count fallback."""
def __init__(self, encoding_name: str = "cl100k_base") -> None:
"""Initialize tokenizer.
Args:
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
"""
self._encoding: Optional[object] = None
self._encoding_name = encoding_name
if TIKTOKEN_AVAILABLE:
try:
self._encoding = tiktoken.get_encoding(encoding_name)
except Exception:
# Fallback to character counting if encoding fails
self._encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Uses tiktoken if available, otherwise falls back to character count / 4.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if not text:
return 0
if self._encoding is not None:
try:
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
except Exception:
# Fall through to character count fallback
pass
# Fallback: rough estimate using character count
# Average of ~4 characters per token for English text
return max(1, len(text) // 4)
def is_using_tiktoken(self) -> bool:
"""Check if tiktoken is being used.
Returns:
True if tiktoken is available and initialized
"""
return self._encoding is not None
# Global default tokenizer instance
_default_tokenizer: Optional[Tokenizer] = None
def get_default_tokenizer() -> Tokenizer:
"""Get the global default tokenizer instance.
Returns:
Shared Tokenizer instance
"""
global _default_tokenizer
if _default_tokenizer is None:
_default_tokenizer = Tokenizer()
return _default_tokenizer
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
"""Count tokens in text using default or provided tokenizer.
Args:
text: Text to count tokens for
tokenizer: Optional tokenizer instance (uses default if None)
Returns:
Estimated token count
"""
if tokenizer is None:
tokenizer = get_default_tokenizer()
return tokenizer.count_tokens(text)

View File

@@ -0,0 +1,335 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing with fallback to regex-based parsing.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import IndexedFile, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction."""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize tree-sitter parser for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx)
"""
self.language_id = language_id
self.path = path
self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer()
if TREE_SITTER_AVAILABLE:
self._initialize_parser()
def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None:
return
try:
# Load language grammar
if self.language_id == "python":
import tree_sitter_python
self._language = TreeSitterLanguage(tree_sitter_python.language())
elif self.language_id == "javascript":
import tree_sitter_javascript
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
elif self.language_id == "typescript":
import tree_sitter_typescript
# Detect TSX files by extension
if self.path is not None and self.path.suffix.lower() == ".tsx":
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
else:
return
# Create parser
self._parser = TreeSitterParser()
if hasattr(self._parser, "set_language"):
self._parser.set_language(self._language) # type: ignore[attr-defined]
else:
self._parser.language = self._language # type: ignore[assignment]
except Exception:
# Gracefully handle missing language bindings
self._parser = None
self._language = None
def is_available(self) -> bool:
"""Check if tree-sitter parser is available.
Returns:
True if parser is initialized and ready
"""
return self._parser is not None and self._language is not None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
Args:
text: Source code text
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle parsing errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract symbols.
Args:
text: Source code text
path: File path
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
if not self.is_available() or self._parser is None:
return None
try:
symbols = self.parse_symbols(text)
if symbols is None:
return None
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
except Exception:
# Gracefully handle parsing errors
return None
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of extracted symbols
"""
if self.language_id == "python":
return self._extract_python_symbols(source_bytes, root)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_symbols(source_bytes, root)
else:
return []
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of Python symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind=self._python_function_kind(node),
range=self._node_range(node),
))
return symbols
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract JavaScript/TypeScript symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of JS/TS symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = self._node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=self._node_range(node),
))
return symbols
def _python_function_kind(self, node: TreeSitterNode) -> str:
"""Determine if Python function is a method or standalone function.
Args:
node: Function definition node
Returns:
'method' if inside a class, 'function' otherwise
"""
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
"""Check if node has a class ancestor.
Args:
node: AST node to check
Returns:
True if node is inside a class
"""
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
"""Get line range for a node.
Args:
node: AST node
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Args:
text: Text to count tokens for
Returns:
Token count
"""
return self._tokenizer.count_tokens(text)

View File

@@ -17,6 +17,7 @@ from codexlens.entities import SearchResult, Symbol
from codexlens.storage.registry import RegistryStore, DirMapping
from codexlens.storage.dir_index import DirIndexStore, SubdirLink
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.sqlite_store import SQLiteStore
@dataclass
@@ -278,6 +279,108 @@ class ChainSearchEngine:
index_paths, name, kind, options.total_limit
)
def search_callers(self, target_symbol: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find all callers of a given symbol across directory hierarchy.
Args:
target_symbol: Name of the symbol to find callers for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with caller information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> callers = engine.search_callers("my_function", Path("D:/project"))
>>> for caller in callers:
... print(f"{caller['source_symbol']} in {caller['source_file']}:{caller['source_line']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_callers_parallel(
index_paths, target_symbol, options.total_limit
)
def search_callees(self, source_symbol: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find all callees (what a symbol calls) across directory hierarchy.
Args:
source_symbol: Name of the symbol to find callees for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with callee information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> callees = engine.search_callees("MyClass.method", Path("D:/project"))
>>> for callee in callees:
... print(f"Calls {callee['target_symbol']} at line {callee['source_line']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_callees_parallel(
index_paths, source_symbol, options.total_limit
)
def search_inheritance(self, class_name: str,
source_path: Path,
options: Optional[SearchOptions] = None) -> List[Dict[str, Any]]:
"""Find inheritance relationships for a class across directory hierarchy.
Args:
class_name: Name of the class to find inheritance for
source_path: Starting directory path
options: Search configuration (uses defaults if None)
Returns:
List of relationship dicts with inheritance information
Examples:
>>> engine = ChainSearchEngine(registry, mapper)
>>> inheritance = engine.search_inheritance("BaseClass", Path("D:/project"))
>>> for rel in inheritance:
... print(f"{rel['source_symbol']} extends {rel['target_symbol']}")
"""
options = options or SearchOptions()
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
return []
index_paths = self._collect_index_paths(start_index, options.depth)
if not index_paths:
return []
return self._search_inheritance_parallel(
index_paths, class_name, options.total_limit
)
# === Internal Methods ===
def _find_start_index(self, source_path: Path) -> Optional[Path]:
@@ -553,6 +656,252 @@ class ChainSearchEngine:
self.logger.debug(f"Symbol search error in {index_path}: {exc}")
return []
def _search_callers_parallel(self, index_paths: List[Path],
target_symbol: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for callers across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
target_symbol: Target symbol name
limit: Total result limit
Returns:
Deduplicated list of caller relationships
"""
all_callers = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_callers_single,
idx_path,
target_symbol
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
callers = future.result()
all_callers.extend(callers)
except Exception as exc:
self.logger.error(f"Caller search failed: {exc}")
# Deduplicate by (source_file, source_line)
seen = set()
unique_callers = []
for caller in all_callers:
key = (caller.get("source_file"), caller.get("source_line"))
if key not in seen:
seen.add(key)
unique_callers.append(caller)
# Sort by source file and line
unique_callers.sort(key=lambda c: (c.get("source_file", ""), c.get("source_line", 0)))
return unique_callers[:limit]
def _search_callers_single(self, index_path: Path,
target_symbol: str) -> List[Dict[str, Any]]:
"""Search for callers in a single index.
Args:
index_path: Path to _index.db file
target_symbol: Target symbol name
Returns:
List of caller relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
return store.query_relationships_by_target(target_symbol)
except Exception as exc:
self.logger.debug(f"Caller search error in {index_path}: {exc}")
return []
def _search_callees_parallel(self, index_paths: List[Path],
source_symbol: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for callees across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
source_symbol: Source symbol name
limit: Total result limit
Returns:
Deduplicated list of callee relationships
"""
all_callees = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_callees_single,
idx_path,
source_symbol
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
callees = future.result()
all_callees.extend(callees)
except Exception as exc:
self.logger.error(f"Callee search failed: {exc}")
# Deduplicate by (target_symbol, source_line)
seen = set()
unique_callees = []
for callee in all_callees:
key = (callee.get("target_symbol"), callee.get("source_line"))
if key not in seen:
seen.add(key)
unique_callees.append(callee)
# Sort by source line
unique_callees.sort(key=lambda c: c.get("source_line", 0))
return unique_callees[:limit]
def _search_callees_single(self, index_path: Path,
source_symbol: str) -> List[Dict[str, Any]]:
"""Search for callees in a single index.
Args:
index_path: Path to _index.db file
source_symbol: Source symbol name
Returns:
List of callee relationship dicts (empty on error)
"""
try:
# Use the connection pool via SQLiteStore
with SQLiteStore(index_path) as store:
# Search across all files containing the symbol
# Get all files that have this symbol
conn = store._get_connection()
file_rows = conn.execute(
"""
SELECT DISTINCT f.path
FROM symbols s
JOIN files f ON s.file_id = f.id
WHERE s.name = ?
""",
(source_symbol,)
).fetchall()
# Collect results from all matching files
all_results = []
for file_row in file_rows:
file_path = file_row["path"]
results = store.query_relationships_by_source(source_symbol, file_path)
all_results.extend(results)
return all_results
except Exception as exc:
self.logger.debug(f"Callee search error in {index_path}: {exc}")
return []
def _search_inheritance_parallel(self, index_paths: List[Path],
class_name: str,
limit: int) -> List[Dict[str, Any]]:
"""Search for inheritance relationships across multiple indexes in parallel.
Args:
index_paths: List of _index.db paths to search
class_name: Class name to search for
limit: Total result limit
Returns:
Deduplicated list of inheritance relationships
"""
all_inheritance = []
executor = self._get_executor()
future_to_path = {
executor.submit(
self._search_inheritance_single,
idx_path,
class_name
): idx_path
for idx_path in index_paths
}
for future in as_completed(future_to_path):
try:
inheritance = future.result()
all_inheritance.extend(inheritance)
except Exception as exc:
self.logger.error(f"Inheritance search failed: {exc}")
# Deduplicate by (source_symbol, target_symbol)
seen = set()
unique_inheritance = []
for rel in all_inheritance:
key = (rel.get("source_symbol"), rel.get("target_symbol"))
if key not in seen:
seen.add(key)
unique_inheritance.append(rel)
# Sort by source file
unique_inheritance.sort(key=lambda r: r.get("source_file", ""))
return unique_inheritance[:limit]
def _search_inheritance_single(self, index_path: Path,
class_name: str) -> List[Dict[str, Any]]:
"""Search for inheritance relationships in a single index.
Args:
index_path: Path to _index.db file
class_name: Class name to search for
Returns:
List of inheritance relationship dicts (empty on error)
"""
try:
with SQLiteStore(index_path) as store:
conn = store._get_connection()
# Search both as base class (target) and derived class (source)
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE (s.name = ? OR r.target_qualified_name LIKE ?)
AND r.relationship_type = 'inherits'
ORDER BY f.path, r.source_line
LIMIT 100
""",
(class_name, f"%{class_name}%")
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
except Exception as exc:
self.logger.debug(f"Inheritance search error in {index_path}: {exc}")
return []
# === Convenience Functions ===

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from codexlens.entities import SemanticChunk, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@dataclass
@@ -14,6 +15,7 @@ class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 100 # Overlap for sliding window
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
@@ -22,6 +24,7 @@ class Chunker:
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
def chunk_by_symbol(
self,
@@ -29,10 +32,18 @@ class Chunker:
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
@@ -47,6 +58,13 @@ class Chunker:
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._tokenizer.count_tokens(chunk_content)
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
@@ -58,6 +76,7 @@ class Chunker:
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"token_count": token_count,
}
))
@@ -68,10 +87,19 @@ class Chunker:
content: str,
file_path: str | Path,
language: str,
line_mapping: Optional[List[int]] = None,
) -> List[SemanticChunk]:
"""Chunk code using sliding window approach.
Used for files without clear symbol boundaries or very long functions.
Args:
content: Source code content
file_path: Path to source file
language: Programming language
line_mapping: Optional list mapping content line indices to original line numbers
(1-indexed). If provided, line_mapping[i] is the original line number
for the i-th line in content.
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
@@ -92,6 +120,18 @@ class Chunker:
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
token_count = self._tokenizer.count_tokens(chunk_content)
# Calculate correct line numbers
if line_mapping:
# Use line mapping to get original line numbers
start_line = line_mapping[start]
end_line = line_mapping[end - 1]
else:
# Default behavior: treat content as starting at line 1
start_line = start + 1
end_line = end
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
@@ -99,9 +139,10 @@ class Chunker:
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start + 1,
"end_line": end,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"token_count": token_count,
}
))
chunk_idx += 1
@@ -119,12 +160,239 @@ class Chunker:
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk a file using the best strategy.
Uses symbol-based chunking if symbols available,
falls back to sliding window for files without symbols.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
if symbols:
return self.chunk_by_symbol(content, symbols, file_path, language)
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
return self.chunk_sliding_window(content, file_path, language)
class DocstringExtractor:
"""Extract docstrings from source code."""
@staticmethod
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
"""Extract Python docstrings with their line ranges.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
docstrings: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
start_line = i + 1
if stripped.count(quote_type) >= 2:
docstring_content = line
end_line = i + 1
docstrings.append((docstring_content, start_line, end_line))
i += 1
continue
docstring_lines = [line]
i += 1
while i < len(lines):
docstring_lines.append(lines[i])
if quote_type in lines[i]:
break
i += 1
end_line = i + 1
docstring_content = "".join(docstring_lines)
docstrings.append((docstring_content, start_line, end_line))
i += 1
return docstrings
@staticmethod
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
"""Extract JSDoc comments with their line ranges.
Returns: List of (comment_content, start_line, end_line) tuples
"""
comments: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('/**'):
start_line = i + 1
comment_lines = [line]
i += 1
while i < len(lines):
comment_lines.append(lines[i])
if '*/' in lines[i]:
break
i += 1
end_line = i + 1
comment_content = "".join(comment_lines)
comments.append((comment_content, start_line, end_line))
i += 1
return comments
@classmethod
def extract_docstrings(
cls,
content: str,
language: str
) -> List[Tuple[str, int, int]]:
"""Extract docstrings based on language.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
if language == "python":
return cls.extract_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.extract_jsdoc_comments(content)
return []
class HybridChunker:
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
Composition-based strategy that:
1. Extracts docstrings as dedicated chunks
2. For remaining code, uses base chunker (symbol or sliding window)
"""
def __init__(
self,
base_chunker: Chunker | None = None,
config: ChunkConfig | None = None
) -> None:
"""Initialize hybrid chunker.
Args:
base_chunker: Chunker to use for non-docstring content
config: Configuration for chunking
"""
self.config = config or ChunkConfig()
self.base_chunker = base_chunker or Chunker(self.config)
self.docstring_extractor = DocstringExtractor()
def _get_excluded_line_ranges(
self,
docstrings: List[Tuple[str, int, int]]
) -> set[int]:
"""Get set of line numbers that are part of docstrings."""
excluded_lines: set[int] = set()
for _, start_line, end_line in docstrings:
for line_num in range(start_line, end_line + 1):
excluded_lines.add(line_num)
return excluded_lines
def _filter_symbols_outside_docstrings(
self,
symbols: List[Symbol],
excluded_lines: set[int]
) -> List[Symbol]:
"""Filter symbols to exclude those completely within docstrings."""
filtered: List[Symbol] = []
for symbol in symbols:
start_line, end_line = symbol.range
symbol_lines = set(range(start_line, end_line + 1))
if not symbol_lines.issubset(excluded_lines):
filtered.append(symbol)
return filtered
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk file using hybrid strategy.
Extracts docstrings first, then chunks remaining code.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
tokenizer = get_default_tokenizer()
# Step 1: Extract docstrings as dedicated chunks
docstrings = self.docstring_extractor.extract_docstrings(content, language)
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
token_count = tokenizer.count_tokens(docstring_content)
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
))
# Step 2: Get line ranges occupied by docstrings
excluded_lines = self._get_excluded_line_ranges(docstrings)
# Step 3: Filter symbols to exclude docstring-only ranges
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
# Step 4: Chunk remaining content using base chunker
if filtered_symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, filtered_symbols, file_path, language, symbol_token_counts
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
else:
lines = content.splitlines(keepends=True)
remaining_lines: List[str] = []
for i, line in enumerate(lines, start=1):
if i not in excluded_lines:
remaining_lines.append(line)
if remaining_lines:
remaining_content = "".join(remaining_lines)
if len(remaining_content.strip()) >= self.config.min_chunk_size:
base_chunks = self.base_chunker.chunk_sliding_window(
remaining_content, file_path, language
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,531 @@
"""Graph analyzer for extracting code relationships using tree-sitter.
Provides AST-based analysis to identify function calls, method invocations,
and class inheritance relationships within source files.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
try:
from tree_sitter import Node as TreeSitterNode
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterNode = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class GraphAnalyzer:
"""Analyzer for extracting semantic relationships from code using AST traversal."""
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
"""Initialize graph analyzer for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
parser: Optional TreeSitterSymbolParser instance for dependency injection.
If None, creates a new parser instance (backward compatibility).
"""
self.language_id = language_id
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
def is_available(self) -> bool:
"""Check if graph analyzer is available.
Returns:
True if tree-sitter parser is initialized and ready
"""
return self._parser.is_available()
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
"""Analyze source code and extract relationships.
Args:
text: Source code text
file_path: File path for relationship context
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def analyze_with_symbols(
self, text: str, file_path: Path, symbols: List[Symbol]
) -> List[CodeRelationship]:
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
Args:
text: Source code text
file_path: File path for relationship context
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
# Convert Symbol objects to internal symbol format
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
# Extract relationships using provided symbols
relationships = self._extract_relationships_with_symbols(
source_bytes, root, str(file_path.resolve()), defined_symbols
)
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def _convert_symbols_to_dict(
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
) -> List[dict]:
"""Convert Symbol objects to internal dict format for relationship extraction.
Args:
source_bytes: Source code as bytes
root: Root AST node
symbols: Pre-parsed Symbol objects
Returns:
List of symbol info dicts with name, node, and type
"""
symbol_dicts = []
symbol_names = {s.name for s in symbols}
# Find AST nodes corresponding to symbols
for node in self._iter_nodes(root):
node_type = node.type
# Check if this node matches any of our symbols
if node_type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "method"
})
elif node_type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
return symbol_dicts
def _extract_relationships_with_symbols(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
) -> List[CodeRelationship]:
"""Extract relationships from AST using pre-parsed symbols.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
defined_symbols: Pre-parsed symbol dicts
Returns:
List of extracted relationships
"""
relationships: List[CodeRelationship] = []
# Determine call node type based on language
if self.language_id == "python":
call_node_type = "call"
extract_target = self._extract_call_target
elif self.language_id in {"javascript", "typescript"}:
call_node_type = "call_expression"
extract_target = self._extract_js_call_target
else:
return []
# Find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == call_node_type:
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = extract_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of extracted relationships
"""
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, file_path)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, file_path)
else:
return []
def _extract_python_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract Python relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of Python relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols with their scopes
defined_symbols = self._collect_python_symbols(source_bytes, root)
# Second pass: find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == "call":
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = self._extract_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_js_ts_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract JavaScript/TypeScript relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of JS/TS relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
# Second pass: find call expressions
for node in self._iter_nodes(root):
if node.type == "call_expression":
# Extract caller context
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
source_symbol = "<module>"
# Extract callee
target_symbol = self._extract_js_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=line_number,
)
)
return relationships
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all Python function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
return symbols
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all JS/TS function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "method"
})
elif node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
return symbols
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
"""Find the enclosing function/method/class for a node.
Args:
node: AST node to find enclosure for
symbols: List of defined symbols
Returns:
Name of enclosing symbol, or None if at module level
"""
# Walk up the tree to find enclosing symbol
parent = node.parent
while parent is not None:
for symbol in symbols:
if symbol["node"] == parent:
return symbol["name"]
parent = parent.parent
return None
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a Python call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers (e.g., "foo()")
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle attribute access (e.g., "obj.method()")
if function_node.type == "attribute":
attr_node = function_node.child_by_field_name("attribute")
if attr_node:
return self._node_text(source_bytes, attr_node)
return None
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a JS/TS call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle member expressions (e.g., "obj.method()")
if function_node.type == "member_expression":
property_node = function_node.child_by_field_name("property")
if property_node:
return self._node_text(source_bytes, property_node)
return None
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")

View File

@@ -75,6 +75,34 @@ class LLMEnhancer:
external LLM tools (gemini, qwen) via CCW CLI subprocess.
"""
CHUNK_REFINEMENT_PROMPT = '''PURPOSE: Identify optimal semantic split points in code chunk
TASK:
- Analyze the code structure to find natural semantic boundaries
- Identify logical groupings (functions, classes, related statements)
- Suggest split points that maintain semantic cohesion
MODE: analysis
EXPECTED: JSON format with split positions
=== CODE CHUNK ===
{code_chunk}
=== OUTPUT FORMAT ===
Return ONLY valid JSON (no markdown, no explanation):
{{
"split_points": [
{{
"line": <line_number>,
"reason": "brief reason for split (e.g., 'start of new function', 'end of class definition')"
}}
]
}}
Rules:
- Split at function/class/method boundaries
- Keep related code together (don't split mid-function)
- Aim for chunks between 500-2000 characters
- Return empty split_points if no good splits found'''
PROMPT_TEMPLATE = '''PURPOSE: Generate semantic summaries and search keywords for code files
TASK:
- For each code block, generate a concise summary (1-2 sentences)
@@ -168,42 +196,246 @@ Return ONLY valid JSON (no markdown, no explanation):
return results
def enhance_file(
self,
path: str,
content: str,
language: str,
working_dir: Optional[Path] = None,
) -> SemanticMetadata:
"""Enhance a single file with LLM-generated semantic metadata.
Convenience method that wraps enhance_files for single file processing.
Args:
path: File path
content: File content
language: Programming language
working_dir: Optional working directory for CCW CLI
Returns:
SemanticMetadata for the file
Raises:
ValueError: If enhancement fails
"""
file_data = FileData(path=path, content=content, language=language)
results = self.enhance_files([file_data], working_dir)
if path not in results:
# Return default metadata if enhancement failed
return SemanticMetadata(
summary=f"Code file written in {language}",
keywords=[language, "code"],
purpose="unknown",
file_path=path,
llm_tool=self.config.tool,
)
return results[path]
def refine_chunk_boundaries(
self,
chunk: SemanticChunk,
max_chunk_size: int = 2000,
working_dir: Optional[Path] = None,
) -> List[SemanticChunk]:
"""Refine chunk boundaries using LLM for large code chunks.
Uses LLM to identify semantic split points in large chunks,
breaking them into smaller, more cohesive pieces.
Args:
chunk: Original chunk to refine
max_chunk_size: Maximum characters before triggering refinement
working_dir: Optional working directory for CCW CLI
Returns:
SemanticMetadata for the file
Raises:
ValueError: If enhancement fails
List of refined chunks (original chunk if no splits or refinement fails)
"""
file_data = FileData(path=path, content=content, language=language)
results = self.enhance_files([file_data], working_dir)
# Skip if chunk is small enough
if len(chunk.content) <= max_chunk_size:
return [chunk]
if path not in results:
# Return default metadata if enhancement failed
return SemanticMetadata(
summary=f"Code file written in {language}",
keywords=[language, "code"],
purpose="unknown",
file_path=path,
llm_tool=self.config.tool,
# Skip if LLM enhancement disabled or unavailable
if not self.config.enabled or not self.check_available():
return [chunk]
# Skip docstring chunks - only refine code chunks
if chunk.metadata.get("chunk_type") == "docstring":
return [chunk]
try:
# Build refinement prompt
prompt = self.CHUNK_REFINEMENT_PROMPT.format(code_chunk=chunk.content)
# Invoke LLM
result = self._invoke_ccw_cli(
prompt,
tool=self.config.tool,
working_dir=working_dir,
)
return results[path]
# Fallback if primary tool fails
if not result["success"] and self.config.fallback_tool:
result = self._invoke_ccw_cli(
prompt,
tool=self.config.fallback_tool,
working_dir=working_dir,
)
if not result["success"]:
logger.debug("LLM refinement failed, returning original chunk")
return [chunk]
# Parse split points
split_points = self._parse_split_points(result["stdout"])
if not split_points:
logger.debug("No split points identified, returning original chunk")
return [chunk]
# Split chunk at identified boundaries
refined_chunks = self._split_chunk_at_points(chunk, split_points)
logger.debug(
"Refined chunk into %d smaller chunks (was %d chars)",
len(refined_chunks),
len(chunk.content),
)
return refined_chunks
except Exception as e:
logger.warning("Chunk refinement error: %s, returning original chunk", e)
return [chunk]
def _parse_split_points(self, stdout: str) -> List[int]:
"""Parse split points from LLM response.
Args:
stdout: Raw stdout from CCW CLI
Returns:
List of line numbers where splits should occur (sorted)
"""
# Extract JSON from response
json_str = self._extract_json(stdout)
if not json_str:
return []
try:
data = json.loads(json_str)
split_points_data = data.get("split_points", [])
# Extract line numbers
lines = []
for point in split_points_data:
if isinstance(point, dict) and "line" in point:
line_num = point["line"]
if isinstance(line_num, int) and line_num > 0:
lines.append(line_num)
return sorted(set(lines))
except (json.JSONDecodeError, ValueError, TypeError) as e:
logger.debug("Failed to parse split points: %s", e)
return []
def _split_chunk_at_points(
self,
chunk: SemanticChunk,
split_points: List[int],
) -> List[SemanticChunk]:
"""Split chunk at specified line numbers.
Args:
chunk: Original chunk to split
split_points: Sorted list of line numbers to split at
Returns:
List of smaller chunks
"""
lines = chunk.content.splitlines(keepends=True)
chunks: List[SemanticChunk] = []
# Get original metadata
base_metadata = dict(chunk.metadata)
original_start = base_metadata.get("start_line", 1)
# Add start and end boundaries
boundaries = [0] + split_points + [len(lines)]
for i in range(len(boundaries) - 1):
start_idx = boundaries[i]
end_idx = boundaries[i + 1]
# Skip empty sections
if start_idx >= end_idx:
continue
# Extract content
section_lines = lines[start_idx:end_idx]
section_content = "".join(section_lines)
# Skip if too small
if len(section_content.strip()) < 50:
continue
# Create new chunk with updated metadata
new_metadata = base_metadata.copy()
new_metadata["start_line"] = original_start + start_idx
new_metadata["end_line"] = original_start + end_idx - 1
new_metadata["refined_by_llm"] = True
new_metadata["original_chunk_size"] = len(chunk.content)
chunks.append(
SemanticChunk(
content=section_content,
embedding=None, # Embeddings will be regenerated
metadata=new_metadata,
)
)
# If no valid chunks created, return original
if not chunks:
return [chunk]
return chunks
def _process_batch(

View File

@@ -149,15 +149,21 @@ class DirIndexStore:
# Replace symbols
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
if symbols:
# Extract token_count and symbol_type from symbol metadata if available
symbol_rows = []
for s in symbols:
token_count = getattr(s, 'token_count', None)
symbol_type = getattr(s, 'symbol_type', None) or s.kind
symbol_rows.append(
(file_id, s.name, s.kind, s.range[0], s.range[1], token_count, symbol_type)
)
conn.executemany(
"""
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
VALUES(?, ?, ?, ?, ?)
INSERT INTO symbols(file_id, name, kind, start_line, end_line, token_count, symbol_type)
VALUES(?, ?, ?, ?, ?, ?, ?)
""",
[
(file_id, s.name, s.kind, s.range[0], s.range[1])
for s in symbols
],
symbol_rows,
)
conn.commit()
@@ -216,15 +222,21 @@ class DirIndexStore:
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
if symbols:
# Extract token_count and symbol_type from symbol metadata if available
symbol_rows = []
for s in symbols:
token_count = getattr(s, 'token_count', None)
symbol_type = getattr(s, 'symbol_type', None) or s.kind
symbol_rows.append(
(file_id, s.name, s.kind, s.range[0], s.range[1], token_count, symbol_type)
)
conn.executemany(
"""
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
VALUES(?, ?, ?, ?, ?)
INSERT INTO symbols(file_id, name, kind, start_line, end_line, token_count, symbol_type)
VALUES(?, ?, ?, ?, ?, ?, ?)
""",
[
(file_id, s.name, s.kind, s.range[0], s.range[1])
for s in symbols
],
symbol_rows,
)
conn.commit()
@@ -1021,7 +1033,9 @@ class DirIndexStore:
name TEXT NOT NULL,
kind TEXT NOT NULL,
start_line INTEGER,
end_line INTEGER
end_line INTEGER,
token_count INTEGER,
symbol_type TEXT
)
"""
)
@@ -1083,6 +1097,7 @@ class DirIndexStore:
conn.execute("CREATE INDEX IF NOT EXISTS idx_subdirs_name ON subdirs(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_file ON symbols(file_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_semantic_file ON semantic_metadata(file_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_keywords_keyword ON keywords(keyword)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_keywords_file_id ON file_keywords(file_id)")

View File

@@ -0,0 +1,48 @@
"""
Migration 002: Add token_count and symbol_type to symbols table.
This migration adds token counting metadata to symbols for accurate chunk
splitting and performance optimization. It also adds symbol_type for better
filtering in searches.
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add token metadata to symbols.
- Adds token_count column to symbols table
- Adds symbol_type column to symbols table (for future use)
- Creates index on symbol_type for efficient filtering
- Backfills existing symbols with NULL token_count (to be calculated lazily)
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Adding token_count column to symbols table...")
try:
cursor.execute("ALTER TABLE symbols ADD COLUMN token_count INTEGER")
log.info("Successfully added token_count column.")
except Exception as e:
# Column might already exist
log.warning(f"Could not add token_count column (might already exist): {e}")
log.info("Adding symbol_type column to symbols table...")
try:
cursor.execute("ALTER TABLE symbols ADD COLUMN symbol_type TEXT")
log.info("Successfully added symbol_type column.")
except Exception as e:
# Column might already exist
log.warning(f"Could not add symbol_type column (might already exist): {e}")
log.info("Creating index on symbol_type for efficient filtering...")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_symbols_type ON symbols(symbol_type)")
log.info("Migration 002 completed successfully.")

View File

@@ -0,0 +1,57 @@
"""
Migration 003: Add code relationships storage.
This migration introduces the `code_relationships` table to store semantic
relationships between code symbols (function calls, inheritance, imports).
This enables graph-based code navigation and dependency analysis.
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection):
"""
Applies the migration to add code relationships table.
- Creates `code_relationships` table with foreign key to symbols
- Creates indexes for efficient relationship queries
- Supports lazy expansion with target_symbol being qualified names
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Creating 'code_relationships' table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT,
FOREIGN KEY (source_symbol_id) REFERENCES symbols (id) ON DELETE CASCADE
)
"""
)
log.info("Creating indexes for code_relationships...")
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source ON code_relationships (source_symbol_id)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_target ON code_relationships (target_qualified_name)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_type ON code_relationships (relationship_type)"
)
cursor.execute(
"CREATE INDEX IF NOT EXISTS idx_relationships_source_line ON code_relationships (source_line)"
)
log.info("Finished creating code_relationships table and indexes.")

View File

@@ -9,7 +9,7 @@ from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from codexlens.entities import IndexedFile, SearchResult, Symbol
from codexlens.entities import CodeRelationship, IndexedFile, SearchResult, Symbol
from codexlens.errors import StorageError
@@ -309,13 +309,184 @@ class SQLiteStore:
"SELECT language, COUNT(*) AS c FROM files GROUP BY language ORDER BY c DESC"
).fetchall()
languages = {row["language"]: row["c"] for row in lang_rows}
# Include relationship count if table exists
relationship_count = 0
try:
rel_row = conn.execute("SELECT COUNT(*) AS c FROM code_relationships").fetchone()
relationship_count = int(rel_row["c"]) if rel_row else 0
except sqlite3.DatabaseError:
pass
return {
"files": int(file_count),
"symbols": int(symbol_count),
"relationships": relationship_count,
"languages": languages,
"db_path": str(self.db_path),
}
def add_relationships(self, file_path: str | Path, relationships: List[CodeRelationship]) -> None:
"""Store code relationships for a file.
Args:
file_path: Path to the file containing the relationships
relationships: List of CodeRelationship objects to store
"""
if not relationships:
return
with self._lock:
conn = self._get_connection()
resolved_path = str(Path(file_path).resolve())
# Get file_id
row = conn.execute("SELECT id FROM files WHERE path=?", (resolved_path,)).fetchone()
if not row:
raise StorageError(f"File not found in index: {file_path}")
file_id = int(row["id"])
# Delete existing relationships for symbols in this file
conn.execute(
"""
DELETE FROM code_relationships
WHERE source_symbol_id IN (
SELECT id FROM symbols WHERE file_id=?
)
""",
(file_id,)
)
# Insert new relationships
relationship_rows = []
for rel in relationships:
# Find source symbol ID
symbol_row = conn.execute(
"""
SELECT id FROM symbols
WHERE file_id=? AND name=? AND start_line <= ? AND end_line >= ?
ORDER BY (end_line - start_line) ASC
LIMIT 1
""",
(file_id, rel.source_symbol, rel.source_line, rel.source_line)
).fetchone()
if symbol_row:
source_symbol_id = int(symbol_row["id"])
relationship_rows.append((
source_symbol_id,
rel.target_symbol,
rel.relationship_type,
rel.source_line,
rel.target_file
))
if relationship_rows:
conn.executemany(
"""
INSERT INTO code_relationships(
source_symbol_id, target_qualified_name, relationship_type,
source_line, target_file
)
VALUES(?, ?, ?, ?, ?)
""",
relationship_rows
)
conn.commit()
def query_relationships_by_target(
self, target_name: str, *, limit: int = 100
) -> List[Dict[str, Any]]:
"""Query relationships by target symbol name (find all callers).
Args:
target_name: Name of the target symbol
limit: Maximum number of results
Returns:
List of dicts containing relationship info with file paths and line numbers
"""
with self._lock:
conn = self._get_connection()
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE r.target_qualified_name = ?
ORDER BY f.path, r.source_line
LIMIT ?
""",
(target_name, limit)
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
def query_relationships_by_source(
self, source_symbol: str, source_file: str | Path, *, limit: int = 100
) -> List[Dict[str, Any]]:
"""Query relationships by source symbol (find what a symbol calls).
Args:
source_symbol: Name of the source symbol
source_file: File path containing the source symbol
limit: Maximum number of results
Returns:
List of dicts containing relationship info
"""
with self._lock:
conn = self._get_connection()
resolved_path = str(Path(source_file).resolve())
rows = conn.execute(
"""
SELECT
s.name AS source_symbol,
r.target_qualified_name,
r.relationship_type,
r.source_line,
f.path AS source_file,
r.target_file
FROM code_relationships r
JOIN symbols s ON r.source_symbol_id = s.id
JOIN files f ON s.file_id = f.id
WHERE s.name = ? AND f.path = ?
ORDER BY r.source_line
LIMIT ?
""",
(source_symbol, resolved_path, limit)
).fetchall()
return [
{
"source_symbol": row["source_symbol"],
"target_symbol": row["target_qualified_name"],
"relationship_type": row["relationship_type"],
"source_line": row["source_line"],
"source_file": row["source_file"],
"target_file": row["target_file"],
}
for row in rows
]
def _connect(self) -> sqlite3.Connection:
"""Legacy method for backward compatibility."""
return self._get_connection()
@@ -348,6 +519,20 @@ class SQLiteStore:
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_name ON symbols(name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_symbols_kind ON symbols(kind)")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS code_relationships (
id INTEGER PRIMARY KEY,
source_symbol_id INTEGER NOT NULL REFERENCES symbols(id) ON DELETE CASCADE,
target_qualified_name TEXT NOT NULL,
relationship_type TEXT NOT NULL,
source_line INTEGER NOT NULL,
target_file TEXT
)
"""
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_target ON code_relationships(target_qualified_name)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rel_source ON code_relationships(source_symbol_id)")
conn.commit()
except sqlite3.DatabaseError as exc:
raise StorageError(f"Failed to initialize database schema: {exc}") from exc