Add comprehensive tests for ast-grep and tree-sitter relationship extraction

- Introduced test suite for AstGrepPythonProcessor covering pattern definitions, parsing, and relationship extraction.
- Added comparison tests between tree-sitter and ast-grep for consistency in relationship extraction.
- Implemented tests for ast-grep binding module to verify functionality and availability.
- Ensured tests cover various scenarios including inheritance, function calls, and imports.
This commit is contained in:
catlog22
2026-02-15 21:14:14 +08:00
parent 126a357aa2
commit 48a6a1f2aa
56 changed files with 10622 additions and 374 deletions

View File

@@ -22,6 +22,9 @@ dependencies = [
"tree-sitter-typescript>=0.23",
"pathspec>=0.11",
"watchdog>=3.0",
# ast-grep for pattern-based AST matching (PyO3 bindings)
# Note: May have compatibility issues with Python 3.13
"ast-grep-py>=0.3.0; python_version < '3.13'",
]
[project.optional-dependencies]

View File

@@ -189,6 +189,9 @@ class Config:
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
# Parser configuration
use_astgrep: bool = False # Use ast-grep for Python relationship extraction (tree-sitter is default)
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()

View File

@@ -3,6 +3,12 @@
from __future__ import annotations
from .factory import ParserFactory
from .astgrep_binding import AstGrepBinding, is_astgrep_available, get_supported_languages
__all__ = ["ParserFactory"]
__all__ = [
"ParserFactory",
"AstGrepBinding",
"is_astgrep_available",
"get_supported_languages",
]

View File

@@ -0,0 +1,320 @@
"""ast-grep based parser binding for CodexLens.
Provides AST-level pattern matching via ast-grep-py (PyO3 bindings).
Note: This module wraps the official ast-grep Python bindings for pattern-based
code analysis. If ast-grep-py is unavailable, the parser returns None gracefully.
Callers should use tree-sitter or regex-based fallbacks.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# Import patterns from centralized definition (avoid duplication)
from codexlens.parsers.patterns.python import get_pattern, PYTHON_PATTERNS
# Graceful import pattern following treesitter_parser.py convention
try:
from ast_grep_py import SgNode, SgRoot
ASTGREP_AVAILABLE = True
except ImportError:
SgNode = None # type: ignore[assignment,misc]
SgRoot = None # type: ignore[assignment,misc]
ASTGREP_AVAILABLE = False
log = logging.getLogger(__name__)
class AstGrepBinding:
"""Wrapper for ast-grep-py bindings with CodexLens integration.
Provides pattern-based AST matching for code relationship extraction.
Uses declarative patterns with metavariables ($A, $$ARGS) for matching.
"""
# Language ID mapping to ast-grep language names
LANGUAGE_MAP = {
"python": "python",
"javascript": "javascript",
"typescript": "typescript",
"tsx": "tsx",
}
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize ast-grep binding for a language.
Args:
language_id: Language identifier (python, javascript, typescript, tsx)
path: Optional file path for language variant detection
"""
self.language_id = language_id
self.path = path
self._language: Optional[str] = None
self._root: Optional[SgRoot] = None # type: ignore[valid-type]
if ASTGREP_AVAILABLE:
self._initialize_language()
def _initialize_language(self) -> None:
"""Initialize ast-grep language setting."""
# Detect TSX from file extension
if self.language_id == "typescript" and self.path is not None:
if self.path.suffix.lower() == ".tsx":
self._language = "tsx"
return
self._language = self.LANGUAGE_MAP.get(self.language_id)
def is_available(self) -> bool:
"""Check if ast-grep binding is available and ready.
Returns:
True if ast-grep-py is installed and language is supported
"""
return ASTGREP_AVAILABLE and self._language is not None
def parse(self, source_code: str) -> bool:
"""Parse source code into ast-grep syntax tree.
Args:
source_code: Source code text to parse
Returns:
True if parsing succeeds, False otherwise
"""
if not self.is_available() or SgRoot is None:
return False
try:
self._root = SgRoot(source_code, self._language) # type: ignore[misc]
return True
except (ValueError, TypeError, RuntimeError) as e:
log.debug(f"ast-grep parse error: {e}")
self._root = None
return False
def find_all(self, pattern: str) -> List[SgNode]: # type: ignore[valid-type]
"""Find all matches for a pattern in the parsed source.
Args:
pattern: ast-grep pattern string (e.g., "class $NAME($$$BASES) $$$BODY")
Returns:
List of matching SgNode objects, empty if no matches or not parsed
"""
if not self.is_available() or self._root is None:
return []
try:
root_node = self._root.root()
# ast-grep-py 0.40+ requires dict config format
config = {"rule": {"pattern": pattern}}
return list(root_node.find_all(config))
except (ValueError, TypeError, AttributeError) as e:
log.debug(f"ast-grep find_all error: {e}")
return []
def find_inheritance(self) -> List[Dict[str, str]]:
"""Find all class inheritance declarations.
Returns:
List of dicts with 'class_name' and 'bases' keys
"""
if self.language_id != "python":
return []
matches = self.find_all(get_pattern("class_with_bases"))
results: List[Dict[str, str]] = []
for node in matches:
class_name = self._get_match(node, "NAME")
if class_name:
results.append({
"class_name": class_name,
"bases": self._get_match(node, "BASES"), # Base classes text
})
return results
def find_calls(self) -> List[Dict[str, str]]:
"""Find all function/method calls.
Returns:
List of dicts with 'function' and 'line' keys
"""
if self.language_id != "python":
return []
matches = self.find_all(get_pattern("call"))
results: List[Dict[str, str]] = []
for node in matches:
func_name = self._get_match(node, "FUNC")
if func_name:
# Skip self. and cls. prefixed calls
base = func_name.split(".", 1)[0]
if base not in {"self", "cls"}:
results.append({
"function": func_name,
"line": str(self._get_line_number(node)),
})
return results
def find_imports(self) -> List[Dict[str, str]]:
"""Find all import statements.
Returns:
List of dicts with 'module' and 'type' keys
"""
if self.language_id != "python":
return []
results: List[Dict[str, str]] = []
# Find 'import X' statements
import_matches = self.find_all(get_pattern("import_stmt"))
for node in import_matches:
module = self._get_match(node, "MODULE")
if module:
results.append({
"module": module,
"type": "import",
"line": str(self._get_line_number(node)),
})
# Find 'from X import Y' statements
from_matches = self.find_all(get_pattern("import_from"))
for node in from_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
if module:
results.append({
"module": module,
"names": names or "",
"type": "from_import",
"line": str(self._get_line_number(node)),
})
return results
def _get_match(self, node: SgNode, metavar: str) -> str: # type: ignore[valid-type]
"""Extract matched metavariable value from node.
Args:
node: SgNode with match
metavar: Metavariable name (without $ prefix)
Returns:
Matched text or empty string
"""
if node is None:
return ""
try:
match = node.get_match(metavar)
if match is not None:
return match.text()
except (ValueError, AttributeError, KeyError) as e:
log.debug(f"ast-grep get_match error for {metavar}: {e}")
return ""
def _get_node_text(self, node: SgNode) -> str: # type: ignore[valid-type]
"""Get full text of a node.
Args:
node: SgNode to extract text from
Returns:
Node's text content
"""
if node is None:
return ""
try:
return node.text()
except (ValueError, AttributeError) as e:
log.debug(f"ast-grep get_node_text error: {e}")
return ""
def _get_line_number(self, node: SgNode) -> int: # type: ignore[valid-type]
"""Get starting line number of a node.
Args:
node: SgNode to get line number for
Returns:
1-based line number
"""
if node is None:
return 0
try:
range_info = node.range()
# ast-grep-py 0.40+ returns Range object with .start.line attribute
if hasattr(range_info, 'start') and hasattr(range_info.start, 'line'):
return range_info.start.line + 1 # Convert to 1-based
# Fallback for string format "(0,0)-(1,8)"
if isinstance(range_info, str) and range_info:
start_part = range_info.split('-')[0].strip('()')
start_line = int(start_part.split(',')[0])
return start_line + 1
except (ValueError, AttributeError, TypeError, IndexError) as e:
log.debug(f"ast-grep get_line_number error: {e}")
return 0
def _get_line_range(self, node: SgNode) -> Tuple[int, int]: # type: ignore[valid-type]
"""Get line range (start, end) of a node.
Args:
node: SgNode to get line range for
Returns:
Tuple of (start_line, end_line), both 1-based inclusive
"""
if node is None:
return (0, 0)
try:
range_info = node.range()
# ast-grep-py 0.40+ returns Range object with .start.line and .end.line
if hasattr(range_info, 'start') and hasattr(range_info, 'end'):
start_line = getattr(range_info.start, 'line', 0)
end_line = getattr(range_info.end, 'line', 0)
return (start_line + 1, end_line + 1) # Convert to 1-based
# Fallback for string format "(0,0)-(1,8)"
if isinstance(range_info, str) and range_info:
parts = range_info.split('-')
start_part = parts[0].strip('()')
end_part = parts[1].strip('()')
start_line = int(start_part.split(',')[0])
end_line = int(end_part.split(',')[0])
return (start_line + 1, end_line + 1)
except (ValueError, AttributeError, TypeError, IndexError) as e:
log.debug(f"ast-grep get_line_range error: {e}")
return (0, 0)
def get_language(self) -> Optional[str]:
"""Get the configured ast-grep language.
Returns:
Language string or None if not configured
"""
return self._language
def is_astgrep_available() -> bool:
"""Check if ast-grep-py is installed and available.
Returns:
True if ast-grep bindings can be imported
"""
return ASTGREP_AVAILABLE
def get_supported_languages() -> List[str]:
"""Get list of supported languages for ast-grep.
Returns:
List of language identifiers
"""
return list(AstGrepBinding.LANGUAGE_MAP.keys())

View File

@@ -0,0 +1,931 @@
"""Ast-grep based processor for Python relationship extraction.
Provides pattern-based AST matching for extracting code relationships
(inheritance, calls, imports) from Python source code.
This processor wraps the ast-grep-py bindings and provides a higher-level
interface for relationship extraction, similar to TreeSitterSymbolParser.
Design Pattern:
- Follows TreeSitterSymbolParser class structure for consistency
- Uses declarative patterns defined in patterns/python/__init__.py
- Provides scope-aware relationship extraction with alias resolution
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
# Import patterns module
from codexlens.parsers.patterns.python import (
PYTHON_PATTERNS,
get_pattern,
get_metavar,
)
# Graceful import pattern following existing convention
try:
from ast_grep_py import SgNode, SgRoot
from codexlens.parsers.astgrep_binding import AstGrepBinding, ASTGREP_AVAILABLE
except ImportError:
SgNode = None # type: ignore[assignment,misc]
SgRoot = None # type: ignore[assignment,misc]
AstGrepBinding = None # type: ignore[assignment,misc]
ASTGREP_AVAILABLE = False
class BaseAstGrepProcessor(ABC):
"""Abstract base class for ast-grep based processors.
Provides common infrastructure for pattern-based AST processing.
Subclasses implement language-specific pattern processing logic.
"""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize processor for a language.
Args:
language_id: Language identifier (python, javascript, typescript)
path: Optional file path for language variant detection
"""
self.language_id = language_id
self.path = path
self._binding: Optional[AstGrepBinding] = None
if ASTGREP_AVAILABLE and AstGrepBinding is not None:
self._binding = AstGrepBinding(language_id, path)
def is_available(self) -> bool:
"""Check if ast-grep processor is available.
Returns:
True if ast-grep binding is ready
"""
return self._binding is not None and self._binding.is_available()
def run_ast_grep(self, source_code: str, pattern: str) -> List[SgNode]: # type: ignore[valid-type]
"""Execute ast-grep pattern matching on source code.
Args:
source_code: Source code text to analyze
pattern: ast-grep pattern string
Returns:
List of matching SgNode objects, empty if no matches or unavailable
"""
if not self.is_available() or self._binding is None:
return []
if not self._binding.parse(source_code):
return []
return self._binding.find_all(pattern)
@abstractmethod
def process_matches(
self,
matches: List[SgNode], # type: ignore[valid-type]
source_code: str,
path: Path,
) -> List[CodeRelationship]:
"""Process ast-grep matches into code relationships.
Args:
matches: List of matched SgNode objects
source_code: Original source code
path: File path being processed
Returns:
List of extracted code relationships
"""
pass
@abstractmethod
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract relationships.
Args:
text: Source code text
path: File path
Returns:
IndexedFile with symbols and relationships, None if unavailable
"""
pass
class AstGrepPythonProcessor(BaseAstGrepProcessor):
"""Python-specific ast-grep processor for relationship extraction.
Extracts INHERITS, CALLS, and IMPORTS relationships from Python code
using declarative ast-grep patterns with scope-aware processing.
"""
def __init__(self, path: Optional[Path] = None) -> None:
"""Initialize Python processor.
Args:
path: Optional file path (for consistency with base class)
"""
super().__init__("python", path)
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse Python source code and extract relationships.
Args:
text: Python source code text
path: File path
Returns:
IndexedFile with symbols and relationships, None if unavailable
"""
if not self.is_available():
return None
try:
symbols = self._extract_symbols(text)
relationships = self._extract_relationships(text, path)
return IndexedFile(
path=str(path.resolve()),
language="python",
symbols=symbols,
chunks=[],
relationships=relationships,
)
except (ValueError, TypeError, AttributeError) as e:
# Log specific parsing errors for debugging
import logging
logging.getLogger(__name__).debug(f"ast-grep parsing error: {e}")
return None
def _extract_symbols(self, source_code: str) -> List[Symbol]:
"""Extract Python symbols (classes, functions, methods).
Args:
source_code: Python source code
Returns:
List of Symbol objects
"""
symbols: List[Symbol] = []
# Collect all scope definitions with line ranges for proper method detection
# Format: (start_line, end_line, kind, name)
scope_defs: List[Tuple[int, int, str, str]] = []
# Track async function positions to avoid duplicates
async_positions: set = set()
# Extract class definitions
class_matches = self.run_ast_grep(source_code, get_pattern("class_def"))
for node in class_matches:
name = self._get_match(node, "NAME")
if name:
start_line, end_line = self._get_line_range(node)
scope_defs.append((start_line, end_line, "class", name))
# Extract async function definitions FIRST (before regular functions)
async_matches = self.run_ast_grep(source_code, get_pattern("async_func_def"))
for node in async_matches:
name = self._get_match(node, "NAME")
if name:
start_line, end_line = self._get_line_range(node)
scope_defs.append((start_line, end_line, "function", name))
async_positions.add(start_line) # Mark this position as async
# Extract function definitions (skip those already captured as async)
func_matches = self.run_ast_grep(source_code, get_pattern("func_def"))
for node in func_matches:
name = self._get_match(node, "NAME")
if name:
start_line, end_line = self._get_line_range(node)
# Skip if already captured as async function (same position)
if start_line not in async_positions:
scope_defs.append((start_line, end_line, "function", name))
# Sort by start line for scope-aware processing
scope_defs.sort(key=lambda x: x[0])
# Process with scope tracking to determine method vs function
scope_stack: List[Tuple[str, int, str]] = [] # (name, end_line, kind)
for start_line, end_line, kind, name in scope_defs:
# Pop scopes that have ended
while scope_stack and scope_stack[-1][1] < start_line:
scope_stack.pop()
if kind == "class":
symbols.append(Symbol(
name=name,
kind="class",
range=(start_line, end_line),
))
scope_stack.append((name, end_line, "class"))
else: # function
# Determine if it's a method (inside a class) or function
is_method = bool(scope_stack) and scope_stack[-1][2] == "class"
symbols.append(Symbol(
name=name,
kind="method" if is_method else "function",
range=(start_line, end_line),
))
scope_stack.append((name, end_line, "function"))
return symbols
def _extract_relationships(self, source_code: str, path: Path) -> List[CodeRelationship]:
"""Extract code relationships with scope and alias resolution.
Args:
source_code: Python source code
path: File path
Returns:
List of CodeRelationship objects
"""
if not self.is_available() or self._binding is None:
return []
source_file = str(path.resolve())
# Collect all matches with line numbers and end lines for scope processing
# Format: (start_line, end_line, match_type, symbol, node)
all_matches: List[Tuple[int, int, str, str, Any]] = []
# Get class definitions (with and without bases) for scope tracking
class_with_bases = self.run_ast_grep(source_code, get_pattern("class_with_bases"))
for node in class_with_bases:
class_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if class_name:
# Record class scope and inheritance
all_matches.append((start_line, end_line, "class_def", class_name, node))
# Extract bases from node text (ast-grep-py 0.40+ doesn't capture $$$)
node_text = self._binding._get_node_text(node) if self._binding else ""
bases_text = self._extract_bases_from_class_text(node_text)
if bases_text:
# Also record inheritance relationship
all_matches.append((start_line, end_line, "inherits", bases_text, node))
# Get classes without bases for scope tracking
class_no_bases = self.run_ast_grep(source_code, get_pattern("class_def"))
for node in class_no_bases:
class_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if class_name:
# Check if not already recorded (avoid duplicates from class_with_bases)
existing = [m for m in all_matches if m[2] == "class_def" and m[3] == class_name and m[0] == start_line]
if not existing:
all_matches.append((start_line, end_line, "class_def", class_name, node))
# Get function definitions for scope tracking
func_matches = self.run_ast_grep(source_code, get_pattern("func_def"))
for node in func_matches:
func_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if func_name:
all_matches.append((start_line, end_line, "func_def", func_name, node))
# Get async function definitions for scope tracking
async_func_matches = self.run_ast_grep(source_code, get_pattern("async_func_def"))
for node in async_func_matches:
func_name = self._get_match(node, "NAME")
start_line, end_line = self._get_line_range(node)
if func_name:
all_matches.append((start_line, end_line, "func_def", func_name, node))
# Get import matches
import_matches = self.run_ast_grep(source_code, get_pattern("import_stmt"))
for node in import_matches:
module = self._get_match(node, "MODULE")
start_line, end_line = self._get_line_range(node)
if module:
all_matches.append((start_line, end_line, "import", module, node))
from_matches = self.run_ast_grep(source_code, get_pattern("import_from"))
for node in from_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
start_line, end_line = self._get_line_range(node)
if module:
all_matches.append((start_line, end_line, "from_import", f"{module}:{names}", node))
# Get call matches
call_matches = self.run_ast_grep(source_code, get_pattern("call"))
for node in call_matches:
func = self._get_match(node, "FUNC")
start_line, end_line = self._get_line_range(node)
if func:
# Skip self. and cls. prefixed calls
base = func.split(".", 1)[0]
if base not in {"self", "cls"}:
all_matches.append((start_line, end_line, "call", func, node))
# Sort by start line number for scope processing
all_matches.sort(key=lambda x: (x[0], x[2] == "call")) # Process scope defs before calls on same line
# Process with scope tracking
relationships = self._process_scope_and_aliases(all_matches, source_file)
return relationships
def _process_scope_and_aliases(
self,
matches: List[Tuple[int, int, str, str, Any]],
source_file: str,
) -> List[CodeRelationship]:
"""Process matches with scope and alias resolution.
Implements proper scope tracking similar to treesitter_parser.py:
- Maintains scope_stack for tracking current scope (class/function names)
- Maintains alias_stack with per-scope alias mappings (inherited from parent)
- Pops scopes when current line passes their end line
- Resolves call targets using current scope's alias map
Args:
matches: Sorted list of (start_line, end_line, type, symbol, node) tuples
source_file: Source file path
Returns:
List of resolved CodeRelationship objects
"""
relationships: List[CodeRelationship] = []
# Scope stack: list of (name, end_line) tuples
scope_stack: List[Tuple[str, int]] = [("<module>", float("inf"))]
# Alias stack: list of alias dicts, one per scope level
# Each new scope inherits parent's aliases (copy on write)
alias_stack: List[Dict[str, str]] = [{}]
def get_current_scope() -> str:
"""Get the name of the current (innermost) scope."""
return scope_stack[-1][0]
def pop_scopes_before(line: int) -> None:
"""Pop all scopes that have ended before the given line."""
while len(scope_stack) > 1 and scope_stack[-1][1] < line:
scope_stack.pop()
alias_stack.pop()
def push_scope(name: str, end_line: int) -> None:
"""Push a new scope onto the stack."""
scope_stack.append((name, end_line))
# Copy parent scope's aliases for inheritance
alias_stack.append(dict(alias_stack[-1]))
def update_aliases(updates: Dict[str, str]) -> None:
"""Update current scope's alias map."""
alias_stack[-1].update(updates)
def resolve_alias(symbol: str) -> str:
"""Resolve a symbol using current scope's alias map."""
if "." not in symbol:
# Simple name - check if it's an alias
return alias_stack[-1].get(symbol, symbol)
# Dotted name - resolve the base
parts = symbol.split(".", 1)
base = parts[0]
rest = parts[1]
if base in alias_stack[-1]:
return f"{alias_stack[-1][base]}.{rest}"
return symbol
for start_line, end_line, match_type, symbol, node in matches:
# Pop any scopes that have ended
pop_scopes_before(start_line)
if match_type == "class_def":
# Push class scope
push_scope(symbol, end_line)
elif match_type == "func_def":
# Push function scope
push_scope(symbol, end_line)
elif match_type == "inherits":
# Record inheritance relationship
# Parse base classes from the bases text
base_classes = self._parse_base_classes(symbol)
for base_class in base_classes:
base_class = base_class.strip()
if base_class:
# Resolve alias for base class
resolved_base = resolve_alias(base_class)
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=resolved_base,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=start_line,
))
elif match_type == "import":
# Process import statement
module = symbol
# Simple import: add base name to alias map
base_name = module.split(".", 1)[0]
update_aliases({base_name: module})
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=start_line,
))
elif match_type == "from_import":
# Process from-import statement
parts = symbol.split(":", 1)
module = parts[0]
names = parts[1] if len(parts) > 1 else ""
# Record the import relationship
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=start_line,
))
# Add aliases for imported names
if names and names != "*":
for name in names.split(","):
name = name.strip()
# Handle "name as alias" syntax
if " as " in name:
as_parts = name.split(" as ")
original = as_parts[0].strip()
alias = as_parts[1].strip()
if alias:
update_aliases({alias: f"{module}.{original}"})
elif name:
update_aliases({name: f"{module}.{name}"})
elif match_type == "call":
# Resolve alias for call target
resolved = resolve_alias(symbol)
relationships.append(CodeRelationship(
source_symbol=get_current_scope(),
target_symbol=resolved,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=start_line,
))
return relationships
def process_matches(
self,
matches: List[SgNode], # type: ignore[valid-type]
source_code: str,
path: Path,
) -> List[CodeRelationship]:
"""Process ast-grep matches into code relationships.
This is a simplified interface for direct match processing.
For full relationship extraction with scope tracking, use parse().
Args:
matches: List of matched SgNode objects
source_code: Original source code
path: File path being processed
Returns:
List of extracted code relationships
"""
if not self.is_available() or self._binding is None:
return []
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
for node in matches:
# Default to call relationship for generic matches
func = self._get_match(node, "FUNC")
line = self._get_line_number(node)
if func:
base = func.split(".", 1)[0]
if base not in {"self", "cls"}:
relationships.append(CodeRelationship(
source_symbol="<module>",
target_symbol=func,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships
def _get_match(self, node: SgNode, metavar: str) -> str: # type: ignore[valid-type]
"""Extract matched metavariable value from node.
Args:
node: SgNode with match
metavar: Metavariable name (without $ prefix)
Returns:
Matched text or empty string
"""
if self._binding is None or node is None:
return ""
return self._binding._get_match(node, metavar)
def _get_line_number(self, node: SgNode) -> int: # type: ignore[valid-type]
"""Get starting line number of a node.
Args:
node: SgNode to get line number for
Returns:
1-based line number
"""
if self._binding is None or node is None:
return 0
return self._binding._get_line_number(node)
def _get_line_range(self, node: SgNode) -> Tuple[int, int]: # type: ignore[valid-type]
"""Get line range for a node.
Args:
node: SgNode to get range for
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
if self._binding is None or node is None:
return (0, 0)
return self._binding._get_line_range(node)
# =========================================================================
# Dedicated extraction methods for INHERITS, CALL, IMPORTS relationships
# =========================================================================
def extract_inherits(
self,
source_code: str,
source_file: str,
source_symbol: str = "<module>",
) -> List[CodeRelationship]:
"""Extract INHERITS relationships from Python code.
Identifies class inheritance patterns including:
- Single inheritance: class Child(Parent):
- Multiple inheritance: class Child(A, B, C):
Args:
source_code: Python source code to analyze
source_file: Path to the source file
source_symbol: The containing scope (class or module)
Returns:
List of CodeRelationship objects with INHERITS type
"""
if not self.is_available():
return []
relationships: List[CodeRelationship] = []
# Use class_with_bases pattern to find classes with inheritance
matches = self.run_ast_grep(source_code, get_pattern("class_with_bases"))
for node in matches:
class_name = self._get_match(node, "NAME")
line = self._get_line_number(node)
if class_name:
# Extract bases from the node text (first line: "class ClassName(Base1, Base2):")
# ast-grep-py 0.40+ doesn't capture $$$ multi-matches, so parse from text
node_text = self._binding._get_node_text(node) if self._binding else ""
bases_text = self._extract_bases_from_class_text(node_text)
if bases_text:
# Parse individual base classes from the bases text
base_classes = self._parse_base_classes(bases_text)
for base_class in base_classes:
base_class = base_class.strip()
if base_class:
relationships.append(CodeRelationship(
source_symbol=class_name,
target_symbol=base_class,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships
def _extract_bases_from_class_text(self, class_text: str) -> str:
"""Extract base classes text from class definition.
Args:
class_text: Full text of class definition (e.g., "class Dog(Animal):\\n pass")
Returns:
Text inside parentheses (e.g., "Animal") or empty string
"""
import re
# Match "class Name(BASES):" - extract BASES
match = re.search(r'class\s+\w+\s*\(([^)]*)\)\s*:', class_text)
if match:
return match.group(1).strip()
return ""
def extract_calls(
self,
source_code: str,
source_file: str,
source_symbol: str = "<module>",
alias_map: Optional[Dict[str, str]] = None,
) -> List[CodeRelationship]:
"""Extract CALL relationships from Python code.
Identifies function and method call patterns including:
- Simple calls: func()
- Calls with arguments: func(arg1, arg2)
- Method calls: obj.method()
- Chained calls: obj.method1().method2()
Args:
source_code: Python source code to analyze
source_file: Path to the source file
source_symbol: The containing scope (class or module)
alias_map: Optional alias map for resolving imported names
Returns:
List of CodeRelationship objects with CALL type
"""
if not self.is_available():
return []
relationships: List[CodeRelationship] = []
alias_map = alias_map or {}
# Use the generic call pattern
matches = self.run_ast_grep(source_code, get_pattern("call"))
for node in matches:
func = self._get_match(node, "FUNC")
line = self._get_line_number(node)
if func:
# Skip self. and cls. prefixed calls (internal method calls)
base = func.split(".", 1)[0]
if base in {"self", "cls", "super"}:
continue
# Resolve alias if available
resolved = self._resolve_call_alias(func, alias_map)
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=resolved,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships
def extract_imports(
self,
source_code: str,
source_file: str,
source_symbol: str = "<module>",
) -> Tuple[List[CodeRelationship], Dict[str, str]]:
"""Extract IMPORTS relationships from Python code.
Identifies import patterns including:
- Simple import: import os
- Import with alias: import numpy as np
- From import: from typing import List
- From import with alias: from collections import defaultdict as dd
- Relative import: from .module import func
- Star import: from module import *
Args:
source_code: Python source code to analyze
source_file: Path to the source file
source_symbol: The containing scope (class or module)
Returns:
Tuple of:
- List of CodeRelationship objects with IMPORTS type
- Dict mapping local names to fully qualified module names (alias map)
"""
if not self.is_available():
return [], {}
relationships: List[CodeRelationship] = []
alias_map: Dict[str, str] = {}
# Process simple imports: import X
import_matches = self.run_ast_grep(source_code, get_pattern("import_stmt"))
for node in import_matches:
module = self._get_match(node, "MODULE")
line = self._get_line_number(node)
if module:
# Add to alias map: first part of module
base_name = module.split(".", 1)[0]
alias_map[base_name] = module
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Process import with alias: import X as Y
alias_matches = self.run_ast_grep(source_code, get_pattern("import_with_alias"))
for node in alias_matches:
module = self._get_match(node, "MODULE")
alias = self._get_match(node, "ALIAS")
line = self._get_line_number(node)
if module and alias:
alias_map[alias] = module
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Process from imports: from X import Y
from_matches = self.run_ast_grep(source_code, get_pattern("import_from"))
for node in from_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
line = self._get_line_number(node)
if module:
# Add relationship for the module
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Add aliases for imported names
if names and names != "*":
for name in names.split(","):
name = name.strip()
# Handle "name as alias" syntax
if " as " in name:
parts = name.split(" as ")
original = parts[0].strip()
alias = parts[1].strip()
alias_map[alias] = f"{module}.{original}"
elif name:
alias_map[name] = f"{module}.{name}"
# Process star imports: from X import *
star_matches = self.run_ast_grep(source_code, get_pattern("from_import_star"))
for node in star_matches:
module = self._get_match(node, "MODULE")
line = self._get_line_number(node)
if module:
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=f"{module}.*",
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
# Process relative imports: from .X import Y
relative_matches = self.run_ast_grep(source_code, get_pattern("relative_import"))
for node in relative_matches:
module = self._get_match(node, "MODULE")
names = self._get_match(node, "NAMES")
line = self._get_line_number(node)
# Prepend dot for relative module path
rel_module = f".{module}" if module else "."
relationships.append(CodeRelationship(
source_symbol=source_symbol,
target_symbol=rel_module,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line,
))
return relationships, alias_map
# =========================================================================
# Helper methods for pattern processing
# =========================================================================
def _parse_base_classes(self, bases_text: str) -> List[str]:
"""Parse base class names from inheritance text.
Handles single and multiple inheritance with proper comma splitting.
Accounts for nested parentheses and complex type annotations.
Args:
bases_text: Text inside the parentheses of class definition
Returns:
List of base class names
"""
if not bases_text:
return []
# Simple comma split (may not handle all edge cases)
bases = []
depth = 0
current = []
for char in bases_text:
if char == "(":
depth += 1
current.append(char)
elif char == ")":
depth -= 1
current.append(char)
elif char == "," and depth == 0:
base = "".join(current).strip()
if base:
bases.append(base)
current = []
else:
current.append(char)
# Add the last base class
if current:
base = "".join(current).strip()
if base:
bases.append(base)
return bases
def _resolve_call_alias(self, func_name: str, alias_map: Dict[str, str]) -> str:
"""Resolve a function call name using import aliases.
Args:
func_name: The function/method name as it appears in code
alias_map: Mapping of local names to fully qualified names
Returns:
Resolved function name (fully qualified if possible)
"""
if "." not in func_name:
# Simple function call - check if it's an alias
return alias_map.get(func_name, func_name)
# Method call or qualified name - resolve the base
parts = func_name.split(".", 1)
base = parts[0]
rest = parts[1]
if base in alias_map:
return f"{alias_map[base]}.{rest}"
return func_name
def is_astgrep_processor_available() -> bool:
"""Check if ast-grep processor is available.
Returns:
True if ast-grep-py is installed and processor can be used
"""
return ASTGREP_AVAILABLE
__all__ = [
"BaseAstGrepProcessor",
"AstGrepPythonProcessor",
"is_astgrep_processor_available",
]

View File

@@ -0,0 +1,5 @@
"""ast-grep pattern definitions for various languages.
This package contains language-specific pattern definitions for
extracting code relationships using ast-grep declarative patterns.
"""

View File

@@ -0,0 +1,204 @@
"""Python ast-grep patterns for relationship extraction.
This module defines declarative patterns for extracting code relationships
(inheritance, calls, imports) from Python source code using ast-grep.
Pattern Syntax (ast-grep-py 0.40+):
$VAR - Single metavariable (matches one AST node)
$$$VAR - Multiple metavariable (matches zero or more nodes)
Example:
"class $CLASS_NAME($$$BASES) $$$BODY" matches:
class MyClass(BaseClass):
pass
with $CLASS_NAME = "MyClass", $$$BASES = "BaseClass", $$$BODY = "pass"
YAML Pattern Files:
inherits.yaml - INHERITS relationship patterns (single/multiple inheritance)
imports.yaml - IMPORTS relationship patterns (import, from...import, as)
call.yaml - CALL relationship patterns (function/method calls)
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional
# Directory containing YAML pattern files
PATTERNS_DIR = Path(__file__).parent
# Python ast-grep patterns organized by relationship type
# Note: ast-grep-py 0.40+ uses $$$ for zero-or-more multi-match
PYTHON_PATTERNS: Dict[str, str] = {
# Class definitions with inheritance
"class_def": "class $NAME $$$BODY",
"class_with_bases": "class $NAME($$$BASES) $$$BODY",
# Single inheritance: class Child(Parent):
"single_inheritance": "class $CLASS_NAME($BASE) $$$BODY",
# Multiple inheritance: class Child(A, B, C):
"multiple_inheritance": "class $CLASS_NAME($BASE, $$$MORE_BASES) $$$BODY",
# Function definitions (use $$$ for zero-or-more params)
"func_def": "def $NAME($$$PARAMS): $$$BODY",
"async_func_def": "async def $NAME($$$PARAMS): $$$BODY",
# Import statements - basic forms
"import_stmt": "import $MODULE",
"import_from": "from $MODULE import $NAMES",
# Import statements - extended forms
"import_with_alias": "import $MODULE as $ALIAS",
"import_multiple": "import $FIRST, $$$REST",
"from_import_single": "from $MODULE import $NAME",
"from_import_with_alias": "from $MODULE import $NAME as $ALIAS",
"from_import_multiple": "from $MODULE import $FIRST, $$$REST",
"from_import_star": "from $MODULE import *",
"relative_import": "from .$$$MODULE import $NAMES",
# Function/method calls - basic form (use $$$ for zero-or-more args)
"call": "$FUNC($$$ARGS)",
"method_call": "$OBJ.$METHOD($$$ARGS)",
# Function/method calls - specific forms
"simple_call": "$FUNC()",
"call_with_args": "$FUNC($$$ARGS)",
"chained_call": "$OBJ.$METHOD($$$ARGS).$$$CHAIN",
"constructor_call": "$CLASS($$$ARGS)",
}
# Metavariable names for extracting match data
METAVARS = {
# Class patterns
"class_name": "NAME",
"class_bases": "BASES",
"class_body": "BODY",
"inherit_class": "CLASS_NAME",
"inherit_base": "BASE",
"inherit_more_bases": "MORE_BASES",
# Function patterns
"func_name": "NAME",
"func_params": "PARAMS",
"func_body": "BODY",
# Import patterns
"import_module": "MODULE",
"import_names": "NAMES",
"import_alias": "ALIAS",
"import_first": "FIRST",
"import_rest": "REST",
# Call patterns
"call_func": "FUNC",
"call_obj": "OBJ",
"call_method": "METHOD",
"call_args": "ARGS",
"call_class": "CLASS",
"call_chain": "CHAIN",
}
# Relationship pattern mapping - expanded for new patterns
RELATIONSHIP_PATTERNS: Dict[str, List[str]] = {
"inheritance": ["class_with_bases", "single_inheritance", "multiple_inheritance"],
"imports": [
"import_stmt", "import_from",
"import_with_alias", "import_multiple",
"from_import_single", "from_import_with_alias",
"from_import_multiple", "from_import_star",
"relative_import",
],
"calls": ["call", "method_call", "simple_call", "call_with_args", "constructor_call"],
}
# YAML pattern file mapping
YAML_PATTERN_FILES = {
"inheritance": "inherits.yaml",
"imports": "imports.yaml",
"calls": "call.yaml",
}
def get_pattern(pattern_name: str) -> str:
"""Get an ast-grep pattern by name.
Args:
pattern_name: Key from PYTHON_PATTERNS dict
Returns:
Pattern string
Raises:
KeyError: If pattern name not found
"""
if pattern_name not in PYTHON_PATTERNS:
raise KeyError(f"Unknown pattern: {pattern_name}. Available: {list(PYTHON_PATTERNS.keys())}")
return PYTHON_PATTERNS[pattern_name]
def get_patterns_for_relationship(rel_type: str) -> List[str]:
"""Get all patterns that can extract a given relationship type.
Args:
rel_type: Relationship type (inheritance, imports, calls)
Returns:
List of pattern names
"""
return RELATIONSHIP_PATTERNS.get(rel_type, [])
def get_metavar(name: str) -> str:
"""Get metavariable name without $ prefix.
Args:
name: Key from METAVARS dict
Returns:
Metavariable name (e.g., "NAME" not "$NAME")
"""
return METAVARS.get(name, name.upper())
def get_yaml_pattern_path(rel_type: str) -> Optional[Path]:
"""Get the path to a YAML pattern file for a relationship type.
Args:
rel_type: Relationship type (inheritance, imports, calls)
Returns:
Path to YAML file or None if not found
"""
filename = YAML_PATTERN_FILES.get(rel_type)
if filename:
return PATTERNS_DIR / filename
return None
def list_yaml_pattern_files() -> Dict[str, Path]:
"""List all available YAML pattern files.
Returns:
Dict mapping relationship type to YAML file path
"""
result = {}
for rel_type, filename in YAML_PATTERN_FILES.items():
path = PATTERNS_DIR / filename
if path.exists():
result[rel_type] = path
return result
__all__ = [
"PYTHON_PATTERNS",
"METAVARS",
"RELATIONSHIP_PATTERNS",
"YAML_PATTERN_FILES",
"PATTERNS_DIR",
"get_pattern",
"get_patterns_for_relationship",
"get_metavar",
"get_yaml_pattern_path",
"list_yaml_pattern_files",
]

View File

@@ -0,0 +1,87 @@
# Python CALL patterns for ast-grep
# Extracts function and method call expressions
# Pattern metadata
id: python-call
language: python
description: Extract function and method calls from Python code
patterns:
# Simple function call
# Matches: func()
- id: simple_call
pattern: "$FUNC()"
message: "Found simple function call"
severity: hint
# Function call with arguments
# Matches: func(arg1, arg2)
- id: call_with_args
pattern: "$FUNC($$$ARGS)"
message: "Found function call with arguments"
severity: hint
# Method call
# Matches: obj.method()
- id: method_call
pattern: "$OBJ.$METHOD($$$ARGS)"
message: "Found method call"
severity: hint
# Chained method call
# Matches: obj.method1().method2()
- id: chained_call
pattern: "$OBJ.$METHOD($$$ARGS).$$$CHAIN"
message: "Found chained method call"
severity: hint
# Call with keyword arguments
# Matches: func(arg=value)
- id: call_with_kwargs
pattern: "$FUNC($$$ARGS, $KWARG=$VALUE$$$MORE)"
message: "Found call with keyword argument"
severity: hint
# Constructor call
# Matches: ClassName()
- id: constructor_call
pattern: "$CLASS($$$ARGS)"
message: "Found constructor call"
severity: hint
# Subscript call (not a real call, but often confused)
# This pattern helps exclude indexing from calls
- id: subscript_access
pattern: "$OBJ[$INDEX]"
message: "Found subscript access"
severity: hint
# Metavariables used:
# $FUNC - Function name being called
# $OBJ - Object receiving the method call
# $METHOD - Method name being called
# $ARGS - Positional arguments
# $KWARG - Keyword argument name
# $VALUE - Keyword argument value
# $CLASS - Class name for constructor calls
# $INDEX - Index for subscript access
# $$$MORE - Additional arguments
# $$$CHAIN - Additional method chains
# Note: The generic call pattern "$FUNC($$$ARGS)" will match all function calls
# including method calls and constructor calls. More specific patterns help
# categorize the type of call.
# Examples matched:
# print("hello") -> call_with_args
# len(items) -> call_with_args
# obj.process() -> method_call
# obj.get().save() -> chained_call
# func(name=value) -> call_with_kwargs
# MyClass() -> constructor_call
# items[0] -> subscript_access (not a call)
# Filtering notes:
# - self.method() calls are typically filtered during processing
# - cls.method() calls are typically filtered during processing
# - super().method() calls may be handled specially

View File

@@ -0,0 +1,82 @@
# Python IMPORTS patterns for ast-grep
# Extracts import statements (import, from...import, as aliases)
# Pattern metadata
id: python-imports
language: python
description: Extract import statements from Python code
patterns:
# Simple import
# Matches: import os
- id: simple_import
pattern: "import $MODULE"
message: "Found simple import"
severity: hint
# Import with alias
# Matches: import numpy as np
- id: import_with_alias
pattern: "import $MODULE as $ALIAS"
message: "Found import with alias"
severity: hint
# Multiple imports
# Matches: import os, sys
- id: multiple_imports
pattern: "import $FIRST, $$$REST"
message: "Found multiple imports"
severity: hint
# From import (single name)
# Matches: from os import path
- id: from_import_single
pattern: "from $MODULE import $NAME"
message: "Found from-import single"
severity: hint
# From import with alias
# Matches: from collections import defaultdict as dd
- id: from_import_with_alias
pattern: "from $MODULE import $NAME as $ALIAS"
message: "Found from-import with alias"
severity: hint
# From import multiple names
# Matches: from typing import List, Dict, Optional
- id: from_import_multiple
pattern: "from $MODULE import $FIRST, $$$REST"
message: "Found from-import multiple"
severity: hint
# From import star
# Matches: from module import *
- id: from_import_star
pattern: "from $MODULE import *"
message: "Found star import"
severity: warning
# Relative import
# Matches: from .module import func
- id: relative_import
pattern: "from .$$$MODULE import $NAMES"
message: "Found relative import"
severity: hint
# Metavariables used:
# $MODULE - The module being imported
# $ALIAS - The alias for the import
# $NAME - The specific name being imported
# $FIRST - First item in a multi-item import
# $$$REST - Remaining items in a multi-item import
# $NAMES - Names being imported in from-import
# Examples matched:
# import os -> simple_import
# import numpy as np -> import_with_alias
# import os, sys, pathlib -> multiple_imports
# from os import path -> from_import_single
# from typing import List, Dict, Set -> from_import_multiple
# from collections import defaultdict -> from_import_single
# from .helpers import utils -> relative_import
# from module import * -> from_import_star

View File

@@ -0,0 +1,42 @@
# Python INHERITS patterns for ast-grep
# Extracts class inheritance relationships (single and multiple inheritance)
# Pattern metadata
id: python-inherits
language: python
description: Extract class inheritance relationships from Python code
# Single inheritance pattern
# Matches: class Child(Parent):
patterns:
- id: single_inheritance
pattern: "class $CLASS_NAME($BASE) $$$BODY"
message: "Found single inheritance"
severity: hint
# Multiple inheritance pattern
# Matches: class Child(Parent1, Parent2, Parent3):
- id: multiple_inheritance
pattern: "class $CLASS_NAME($BASE, $$$MORE_BASES) $$$BODY"
message: "Found multiple inheritance"
severity: hint
# Generic inheritance with any number of bases
# Matches: class Child(...): with any number of parent classes
- id: class_with_bases
pattern: "class $NAME($$$BASES) $$$BODY"
message: "Found class with base classes"
severity: hint
# Metavariables used:
# $CLASS_NAME - The name of the child class
# $BASE - First base class (for single inheritance)
# $BASES - All base classes combined
# $MORE_BASES - Additional base classes after the first (for multiple inheritance)
# $$$BODY - Class body (statements, can be multiple)
# Examples matched:
# class Dog(Animal): -> single_inheritance
# class C(A, B): -> multiple_inheritance
# class D(BaseMixin, logging.Log) -> class_with_bases
# class E(A, B, C, D): -> multiple_inheritance

View File

@@ -11,7 +11,7 @@ return `None`; callers should use a regex-based fallback such as
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING
try:
from tree_sitter import Language as TreeSitterLanguage
@@ -27,26 +27,45 @@ except ImportError:
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
if TYPE_CHECKING:
from codexlens.config import Config
class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction."""
"""Parser using tree-sitter for AST-level symbol extraction.
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
Supports optional ast-grep integration for Python relationship extraction
when config.use_astgrep is True and ast-grep-py is available.
"""
def __init__(
self,
language_id: str,
path: Optional[Path] = None,
config: Optional["Config"] = None,
) -> None:
"""Initialize tree-sitter parser for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx)
config: Optional Config instance for parser feature toggles
"""
self.language_id = language_id
self.path = path
self._config = config
self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer()
self._astgrep_processor = None
if TREE_SITTER_AVAILABLE:
self._initialize_parser()
# Initialize ast-grep processor for Python if config enables it
if self._should_use_astgrep():
self._initialize_astgrep_processor()
def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None:
@@ -82,6 +101,31 @@ class TreeSitterSymbolParser:
self._parser = None
self._language = None
def _should_use_astgrep(self) -> bool:
"""Check if ast-grep should be used for relationship extraction.
Returns:
True if config.use_astgrep is True and language is Python
"""
if self._config is None:
return False
if not getattr(self._config, "use_astgrep", False):
return False
return self.language_id == "python"
def _initialize_astgrep_processor(self) -> None:
"""Initialize ast-grep processor for Python relationship extraction."""
try:
from codexlens.parsers.astgrep_processor import (
AstGrepPythonProcessor,
is_astgrep_processor_available,
)
if is_astgrep_processor_available():
self._astgrep_processor = AstGrepPythonProcessor(self.path)
except ImportError:
self._astgrep_processor = None
def is_available(self) -> bool:
"""Check if tree-sitter parser is available.
@@ -138,7 +182,10 @@ class TreeSitterSymbolParser:
source_bytes, root = parsed
try:
symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path)
# Pass source_code for ast-grep integration
relationships = self._extract_relationships(
source_bytes, root, path, source_code=text
)
return IndexedFile(
path=str(path.resolve()),
@@ -173,13 +220,68 @@ class TreeSitterSymbolParser:
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
source_code: Optional[str] = None,
) -> List[CodeRelationship]:
"""Extract relationships, optionally using ast-grep for Python.
When config.use_astgrep is True and ast-grep is available for Python,
uses ast-grep for relationship extraction. Otherwise, uses tree-sitter.
Args:
source_bytes: Source code as bytes
root: Root AST node from tree-sitter
path: File path
source_code: Optional source code string (required for ast-grep)
Returns:
List of extracted relationships
"""
if self.language_id == "python":
# Try ast-grep first if configured and available
if self._astgrep_processor is not None and source_code is not None:
try:
astgrep_rels = self._extract_python_relationships_astgrep(
source_code, path
)
if astgrep_rels is not None:
return astgrep_rels
except Exception:
# Fall back to tree-sitter on ast-grep failure
pass
return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path)
return []
def _extract_python_relationships_astgrep(
self,
source_code: str,
path: Path,
) -> Optional[List[CodeRelationship]]:
"""Extract Python relationships using ast-grep processor.
Args:
source_code: Python source code text
path: File path
Returns:
List of relationships, or None if ast-grep unavailable
"""
if self._astgrep_processor is None:
return None
if not self._astgrep_processor.is_available():
return None
try:
indexed = self._astgrep_processor.parse(source_code, path)
if indexed is not None:
return indexed.relationships
except Exception:
pass
return None
def _extract_python_relationships(
self,
source_bytes: bytes,

View File

@@ -0,0 +1 @@
"""Tests for codexlens.parsers modules."""

View File

@@ -0,0 +1,444 @@
"""Tests for dedicated extraction methods: extract_inherits, extract_calls, extract_imports.
Tests pattern-based relationship extraction from Python source code
using ast-grep-py bindings for INHERITS, CALL, and IMPORTS relationships.
"""
from pathlib import Path
import pytest
from codexlens.parsers.astgrep_processor import (
AstGrepPythonProcessor,
is_astgrep_processor_available,
)
from codexlens.entities import RelationshipType
# Check if ast-grep is available for conditional test skipping
ASTGREP_AVAILABLE = is_astgrep_processor_available()
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractInherits:
"""Tests for extract_inherits method - INHERITS relationship extraction."""
def test_single_inheritance(self):
"""Test extraction of single inheritance relationship."""
processor = AstGrepPythonProcessor()
code = """
class Animal:
pass
class Dog(Animal):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 1
rel = relationships[0]
assert rel.source_symbol == "Dog"
assert rel.target_symbol == "Animal"
assert rel.relationship_type == RelationshipType.INHERITS
def test_multiple_inheritance(self):
"""Test extraction of multiple inheritance relationships."""
processor = AstGrepPythonProcessor()
code = """
class A:
pass
class B:
pass
class C(A, B):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
# Should have 2 relationships: C->A and C->B
assert len(relationships) == 2
targets = {r.target_symbol for r in relationships}
assert "A" in targets
assert "B" in targets
for rel in relationships:
assert rel.source_symbol == "C"
def test_no_inheritance(self):
"""Test that classes without inheritance return empty list."""
processor = AstGrepPythonProcessor()
code = """
class Standalone:
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 0
def test_nested_class_inheritance(self):
"""Test extraction of inheritance in nested classes."""
processor = AstGrepPythonProcessor()
code = """
class Outer:
class Inner(Base):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 1
assert relationships[0].source_symbol == "Inner"
assert relationships[0].target_symbol == "Base"
def test_inheritance_with_complex_bases(self):
"""Test extraction with generic or complex base classes."""
processor = AstGrepPythonProcessor()
code = """
class Service(BaseService, mixins.Loggable):
pass
"""
relationships = processor.extract_inherits(code, "test.py")
assert len(relationships) == 2
targets = {r.target_symbol for r in relationships}
assert "BaseService" in targets
assert "mixins.Loggable" in targets
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractCalls:
"""Tests for extract_calls method - CALL relationship extraction."""
def test_simple_function_call(self):
"""Test extraction of simple function calls."""
processor = AstGrepPythonProcessor()
code = """
def main():
print("hello")
len([1, 2, 3])
"""
relationships = processor.extract_calls(code, "test.py", "main")
targets = {r.target_symbol for r in relationships}
assert "print" in targets
assert "len" in targets
def test_method_call(self):
"""Test extraction of method calls."""
processor = AstGrepPythonProcessor()
code = """
def process():
obj.method()
items.append(1)
"""
relationships = processor.extract_calls(code, "test.py", "process")
targets = {r.target_symbol for r in relationships}
assert "obj.method" in targets
assert "items.append" in targets
def test_skips_self_calls(self):
"""Test that self.method() calls are filtered."""
processor = AstGrepPythonProcessor()
code = """
class Service:
def process(self):
self.internal()
external_func()
"""
relationships = processor.extract_calls(code, "test.py", "Service")
targets = {r.target_symbol for r in relationships}
# self.internal should be filtered
assert "self.internal" not in targets
assert "internal" not in targets
assert "external_func" in targets
def test_skips_cls_calls(self):
"""Test that cls.method() calls are filtered."""
processor = AstGrepPythonProcessor()
code = """
class Factory:
@classmethod
def create(cls):
cls.helper()
other_func()
"""
relationships = processor.extract_calls(code, "test.py", "Factory")
targets = {r.target_symbol for r in relationships}
assert "cls.helper" not in targets
assert "other_func" in targets
def test_alias_resolution(self):
"""Test call alias resolution using import map."""
processor = AstGrepPythonProcessor()
code = """
def main():
np.array([1, 2, 3])
"""
alias_map = {"np": "numpy"}
relationships = processor.extract_calls(code, "test.py", "main", alias_map)
assert len(relationships) >= 1
# Should resolve np.array to numpy.array
assert any("numpy.array" in r.target_symbol for r in relationships)
def test_no_calls(self):
"""Test that code without calls returns empty list."""
processor = AstGrepPythonProcessor()
code = """
x = 1
y = x + 2
"""
relationships = processor.extract_calls(code, "test.py")
assert len(relationships) == 0
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractImports:
"""Tests for extract_imports method - IMPORTS relationship extraction."""
def test_simple_import(self):
"""Test extraction of simple import statements."""
processor = AstGrepPythonProcessor()
code = "import os"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
assert relationships[0].target_symbol == "os"
assert relationships[0].relationship_type == RelationshipType.IMPORTS
assert alias_map.get("os") == "os"
def test_import_with_alias(self):
"""Test extraction of import with alias."""
processor = AstGrepPythonProcessor()
code = "import numpy as np"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
assert relationships[0].target_symbol == "numpy"
assert alias_map.get("np") == "numpy"
def test_from_import(self):
"""Test extraction of from-import statements."""
processor = AstGrepPythonProcessor()
code = "from typing import List, Dict"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
assert relationships[0].target_symbol == "typing"
assert alias_map.get("List") == "typing.List"
assert alias_map.get("Dict") == "typing.Dict"
def test_from_import_with_alias(self):
"""Test extraction of from-import with alias."""
processor = AstGrepPythonProcessor()
code = "from collections import defaultdict as dd"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 1
# The alias map should map dd to collections.defaultcount
assert "dd" in alias_map
assert "defaultdict" in alias_map.get("dd", "")
def test_star_import(self):
"""Test extraction of star imports."""
processor = AstGrepPythonProcessor()
code = "from module import *"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) >= 1
# Star import should be recorded
star_imports = [r for r in relationships if "*" in r.target_symbol]
assert len(star_imports) >= 1
def test_relative_import(self):
"""Test extraction of relative imports."""
processor = AstGrepPythonProcessor()
code = "from .utils import helper"
relationships, alias_map = processor.extract_imports(code, "test.py")
# Should capture the relative import
assert len(relationships) >= 1
rel_imports = [r for r in relationships if r.target_symbol.startswith(".")]
assert len(rel_imports) >= 1
def test_multiple_imports(self):
"""Test extraction of multiple import types."""
processor = AstGrepPythonProcessor()
code = """
import os
import sys
from typing import List
from collections import defaultdict as dd
"""
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) >= 4
targets = {r.target_symbol for r in relationships}
assert "os" in targets
assert "sys" in targets
assert "typing" in targets
assert "collections" in targets
def test_no_imports(self):
"""Test that code without imports returns empty list."""
processor = AstGrepPythonProcessor()
code = """
x = 1
def foo():
pass
"""
relationships, alias_map = processor.extract_imports(code, "test.py")
assert len(relationships) == 0
assert len(alias_map) == 0
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestExtractMethodsIntegration:
"""Integration tests combining multiple extraction methods."""
def test_full_file_extraction(self):
"""Test extracting all relationships from a complete file."""
processor = AstGrepPythonProcessor()
code = """
import os
from typing import List, Optional
class Base:
pass
class Service(Base):
def __init__(self):
self.data = []
def process(self):
result = os.path.join("a", "b")
items = List([1, 2, 3])
return result
def main():
svc = Service()
svc.process()
"""
source_file = "test.py"
# Extract all relationship types
imports, alias_map = processor.extract_imports(code, source_file)
inherits = processor.extract_inherits(code, source_file)
calls = processor.extract_calls(code, source_file, alias_map=alias_map)
# Verify we got all expected relationships
assert len(imports) >= 2 # os and typing
assert len(inherits) == 1 # Service -> Base
assert len(calls) >= 2 # os.path.join and others
# Verify inheritance
assert any(r.source_symbol == "Service" and r.target_symbol == "Base"
for r in inherits)
def test_alias_propagation(self):
"""Test that import aliases propagate to call resolution."""
processor = AstGrepPythonProcessor()
code = """
import numpy as np
def compute():
arr = np.array([1, 2, 3])
return np.sum(arr)
"""
source_file = "test.py"
imports, alias_map = processor.extract_imports(code, source_file)
calls = processor.extract_calls(code, source_file, alias_map=alias_map)
# Alias map should have np -> numpy
assert alias_map.get("np") == "numpy"
# Calls should resolve np.array and np.sum
resolved_targets = {r.target_symbol for r in calls}
# At minimum, np.array and np.sum should be captured
np_calls = [t for t in resolved_targets if "np" in t or "numpy" in t]
assert len(np_calls) >= 2
class TestExtractMethodFallback:
"""Tests for fallback behavior when ast-grep unavailable."""
def test_extract_inherits_empty_when_unavailable(self):
"""Test extract_inherits returns empty list when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "class Dog(Animal): pass"
relationships = processor.extract_inherits(code, "test.py")
assert relationships == []
def test_extract_calls_empty_when_unavailable(self):
"""Test extract_calls returns empty list when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "print('hello')"
relationships = processor.extract_calls(code, "test.py")
assert relationships == []
def test_extract_imports_empty_when_unavailable(self):
"""Test extract_imports returns empty tuple when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "import os"
relationships, alias_map = processor.extract_imports(code, "test.py")
assert relationships == []
assert alias_map == {}
class TestHelperMethods:
"""Tests for internal helper methods."""
def test_parse_base_classes_single(self):
"""Test _parse_base_classes with single base."""
processor = AstGrepPythonProcessor()
result = processor._parse_base_classes("BaseClass")
assert result == ["BaseClass"]
def test_parse_base_classes_multiple(self):
"""Test _parse_base_classes with multiple bases."""
processor = AstGrepPythonProcessor()
result = processor._parse_base_classes("A, B, C")
assert result == ["A", "B", "C"]
def test_parse_base_classes_with_generics(self):
"""Test _parse_base_classes with generic types."""
processor = AstGrepPythonProcessor()
result = processor._parse_base_classes("Generic[T], Mixin")
assert "Generic[T]" in result
assert "Mixin" in result
def test_resolve_call_alias_simple(self):
"""Test _resolve_call_alias with simple name."""
processor = AstGrepPythonProcessor()
alias_map = {"np": "numpy"}
result = processor._resolve_call_alias("np", alias_map)
assert result == "numpy"
def test_resolve_call_alias_qualified(self):
"""Test _resolve_call_alias with qualified name."""
processor = AstGrepPythonProcessor()
alias_map = {"np": "numpy"}
result = processor._resolve_call_alias("np.array", alias_map)
assert result == "numpy.array"
def test_resolve_call_alias_no_match(self):
"""Test _resolve_call_alias when no alias exists."""
processor = AstGrepPythonProcessor()
alias_map = {}
result = processor._resolve_call_alias("myfunc", alias_map)
assert result == "myfunc"

View File

@@ -0,0 +1,402 @@
"""Tests for AstGrepPythonProcessor.
Tests pattern-based relationship extraction from Python source code
using ast-grep-py bindings.
"""
from pathlib import Path
import pytest
from codexlens.parsers.astgrep_processor import (
AstGrepPythonProcessor,
BaseAstGrepProcessor,
is_astgrep_processor_available,
)
from codexlens.parsers.patterns.python import (
PYTHON_PATTERNS,
METAVARS,
RELATIONSHIP_PATTERNS,
get_pattern,
get_patterns_for_relationship,
get_metavar,
)
# Check if ast-grep is available for conditional test skipping
ASTGREP_AVAILABLE = is_astgrep_processor_available()
class TestPatternDefinitions:
"""Tests for Python pattern definitions."""
def test_python_patterns_exist(self):
"""Verify all expected patterns are defined."""
expected_patterns = [
"class_def",
"class_with_bases",
"func_def",
"async_func_def",
"import_stmt",
"import_from",
"call",
"method_call",
]
for pattern_name in expected_patterns:
assert pattern_name in PYTHON_PATTERNS, f"Missing pattern: {pattern_name}"
def test_get_pattern_returns_correct_pattern(self):
"""Test get_pattern returns expected pattern strings."""
# Note: ast-grep-py 0.40+ uses $$$ for zero-or-more multi-match
assert get_pattern("class_def") == "class $NAME $$$BODY"
assert get_pattern("func_def") == "def $NAME($$$PARAMS): $$$BODY"
assert get_pattern("import_stmt") == "import $MODULE"
def test_get_pattern_raises_for_unknown(self):
"""Test get_pattern raises KeyError for unknown patterns."""
with pytest.raises(KeyError):
get_pattern("nonexistent_pattern")
def test_metavars_defined(self):
"""Verify metavariable mappings are defined."""
expected_metavars = [
"class_name",
"func_name",
"import_module",
"call_func",
]
for var in expected_metavars:
assert var in METAVARS, f"Missing metavar: {var}"
def test_get_metavar(self):
"""Test get_metavar returns correct values."""
assert get_metavar("class_name") == "NAME"
assert get_metavar("func_name") == "NAME"
assert get_metavar("import_module") == "MODULE"
def test_relationship_patterns_mapping(self):
"""Test relationship type to pattern mapping."""
assert "class_with_bases" in get_patterns_for_relationship("inheritance")
assert "import_stmt" in get_patterns_for_relationship("imports")
assert "import_from" in get_patterns_for_relationship("imports")
assert "call" in get_patterns_for_relationship("calls")
class TestAstGrepPythonProcessorAvailability:
"""Tests for processor availability."""
def test_is_available_returns_bool(self):
"""Test is_available returns a boolean."""
processor = AstGrepPythonProcessor()
assert isinstance(processor.is_available(), bool)
def test_is_available_matches_global_check(self):
"""Test is_available matches is_astgrep_processor_available."""
processor = AstGrepPythonProcessor()
assert processor.is_available() == is_astgrep_processor_available()
def test_module_level_check(self):
"""Test module-level availability function."""
assert isinstance(is_astgrep_processor_available(), bool)
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestAstGrepPythonProcessorParsing:
"""Tests for Python parsing with ast-grep."""
def test_parse_simple_function(self):
"""Test parsing a simple function definition."""
processor = AstGrepPythonProcessor()
code = "def hello():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert result.language == "python"
assert len(result.symbols) == 1
assert result.symbols[0].name == "hello"
assert result.symbols[0].kind == "function"
def test_parse_class(self):
"""Test parsing a class definition."""
processor = AstGrepPythonProcessor()
code = "class MyClass:\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "MyClass"
assert result.symbols[0].kind == "class"
def test_parse_async_function(self):
"""Test parsing an async function definition."""
processor = AstGrepPythonProcessor()
code = "async def fetch_data():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert len(result.symbols) == 1
assert result.symbols[0].name == "fetch_data"
def test_parse_class_with_inheritance(self):
"""Test parsing class with inheritance."""
processor = AstGrepPythonProcessor()
code = """
class Base:
pass
class Child(Base):
pass
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
names = [s.name for s in result.symbols]
assert "Base" in names
assert "Child" in names
# Check inheritance relationship
inherits = [
r for r in result.relationships
if r.relationship_type.value == "inherits"
]
assert any(r.source_symbol == "Child" for r in inherits)
def test_parse_imports(self):
"""Test parsing import statements."""
processor = AstGrepPythonProcessor()
code = """
import os
from sys import path
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
imports = [
r for r in result.relationships
if r.relationship_type.value == "imports"
]
assert len(imports) >= 1
targets = {r.target_symbol for r in imports}
assert "os" in targets
def test_parse_function_calls(self):
"""Test parsing function calls."""
processor = AstGrepPythonProcessor()
code = """
def main():
print("hello")
len([1, 2, 3])
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
calls = [
r for r in result.relationships
if r.relationship_type.value == "calls"
]
targets = {r.target_symbol for r in calls}
assert "print" in targets
assert "len" in targets
def test_parse_empty_file(self):
"""Test parsing an empty file."""
processor = AstGrepPythonProcessor()
result = processor.parse("", Path("test.py"))
assert result is not None
assert len(result.symbols) == 0
def test_parse_returns_indexed_file(self):
"""Test that parse returns proper IndexedFile structure."""
processor = AstGrepPythonProcessor()
code = "def test():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is not None
assert result.path.endswith("test.py")
assert result.language == "python"
assert isinstance(result.symbols, list)
assert isinstance(result.chunks, list)
assert isinstance(result.relationships, list)
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestAstGrepPythonProcessorRelationships:
"""Tests for relationship extraction."""
def test_inheritance_extraction(self):
"""Test extraction of inheritance relationships."""
processor = AstGrepPythonProcessor()
code = """
class Animal:
pass
class Dog(Animal):
pass
class Cat(Animal):
pass
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
inherits = [
r for r in result.relationships
if r.relationship_type.value == "inherits"
]
# Should have 2 inheritance relationships
assert len(inherits) >= 2
sources = {r.source_symbol for r in inherits}
assert "Dog" in sources
assert "Cat" in sources
def test_call_extraction_skips_self(self):
"""Test that self.method() calls are filtered."""
processor = AstGrepPythonProcessor()
code = """
class Service:
def process(self):
self.internal()
external_call()
def external_call():
pass
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
calls = [
r for r in result.relationships
if r.relationship_type.value == "calls"
]
targets = {r.target_symbol for r in calls}
# self.internal should be filtered
assert "self.internal" not in targets
assert "external_call" in targets
def test_import_with_alias_resolution(self):
"""Test import alias resolution in calls."""
processor = AstGrepPythonProcessor()
code = """
import os.path as osp
def main():
osp.join("a", "b")
"""
result = processor.parse(code, Path("test.py"))
assert result is not None
calls = [
r for r in result.relationships
if r.relationship_type.value == "calls"
]
targets = {r.target_symbol for r in calls}
# Should resolve osp to os.path
assert any("os.path" in t for t in targets)
@pytest.mark.skipif(not ASTGREP_AVAILABLE, reason="ast-grep-py not installed")
class TestAstGrepPythonProcessorRunAstGrep:
"""Tests for run_ast_grep method."""
def test_run_ast_grep_returns_list(self):
"""Test run_ast_grep returns a list."""
processor = AstGrepPythonProcessor()
code = "def hello():\n pass"
processor._binding.parse(code) if processor._binding else None
matches = processor.run_ast_grep(code, "def $NAME($$$PARAMS) $$$BODY")
assert isinstance(matches, list)
def test_run_ast_grep_finds_matches(self):
"""Test run_ast_grep finds expected matches."""
processor = AstGrepPythonProcessor()
code = "def hello():\n pass"
matches = processor.run_ast_grep(code, "def $NAME($$$PARAMS) $$$BODY")
assert len(matches) >= 1
def test_run_ast_grep_empty_code(self):
"""Test run_ast_grep with empty code."""
processor = AstGrepPythonProcessor()
matches = processor.run_ast_grep("", "def $NAME($$$PARAMS) $$$BODY")
assert matches == []
def test_run_ast_grep_no_matches(self):
"""Test run_ast_grep when pattern doesn't match."""
processor = AstGrepPythonProcessor()
code = "x = 1"
matches = processor.run_ast_grep(code, "class $NAME $$$BODY")
assert matches == []
class TestAstGrepPythonProcessorFallback:
"""Tests for fallback behavior when ast-grep unavailable."""
def test_parse_returns_none_when_unavailable(self):
"""Test parse returns None when ast-grep unavailable."""
# This test runs regardless of availability
# When unavailable, should gracefully return None
processor = AstGrepPythonProcessor()
if not processor.is_available():
code = "def test():\n pass"
result = processor.parse(code, Path("test.py"))
assert result is None
def test_run_ast_grep_empty_when_unavailable(self):
"""Test run_ast_grep returns empty list when unavailable."""
processor = AstGrepPythonProcessor()
if not processor.is_available():
matches = processor.run_ast_grep("code", "pattern")
assert matches == []
class TestBaseAstGrepProcessor:
"""Tests for abstract base class."""
def test_cannot_instantiate_base_class(self):
"""Test that BaseAstGrepProcessor cannot be instantiated directly."""
with pytest.raises(TypeError):
BaseAstGrepProcessor("python") # type: ignore[abstract]
def test_subclass_implements_abstract_methods(self):
"""Test that AstGrepPythonProcessor implements all abstract methods."""
processor = AstGrepPythonProcessor()
# Should have process_matches method
assert hasattr(processor, "process_matches")
# Should have parse method
assert hasattr(processor, "parse")
# Check methods are callable
assert callable(processor.process_matches)
assert callable(processor.parse)
class TestPatternIntegration:
"""Tests for pattern module integration with processor."""
def test_processor_uses_pattern_module(self):
"""Verify processor uses patterns from pattern module."""
# The processor should import and use patterns from patterns/python/
from codexlens.parsers.astgrep_processor import get_pattern
# Verify pattern access works
assert get_pattern("class_def") is not None
assert get_pattern("func_def") is not None
def test_pattern_consistency(self):
"""Test pattern definitions are consistent."""
# Patterns used by processor should exist in pattern module
patterns_needed = [
"class_def",
"class_with_bases",
"func_def",
"async_func_def",
"import_stmt",
"import_from",
"call",
]
for pattern_name in patterns_needed:
# Should not raise KeyError
pattern = get_pattern(pattern_name)
assert pattern is not None
assert len(pattern) > 0

View File

@@ -0,0 +1,526 @@
"""Comparison tests for tree-sitter vs ast-grep Python relationship extraction.
Validates that both parsers produce consistent output for Python relationship
extraction (INHERITS, CALL, IMPORTS).
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Set, Tuple
import pytest
from codexlens.config import Config
from codexlens.entities import CodeRelationship, RelationshipType
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
# Sample Python code for testing relationship extraction
SAMPLE_PYTHON_CODE = '''
"""Module docstring."""
import os
import sys
from typing import List, Dict, Optional
from collections import defaultdict as dd
from pathlib import Path as PPath
class BaseClass:
"""Base class."""
def base_method(self):
pass
def another_method(self):
return self.base_method()
class Mixin:
"""Mixin class."""
def mixin_func(self):
return "mixin"
class ChildClass(BaseClass, Mixin):
"""Child class with multiple inheritance."""
def __init__(self):
super().__init__()
self.data = dd(list)
def process(self, items: List[str]) -> Dict[str, int]:
result = {}
for item in items:
result[item] = len(item)
return result
def call_external(self, path: str) -> Optional[str]:
p = PPath(path)
if p.exists():
return str(p.read_text())
return None
def standalone_function():
"""Standalone function."""
data = [1, 2, 3]
return sum(data)
async def async_function():
"""Async function."""
import asyncio
await asyncio.sleep(1)
'''
def relationship_to_tuple(rel: CodeRelationship) -> Tuple[str, str, str, int]:
"""Convert relationship to a comparable tuple.
Returns:
(source_symbol, target_symbol, relationship_type, source_line)
"""
return (
rel.source_symbol,
rel.target_symbol,
rel.relationship_type.value,
rel.source_line,
)
def extract_relationship_tuples(
relationships: List[CodeRelationship],
) -> Set[Tuple[str, str, str]]:
"""Extract relationship tuples without line numbers for comparison.
Returns:
Set of (source_symbol, target_symbol, relationship_type) tuples
"""
return {
(rel.source_symbol, rel.target_symbol, rel.relationship_type.value)
for rel in relationships
}
def filter_by_type(
relationships: List[CodeRelationship],
rel_type: RelationshipType,
) -> List[CodeRelationship]:
"""Filter relationships by type."""
return [r for r in relationships if r.relationship_type == rel_type]
class TestTreeSitterVsAstGrepComparison:
"""Compare tree-sitter and ast-grep Python relationship extraction."""
@pytest.fixture
def sample_path(self, tmp_path: Path) -> Path:
"""Create a temporary Python file with sample code."""
py_file = tmp_path / "sample.py"
py_file.write_text(SAMPLE_PYTHON_CODE)
return py_file
@pytest.fixture
def ts_parser_default(self) -> TreeSitterSymbolParser:
"""Create tree-sitter parser with default config (use_astgrep=False)."""
config = Config()
assert config.use_astgrep is False
return TreeSitterSymbolParser("python", config=config)
@pytest.fixture
def ts_parser_astgrep(self) -> TreeSitterSymbolParser:
"""Create tree-sitter parser with ast-grep enabled."""
config = Config()
config.use_astgrep = True
return TreeSitterSymbolParser("python", config=config)
def test_parser_availability(self, ts_parser_default: TreeSitterSymbolParser) -> None:
"""Test that tree-sitter parser is available."""
assert ts_parser_default.is_available()
def test_astgrep_processor_initialization(
self, ts_parser_astgrep: TreeSitterSymbolParser
) -> None:
"""Test that ast-grep processor is initialized when config enables it."""
# The processor should be initialized (may be None if ast-grep-py not installed)
# This test just verifies the initialization path works
assert ts_parser_astgrep._config is not None
assert ts_parser_astgrep._config.use_astgrep is True
def _skip_if_astgrep_unavailable(
self, ts_parser_astgrep: TreeSitterSymbolParser
) -> None:
"""Skip test if ast-grep is not available."""
if ts_parser_astgrep._astgrep_processor is None:
pytest.skip("ast-grep-py not installed")
def test_parse_returns_valid_result(
self,
ts_parser_default: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that parsing returns a valid IndexedFile."""
source_code = sample_path.read_text()
result = ts_parser_default.parse(source_code, sample_path)
assert result is not None
assert result.language == "python"
assert len(result.symbols) > 0
assert len(result.relationships) > 0
def test_extracted_symbols_match(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that both parsers extract similar symbols."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Compare symbol names
ts_symbols = {s.name for s in result_ts.symbols}
astgrep_symbols = {s.name for s in result_astgrep.symbols}
# Should have the same symbols (classes, functions, methods)
assert ts_symbols == astgrep_symbols
def test_inheritance_relationships(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test INHERITS relationship extraction consistency."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Extract inheritance relationships
ts_inherits = filter_by_type(result_ts.relationships, RelationshipType.INHERITS)
astgrep_inherits = filter_by_type(
result_astgrep.relationships, RelationshipType.INHERITS
)
ts_tuples = extract_relationship_tuples(ts_inherits)
astgrep_tuples = extract_relationship_tuples(astgrep_inherits)
# Both should detect ChildClass(BaseClass, Mixin)
assert ts_tuples == astgrep_tuples
# Verify specific inheritance relationships
expected_inherits = {
("ChildClass", "BaseClass", "inherits"),
("ChildClass", "Mixin", "inherits"),
}
assert ts_tuples == expected_inherits
def test_import_relationships(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test IMPORTS relationship extraction consistency."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Extract import relationships
ts_imports = filter_by_type(result_ts.relationships, RelationshipType.IMPORTS)
astgrep_imports = filter_by_type(
result_astgrep.relationships, RelationshipType.IMPORTS
)
ts_tuples = extract_relationship_tuples(ts_imports)
astgrep_tuples = extract_relationship_tuples(astgrep_imports)
# Compare - should be similar (may differ in exact module representation)
# At minimum, both should detect the top-level imports
ts_modules = {t[1].split(".")[0] for t in ts_tuples}
astgrep_modules = {t[1].split(".")[0] for t in astgrep_tuples}
# Should have imports from: os, sys, typing, collections, pathlib
expected_modules = {"os", "sys", "typing", "collections", "pathlib", "asyncio"}
assert ts_modules >= expected_modules or astgrep_modules >= expected_modules
def test_call_relationships(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test CALL relationship extraction consistency."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
# Extract call relationships
ts_calls = filter_by_type(result_ts.relationships, RelationshipType.CALL)
astgrep_calls = filter_by_type(
result_astgrep.relationships, RelationshipType.CALL
)
# Calls may differ due to scope tracking differences
# Just verify both parsers find call relationships
assert len(ts_calls) > 0
assert len(astgrep_calls) > 0
# Verify specific calls that should be detected
ts_call_targets = {r.target_symbol for r in ts_calls}
astgrep_call_targets = {r.target_symbol for r in astgrep_calls}
# Both should detect at least some common calls
# (exact match not required due to scope tracking differences)
common_targets = ts_call_targets & astgrep_call_targets
assert len(common_targets) > 0
def test_relationship_count_similarity(
self,
ts_parser_default: TreeSitterSymbolParser,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that relationship counts are similar (>95% consistency)."""
self._skip_if_astgrep_unavailable(ts_parser_astgrep)
source_code = sample_path.read_text()
result_ts = ts_parser_default.parse(source_code, sample_path)
result_astgrep = ts_parser_astgrep.parse(source_code, sample_path)
assert result_ts is not None
assert result_astgrep is not None
ts_count = len(result_ts.relationships)
astgrep_count = len(result_astgrep.relationships)
# Calculate consistency percentage
if max(ts_count, astgrep_count) == 0:
consistency = 100.0
else:
consistency = (
min(ts_count, astgrep_count) / max(ts_count, astgrep_count) * 100
)
# Require >95% consistency
assert consistency >= 95.0, (
f"Relationship consistency {consistency:.1f}% below 95% threshold "
f"(tree-sitter: {ts_count}, ast-grep: {astgrep_count})"
)
def test_config_switch_affects_parser(
self, sample_path: Path
) -> None:
"""Test that config.use_astgrep affects which parser is used."""
config_default = Config()
config_astgrep = Config()
config_astgrep.use_astgrep = True
parser_default = TreeSitterSymbolParser("python", config=config_default)
parser_astgrep = TreeSitterSymbolParser("python", config=config_astgrep)
# Default parser should not have ast-grep processor
assert parser_default._astgrep_processor is None
# Ast-grep parser may have processor if ast-grep-py is installed
# (could be None if not installed, which is fine)
if parser_astgrep._astgrep_processor is not None:
# If available, verify it's the right type
from codexlens.parsers.astgrep_processor import AstGrepPythonProcessor
assert isinstance(
parser_astgrep._astgrep_processor, AstGrepPythonProcessor
)
def test_fallback_to_treesitter_on_astgrep_failure(
self,
ts_parser_astgrep: TreeSitterSymbolParser,
sample_path: Path,
) -> None:
"""Test that parser falls back to tree-sitter if ast-grep fails."""
source_code = sample_path.read_text()
# Even with use_astgrep=True, should get valid results
result = ts_parser_astgrep.parse(source_code, sample_path)
# Should always return a valid result (either from ast-grep or tree-sitter fallback)
assert result is not None
assert result.language == "python"
assert len(result.relationships) > 0
class TestSimpleCodeSamples:
"""Test with simple code samples for precise comparison."""
def test_simple_inheritance(self) -> None:
"""Test simple single inheritance."""
code = """
class Parent:
pass
class Child(Parent):
pass
"""
self._compare_parsers(code, expected_inherits={("Child", "Parent")})
def test_multiple_inheritance(self) -> None:
"""Test multiple inheritance."""
code = """
class A:
pass
class B:
pass
class C(A, B):
pass
"""
self._compare_parsers(
code, expected_inherits={("C", "A"), ("C", "B")}
)
def test_simple_imports(self) -> None:
"""Test simple import statements."""
code = """
import os
import sys
"""
config_ts = Config()
config_ag = Config()
config_ag.use_astgrep = True
parser_ts = TreeSitterSymbolParser("python", config=config_ts)
parser_ag = TreeSitterSymbolParser("python", config=config_ag)
tmp_path = Path("test.py")
result_ts = parser_ts.parse(code, tmp_path)
result_ag = parser_ag.parse(code, tmp_path)
assert result_ts is not None
# ast-grep result may be None if not installed
if result_ag is not None:
ts_imports = {
r.target_symbol
for r in result_ts.relationships
if r.relationship_type == RelationshipType.IMPORTS
}
ag_imports = {
r.target_symbol
for r in result_ag.relationships
if r.relationship_type == RelationshipType.IMPORTS
}
assert ts_imports == ag_imports
def test_imports_inside_function(self) -> None:
"""Test simple import inside a function scope is recorded.
Note: tree-sitter parser requires a scope to record imports.
Module-level imports without any function/class are not recorded
because scope_stack is empty at module level.
"""
code = """
def my_function():
import collections
return collections
"""
config_ts = Config()
config_ag = Config()
config_ag.use_astgrep = True
parser_ts = TreeSitterSymbolParser("python", config=config_ts)
parser_ag = TreeSitterSymbolParser("python", config=config_ag)
tmp_path = Path("test.py")
result_ts = parser_ts.parse(code, tmp_path)
result_ag = parser_ag.parse(code, tmp_path)
assert result_ts is not None
# Get import relationship targets
ts_imports = [
r.target_symbol
for r in result_ts.relationships
if r.relationship_type == RelationshipType.IMPORTS
]
# Should have collections
ts_has_collections = any("collections" in t for t in ts_imports)
assert ts_has_collections, f"Expected collections import, got: {ts_imports}"
# If ast-grep is available, verify it also finds the imports
if result_ag is not None:
ag_imports = [
r.target_symbol
for r in result_ag.relationships
if r.relationship_type == RelationshipType.IMPORTS
]
ag_has_collections = any("collections" in t for t in ag_imports)
assert ag_has_collections, f"Expected collections import in ast-grep, got: {ag_imports}"
def _compare_parsers(
self,
code: str,
expected_inherits: Set[Tuple[str, str]],
) -> None:
"""Helper to compare parser outputs for inheritance."""
config_ts = Config()
config_ag = Config()
config_ag.use_astgrep = True
parser_ts = TreeSitterSymbolParser("python", config=config_ts)
parser_ag = TreeSitterSymbolParser("python", config=config_ag)
tmp_path = Path("test.py")
result_ts = parser_ts.parse(code, tmp_path)
assert result_ts is not None
# Verify tree-sitter finds expected inheritance
ts_inherits = {
(r.source_symbol, r.target_symbol)
for r in result_ts.relationships
if r.relationship_type == RelationshipType.INHERITS
}
assert ts_inherits == expected_inherits
# If ast-grep is available, verify it matches
result_ag = parser_ag.parse(code, tmp_path)
if result_ag is not None:
ag_inherits = {
(r.source_symbol, r.target_symbol)
for r in result_ag.relationships
if r.relationship_type == RelationshipType.INHERITS
}
assert ag_inherits == expected_inherits
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,191 @@
"""Tests for ast-grep binding module.
Verifies basic import and functionality of AstGrepBinding.
Run with: python -m pytest tests/test_astgrep_binding.py -v
"""
from __future__ import annotations
import pytest
from pathlib import Path
class TestAstGrepBindingAvailability:
"""Test availability checks."""
def test_is_astgrep_available_function(self):
"""Test is_astgrep_available function returns boolean."""
from codexlens.parsers.astgrep_binding import is_astgrep_available
result = is_astgrep_available()
assert isinstance(result, bool)
def test_get_supported_languages(self):
"""Test get_supported_languages returns expected languages."""
from codexlens.parsers.astgrep_binding import get_supported_languages
languages = get_supported_languages()
assert isinstance(languages, list)
assert "python" in languages
assert "javascript" in languages
assert "typescript" in languages
class TestAstGrepBindingInit:
"""Test AstGrepBinding initialization."""
def test_init_python(self):
"""Test initialization with Python language."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
assert binding.language_id == "python"
def test_init_typescript_with_tsx(self):
"""Test TSX detection from file extension."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("typescript", Path("component.tsx"))
assert binding.language_id == "typescript"
def test_is_available_returns_boolean(self):
"""Test is_available returns boolean."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
result = binding.is_available()
assert isinstance(result, bool)
def _is_astgrep_installed():
"""Check if ast-grep-py is installed."""
try:
import ast_grep_py # noqa: F401
return True
except ImportError:
return False
@pytest.mark.skipif(
not _is_astgrep_installed(),
reason="ast-grep-py not installed"
)
class TestAstGrepBindingWithAstGrep:
"""Tests that require ast-grep-py to be installed."""
def test_parse_simple_python(self):
"""Test parsing simple Python code."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = "x = 1"
result = binding.parse(source)
assert result is True
def test_find_inheritance(self):
"""Test finding class inheritance."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = """
class MyClass(BaseClass):
pass
"""
binding.parse(source)
results = binding.find_inheritance()
assert len(results) >= 0 # May or may not find depending on pattern match
def test_find_calls(self):
"""Test finding function calls."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = """
def foo():
bar()
baz.qux()
"""
binding.parse(source)
results = binding.find_calls()
assert isinstance(results, list)
def test_find_imports(self):
"""Test finding import statements."""
from codexlens.parsers.astgrep_binding import AstGrepBinding
binding = AstGrepBinding("python")
if not binding.is_available():
pytest.skip("ast-grep not available")
source = """
import os
from typing import List
"""
binding.parse(source)
results = binding.find_imports()
assert isinstance(results, list)
def test_basic_import():
"""Test that the module can be imported."""
try:
from codexlens.parsers.astgrep_binding import (
AstGrepBinding,
is_astgrep_available,
get_supported_languages,
ASTGREP_AVAILABLE,
)
assert True
except ImportError as e:
pytest.fail(f"Failed to import astgrep_binding: {e}")
def test_availability_flag():
"""Test ASTGREP_AVAILABLE flag is defined."""
from codexlens.parsers.astgrep_binding import ASTGREP_AVAILABLE
assert isinstance(ASTGREP_AVAILABLE, bool)
if __name__ == "__main__":
# Run basic verification
print("Testing astgrep_binding module...")
from codexlens.parsers.astgrep_binding import (
AstGrepBinding,
is_astgrep_available,
get_supported_languages,
)
print(f"ast-grep available: {is_astgrep_available()}")
print(f"Supported languages: {get_supported_languages()}")
binding = AstGrepBinding("python")
print(f"Python binding available: {binding.is_available()}")
if binding.is_available():
test_code = """
import os
from typing import List
class MyClass(BaseClass):
def method(self):
self.helper()
external_func()
def helper():
pass
"""
binding.parse(test_code)
print(f"Inheritance found: {binding.find_inheritance()}")
print(f"Calls found: {binding.find_calls()}")
print(f"Imports found: {binding.find_imports()}")
else:
print("Note: ast-grep-py not installed. To install:")
print(" pip install ast-grep-py")
print(" Note: May have compatibility issues with Python 3.13")
print("Basic verification complete!")