Refactor code structure and remove redundant changes

This commit is contained in:
catlog22
2026-01-24 14:47:47 +08:00
parent cf5fecd66d
commit f2b0a5bbc9
113 changed files with 43217 additions and 235 deletions

View File

@@ -0,0 +1,8 @@
"""Parsers for CodexLens."""
from __future__ import annotations
from .factory import ParserFactory
__all__ = ["ParserFactory"]

View File

@@ -0,0 +1,202 @@
"""Optional encoding detection module for CodexLens.
Provides automatic encoding detection with graceful fallback to UTF-8.
Install with: pip install codexlens[encoding]
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Tuple, Optional
log = logging.getLogger(__name__)
# Feature flag for encoding detection availability
ENCODING_DETECTION_AVAILABLE = False
_import_error: Optional[str] = None
def _detect_chardet_backend() -> Tuple[bool, Optional[str]]:
"""Detect if chardet or charset-normalizer is available."""
try:
import chardet
return True, None
except ImportError:
pass
try:
from charset_normalizer import from_bytes
return True, None
except ImportError:
pass
return False, "chardet not available. Install with: pip install codexlens[encoding]"
# Initialize on module load
ENCODING_DETECTION_AVAILABLE, _import_error = _detect_chardet_backend()
def check_encoding_available() -> Tuple[bool, Optional[str]]:
"""Check if encoding detection dependencies are available.
Returns:
Tuple of (available, error_message)
"""
return ENCODING_DETECTION_AVAILABLE, _import_error
def detect_encoding(content_bytes: bytes, confidence_threshold: float = 0.7) -> str:
"""Detect encoding from file content bytes.
Uses chardet or charset-normalizer with configurable confidence threshold.
Falls back to UTF-8 if confidence is too low or detection unavailable.
Args:
content_bytes: Raw file content as bytes
confidence_threshold: Minimum confidence (0.0-1.0) to accept detection
Returns:
Detected encoding name (e.g., 'utf-8', 'iso-8859-1', 'gbk')
Returns 'utf-8' as fallback if detection fails or confidence too low
"""
if not ENCODING_DETECTION_AVAILABLE:
log.debug("Encoding detection not available, using UTF-8 fallback")
return "utf-8"
if not content_bytes:
return "utf-8"
try:
# Try chardet first
try:
import chardet
result = chardet.detect(content_bytes)
encoding = result.get("encoding")
confidence = result.get("confidence", 0.0)
if encoding and confidence >= confidence_threshold:
log.debug(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
# Normalize encoding name: replace underscores with hyphens
return encoding.lower().replace('_', '-')
else:
log.debug(
f"Low confidence encoding detection: {encoding} "
f"(confidence: {confidence:.2f}), using UTF-8 fallback"
)
return "utf-8"
except ImportError:
pass
# Fallback to charset-normalizer
try:
from charset_normalizer import from_bytes
results = from_bytes(content_bytes)
if results:
best = results.best()
if best and best.encoding:
log.debug(f"Detected encoding via charset-normalizer: {best.encoding}")
# Normalize encoding name: replace underscores with hyphens
return best.encoding.lower().replace('_', '-')
except ImportError:
pass
except Exception as e:
log.warning(f"Encoding detection failed: {e}, using UTF-8 fallback")
return "utf-8"
def read_file_safe(
path: Path | str,
confidence_threshold: float = 0.7,
max_detection_bytes: int = 100_000
) -> Tuple[str, str]:
"""Read file with automatic encoding detection and safe decoding.
Reads file bytes, detects encoding, and decodes with error replacement
to preserve file structure even with encoding issues.
Args:
path: Path to file to read
confidence_threshold: Minimum confidence for encoding detection
max_detection_bytes: Maximum bytes to use for encoding detection (default 100KB)
Returns:
Tuple of (content, detected_encoding)
- content: Decoded file content (with <20> for unmappable bytes)
- detected_encoding: Detected encoding name
Raises:
OSError: If file cannot be read
IsADirectoryError: If path is a directory
"""
file_path = Path(path) if isinstance(path, str) else path
# Read file bytes
try:
content_bytes = file_path.read_bytes()
except Exception as e:
log.error(f"Failed to read file {file_path}: {e}")
raise
# Detect encoding from first N bytes for performance
detection_sample = content_bytes[:max_detection_bytes] if len(content_bytes) > max_detection_bytes else content_bytes
encoding = detect_encoding(detection_sample, confidence_threshold)
# Decode with error replacement to preserve structure
try:
content = content_bytes.decode(encoding, errors='replace')
log.debug(f"Successfully decoded {file_path} using {encoding}")
return content, encoding
except Exception as e:
# Final fallback to UTF-8 with replacement
log.warning(f"Failed to decode {file_path} with {encoding}, using UTF-8: {e}")
content = content_bytes.decode('utf-8', errors='replace')
return content, 'utf-8'
def is_binary_file(path: Path | str, sample_size: int = 8192) -> bool:
"""Check if file is likely binary by sampling first bytes.
Uses heuristic: if >30% of sample bytes are null or non-text, consider binary.
Args:
path: Path to file to check
sample_size: Number of bytes to sample (default 8KB)
Returns:
True if file appears to be binary, False otherwise
"""
file_path = Path(path) if isinstance(path, str) else path
try:
with file_path.open('rb') as f:
sample = f.read(sample_size)
if not sample:
return False
# Count null bytes and non-printable characters
null_count = sample.count(b'\x00')
non_text_count = sum(1 for byte in sample if byte < 0x20 and byte not in (0x09, 0x0a, 0x0d))
# If >30% null bytes or >50% non-text, consider binary
null_ratio = null_count / len(sample)
non_text_ratio = non_text_count / len(sample)
return null_ratio > 0.3 or non_text_ratio > 0.5
except Exception as e:
log.debug(f"Binary check failed for {file_path}: {e}, assuming text")
return False
__all__ = [
"ENCODING_DETECTION_AVAILABLE",
"check_encoding_available",
"detect_encoding",
"read_file_safe",
"is_binary_file",
]

View File

@@ -0,0 +1,385 @@
"""Parser factory for CodexLens.
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
available. Regex fallbacks are retained to preserve the existing parser
interface and behavior in minimal environments.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Protocol
from codexlens.config import Config
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class Parser(Protocol):
def parse(self, text: str, path: Path) -> IndexedFile: ...
@dataclass
class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
# Try tree-sitter first for supported languages
if self.language_id in {"python", "javascript", "typescript"}:
ts_parser = TreeSitterSymbolParser(self.language_id, path)
if ts_parser.is_available():
indexed = ts_parser.parse(text, path)
if indexed is not None:
return indexed
# Fallback to regex parsing
if self.language_id == "python":
symbols = _parse_python_symbols_regex(text)
relationships = _parse_python_relationships_regex(text, path)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols_regex(text)
relationships = _parse_js_ts_relationships_regex(text, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
relationships = []
elif self.language_id == "go":
symbols = _parse_go_symbols(text)
relationships = []
elif self.language_id == "markdown":
symbols = _parse_markdown_symbols(text)
relationships = []
elif self.language_id == "text":
symbols = _parse_text_symbols(text)
relationships = []
else:
symbols = _parse_generic_symbols(text)
relationships = []
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
class ParserFactory:
def __init__(self, config: Config) -> None:
self.config = config
self._parsers: Dict[str, Parser] = {}
def get_parser(self, language_id: str) -> Parser:
if language_id not in self._parsers:
self._parsers[language_id] = SimpleRegexParser(language_id)
return self._parsers[language_id]
# Regex-based fallback parsers
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_PY_IMPORT_RE = re.compile(r"^(?:from\s+([\w.]+)\s+)?import\s+([\w.,\s]+)")
_PY_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_python_symbols(text: str) -> List[Symbol]:
"""Parse Python symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser("python")
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Optional[Path] = None,
) -> List[Symbol]:
"""Parse JS/TS symbols, using tree-sitter if available, regex fallback."""
ts_parser = TreeSitterSymbolParser(language_id, path)
if ts_parser.is_available():
symbols = ts_parser.parse_symbols(text)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
current_class_indent: Optional[int] = None
for i, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_class_indent = len(line) - len(line.lstrip(" "))
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
indent = len(line) - len(line.lstrip(" "))
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
continue
if current_class_indent is not None:
indent = len(line) - len(line.lstrip(" "))
if line.strip() and indent <= current_class_indent:
current_class_indent = None
return symbols
def _parse_python_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
current_scope = def_match.group(1)
continue
if current_scope is None:
continue
import_match = _PY_IMPORT_RE.search(line)
if import_match:
import_target = import_match.group(1) or import_match.group(2)
if import_target:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_target.strip(),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _PY_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {
"if",
"for",
"while",
"return",
"print",
"len",
"str",
"int",
"float",
"list",
"dict",
"set",
"tuple",
current_scope,
}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
_JS_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]([^'\"]+)['\"]")
_JS_CALL_RE = re.compile(r"(?<![.\w])(\w+)\s*\(")
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
class_brace_depth = 0
brace_depth = 0
for i, line in enumerate(text.splitlines(), start=1):
brace_depth += line.count("{") - line.count("}")
class_match = _JS_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
in_class = True
class_brace_depth = brace_depth
continue
if in_class and brace_depth < class_brace_depth:
in_class = False
func_match = _JS_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
continue
if in_class:
method_match = _JS_METHOD_RE.match(line)
if method_match:
name = method_match.group(1)
if name != "constructor":
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
return symbols
def _parse_js_ts_relationships_regex(text: str, path: Path) -> List[CodeRelationship]:
relationships: List[CodeRelationship] = []
current_scope: str | None = None
source_file = str(path.resolve())
for line_num, line in enumerate(text.splitlines(), start=1):
class_match = _JS_CLASS_RE.match(line)
if class_match:
current_scope = class_match.group(1)
continue
func_match = _JS_FUNC_RE.match(line)
if func_match:
current_scope = func_match.group(1)
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
current_scope = arrow_match.group(1)
continue
if current_scope is None:
continue
import_match = _JS_IMPORT_RE.search(line)
if import_match:
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=import_match.group(1),
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
for call_match in _JS_CALL_RE.finditer(line):
call_name = call_match.group(1)
if call_name in {current_scope}:
continue
relationships.append(
CodeRelationship(
source_symbol=current_scope,
target_symbol=call_name,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=line_num,
)
)
return relationships
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
)
def _parse_java_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _JAVA_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
method_match = _JAVA_METHOD_RE.match(line)
if method_match:
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
return symbols
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
def _parse_go_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
type_match = _GO_TYPE_RE.match(line)
if type_match:
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
continue
func_match = _GO_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
return symbols
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
def _parse_generic_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _GENERIC_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _GENERIC_DEF_RE.match(line)
if def_match:
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
return symbols
# Markdown heading regex: # Heading, ## Heading, etc.
_MD_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$")
def _parse_markdown_symbols(text: str) -> List[Symbol]:
"""Parse Markdown headings as symbols.
Extracts # headings as 'section' symbols with heading level as kind suffix.
"""
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
heading_match = _MD_HEADING_RE.match(line)
if heading_match:
level = len(heading_match.group(1))
title = heading_match.group(2).strip()
# Use 'section' kind with level indicator
kind = f"h{level}"
symbols.append(Symbol(name=title, kind=kind, range=(i, i)))
return symbols
def _parse_text_symbols(text: str) -> List[Symbol]:
"""Parse plain text files - no symbols, just index content."""
# Text files don't have structured symbols, return empty list
# The file content will still be indexed for FTS search
return []

View File

@@ -0,0 +1,98 @@
"""Token counting utilities for CodexLens.
Provides accurate token counting using tiktoken with character count fallback.
"""
from __future__ import annotations
from typing import Optional
try:
import tiktoken
TIKTOKEN_AVAILABLE = True
except ImportError:
TIKTOKEN_AVAILABLE = False
class Tokenizer:
"""Token counter with tiktoken primary and character count fallback."""
def __init__(self, encoding_name: str = "cl100k_base") -> None:
"""Initialize tokenizer.
Args:
encoding_name: Tiktoken encoding name (default: cl100k_base for GPT-4)
"""
self._encoding: Optional[object] = None
self._encoding_name = encoding_name
if TIKTOKEN_AVAILABLE:
try:
self._encoding = tiktoken.get_encoding(encoding_name)
except Exception:
# Fallback to character counting if encoding fails
self._encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Uses tiktoken if available, otherwise falls back to character count / 4.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if not text:
return 0
if self._encoding is not None:
try:
return len(self._encoding.encode(text)) # type: ignore[attr-defined]
except Exception:
# Fall through to character count fallback
pass
# Fallback: rough estimate using character count
# Average of ~4 characters per token for English text
return max(1, len(text) // 4)
def is_using_tiktoken(self) -> bool:
"""Check if tiktoken is being used.
Returns:
True if tiktoken is available and initialized
"""
return self._encoding is not None
# Global default tokenizer instance
_default_tokenizer: Optional[Tokenizer] = None
def get_default_tokenizer() -> Tokenizer:
"""Get the global default tokenizer instance.
Returns:
Shared Tokenizer instance
"""
global _default_tokenizer
if _default_tokenizer is None:
_default_tokenizer = Tokenizer()
return _default_tokenizer
def count_tokens(text: str, tokenizer: Optional[Tokenizer] = None) -> int:
"""Count tokens in text using default or provided tokenizer.
Args:
text: Text to count tokens for
tokenizer: Optional tokenizer instance (uses default if None)
Returns:
Estimated token count
"""
if tokenizer is None:
tokenizer = get_default_tokenizer()
return tokenizer.count_tokens(text)

View File

@@ -0,0 +1,809 @@
"""Tree-sitter based parser for CodexLens.
Provides precise AST-level parsing via tree-sitter.
Note: This module does not provide a regex fallback inside `TreeSitterSymbolParser`.
If tree-sitter (or a language binding) is unavailable, `parse()`/`parse_symbols()`
return `None`; callers should use a regex-based fallback such as
`codexlens.parsers.factory.SimpleRegexParser`.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, IndexedFile, RelationshipType, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
class TreeSitterSymbolParser:
"""Parser using tree-sitter for AST-level symbol extraction."""
def __init__(self, language_id: str, path: Optional[Path] = None) -> None:
"""Initialize tree-sitter parser for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
path: Optional file path for language variant detection (e.g., .tsx)
"""
self.language_id = language_id
self.path = path
self._parser: Optional[object] = None
self._language: Optional[TreeSitterLanguage] = None
self._tokenizer = get_default_tokenizer()
if TREE_SITTER_AVAILABLE:
self._initialize_parser()
def _initialize_parser(self) -> None:
"""Initialize tree-sitter parser and language."""
if TreeSitterParser is None or TreeSitterLanguage is None:
return
try:
# Load language grammar
if self.language_id == "python":
import tree_sitter_python
self._language = TreeSitterLanguage(tree_sitter_python.language())
elif self.language_id == "javascript":
import tree_sitter_javascript
self._language = TreeSitterLanguage(tree_sitter_javascript.language())
elif self.language_id == "typescript":
import tree_sitter_typescript
# Detect TSX files by extension
if self.path is not None and self.path.suffix.lower() == ".tsx":
self._language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
self._language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
else:
return
# Create parser
self._parser = TreeSitterParser()
if hasattr(self._parser, "set_language"):
self._parser.set_language(self._language) # type: ignore[attr-defined]
else:
self._parser.language = self._language # type: ignore[assignment]
except Exception:
# Gracefully handle missing language bindings
self._parser = None
self._language = None
def is_available(self) -> bool:
"""Check if tree-sitter parser is available.
Returns:
True if parser is initialized and ready
"""
return self._parser is not None and self._language is not None
def _parse_tree(self, text: str) -> Optional[tuple[bytes, TreeSitterNode]]:
if not self.is_available() or self._parser is None:
return None
try:
source_bytes = text.encode("utf8")
tree = self._parser.parse(source_bytes) # type: ignore[attr-defined]
return source_bytes, tree.root_node
except Exception:
return None
def parse_symbols(self, text: str) -> Optional[List[Symbol]]:
"""Parse source code and extract symbols without creating IndexedFile.
Args:
text: Source code text
Returns:
List of symbols if parsing succeeds, None if tree-sitter unavailable
"""
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
return self._extract_symbols(source_bytes, root)
except Exception:
# Gracefully handle extraction errors
return None
def parse(self, text: str, path: Path) -> Optional[IndexedFile]:
"""Parse source code and extract symbols.
Args:
text: Source code text
path: File path
Returns:
IndexedFile if parsing succeeds, None if tree-sitter unavailable
"""
parsed = self._parse_tree(text)
if parsed is None:
return None
source_bytes, root = parsed
try:
symbols = self._extract_symbols(source_bytes, root)
relationships = self._extract_relationships(source_bytes, root, path)
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
relationships=relationships,
)
except Exception:
# Gracefully handle parsing errors
return None
def _extract_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of extracted symbols
"""
if self.language_id == "python":
return self._extract_python_symbols(source_bytes, root)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_symbols(source_bytes, root)
else:
return []
def _extract_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, path)
if self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, path)
return []
def _extract_python_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"self", "cls"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"class_definition", "function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "class_definition" and pushed_scope:
superclasses = node.child_by_field_name("superclasses")
if superclasses is not None:
for child in superclasses.children:
dotted = self._python_expression_to_dotted(source_bytes, child)
if not dotted:
continue
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type in {"import_statement", "import_from_statement"}:
updates, imported_targets = self._python_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
if node.type == "call":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._python_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _extract_js_ts_relationships(
self,
source_bytes: bytes,
root: TreeSitterNode,
path: Path,
) -> List[CodeRelationship]:
source_file = str(path.resolve())
relationships: List[CodeRelationship] = []
scope_stack: List[str] = []
alias_stack: List[Dict[str, str]] = [{}]
def record_import(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.IMPORTS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_call(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
base = target_symbol.split(".", 1)[0]
if base in {"this", "super"}:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.CALL,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def record_inherits(target_symbol: str, source_line: int) -> None:
if not target_symbol.strip() or not scope_stack:
return
relationships.append(
CodeRelationship(
source_symbol=scope_stack[-1],
target_symbol=target_symbol,
relationship_type=RelationshipType.INHERITS,
source_file=source_file,
target_file=None,
source_line=source_line,
)
)
def visit(node: TreeSitterNode) -> None:
pushed_scope = False
pushed_aliases = False
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if pushed_scope:
superclass = node.child_by_field_name("superclass")
if superclass is not None:
dotted = self._js_expression_to_dotted(source_bytes, superclass)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_inherits(resolved, self._node_start_line(node))
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type in {"identifier", "property_identifier"}
and value_node.type == "arrow_function"
):
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name:
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is not None:
scope_name = self._node_text(source_bytes, name_node).strip()
if scope_name and scope_name != "constructor":
scope_stack.append(scope_name)
pushed_scope = True
alias_stack.append(dict(alias_stack[-1]))
pushed_aliases = True
if node.type in {"import_declaration", "import_statement"}:
updates, imported_targets = self._js_import_aliases_and_targets(source_bytes, node)
if updates:
alias_stack[-1].update(updates)
for target_symbol in imported_targets:
record_import(target_symbol, self._node_start_line(node))
# Best-effort support for CommonJS require() imports:
# const fs = require("fs")
if node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is not None
and value_node is not None
and name_node.type == "identifier"
and value_node.type == "call_expression"
):
callee = value_node.child_by_field_name("function")
args = value_node.child_by_field_name("arguments")
if (
callee is not None
and self._node_text(source_bytes, callee).strip() == "require"
and args is not None
):
module_name = self._js_first_string_argument(source_bytes, args)
if module_name:
alias_stack[-1][self._node_text(source_bytes, name_node).strip()] = module_name
record_import(module_name, self._node_start_line(node))
if node.type == "call_expression":
fn_node = node.child_by_field_name("function")
if fn_node is not None:
dotted = self._js_expression_to_dotted(source_bytes, fn_node)
if dotted:
resolved = self._resolve_alias_dotted(dotted, alias_stack[-1])
record_call(resolved, self._node_start_line(node))
for child in node.children:
visit(child)
if pushed_aliases:
alias_stack.pop()
if pushed_scope:
scope_stack.pop()
visit(root)
return relationships
def _node_start_line(self, node: TreeSitterNode) -> int:
return node.start_point[0] + 1
def _resolve_alias_dotted(self, dotted: str, aliases: Dict[str, str]) -> str:
dotted = (dotted or "").strip()
if not dotted:
return ""
base, sep, rest = dotted.partition(".")
resolved_base = aliases.get(base, base)
if not rest:
return resolved_base
if resolved_base and rest:
return f"{resolved_base}.{rest}"
return resolved_base
def _python_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"identifier", "dotted_name"}:
return self._node_text(source_bytes, node).strip()
if node.type == "attribute":
obj = node.child_by_field_name("object")
attr = node.child_by_field_name("attribute")
obj_text = self._python_expression_to_dotted(source_bytes, obj) if obj is not None else ""
attr_text = self._node_text(source_bytes, attr).strip() if attr is not None else ""
if obj_text and attr_text:
return f"{obj_text}.{attr_text}"
return obj_text or attr_text
return ""
def _python_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
if node.type == "import_statement":
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
module_name = self._node_text(source_bytes, name_node).strip()
if not module_name:
continue
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else module_name.split(".", 1)[0]
)
if bound_name:
aliases[bound_name] = module_name
targets.append(module_name)
elif child.type == "dotted_name":
module_name = self._node_text(source_bytes, child).strip()
if not module_name:
continue
bound_name = module_name.split(".", 1)[0]
if bound_name:
aliases[bound_name] = bound_name
targets.append(module_name)
if node.type == "import_from_statement":
module_name = ""
module_node = node.child_by_field_name("module_name")
if module_node is None:
for child in node.children:
if child.type == "dotted_name":
module_node = child
break
if module_node is not None:
module_name = self._node_text(source_bytes, module_node).strip()
for child in node.children:
if child.type == "aliased_import":
name_node = child.child_by_field_name("name")
alias_node = child.child_by_field_name("alias")
if name_node is None:
continue
imported_name = self._node_text(source_bytes, name_node).strip()
if not imported_name or imported_name == "*":
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
bound_name = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported_name
)
if bound_name:
aliases[bound_name] = target
targets.append(target)
elif child.type == "identifier":
imported_name = self._node_text(source_bytes, child).strip()
if not imported_name or imported_name in {"from", "import", "*"}:
continue
target = f"{module_name}.{imported_name}" if module_name else imported_name
aliases[imported_name] = target
targets.append(target)
return aliases, targets
def _js_expression_to_dotted(self, source_bytes: bytes, node: TreeSitterNode) -> str:
if node.type in {"this", "super"}:
return node.type
if node.type in {"identifier", "property_identifier"}:
return self._node_text(source_bytes, node).strip()
if node.type == "member_expression":
obj = node.child_by_field_name("object")
prop = node.child_by_field_name("property")
obj_text = self._js_expression_to_dotted(source_bytes, obj) if obj is not None else ""
prop_text = self._js_expression_to_dotted(source_bytes, prop) if prop is not None else ""
if obj_text and prop_text:
return f"{obj_text}.{prop_text}"
return obj_text or prop_text
return ""
def _js_import_aliases_and_targets(
self,
source_bytes: bytes,
node: TreeSitterNode,
) -> tuple[Dict[str, str], List[str]]:
aliases: Dict[str, str] = {}
targets: List[str] = []
module_name = ""
source_node = node.child_by_field_name("source")
if source_node is not None:
module_name = self._node_text(source_bytes, source_node).strip().strip("\"'").strip()
if module_name:
targets.append(module_name)
for child in node.children:
if child.type == "import_clause":
for clause_child in child.children:
if clause_child.type == "identifier":
# Default import: import React from "react"
local = self._node_text(source_bytes, clause_child).strip()
if local and module_name:
aliases[local] = module_name
if clause_child.type == "namespace_import":
# Namespace import: import * as fs from "fs"
name_node = clause_child.child_by_field_name("name")
if name_node is not None and module_name:
local = self._node_text(source_bytes, name_node).strip()
if local:
aliases[local] = module_name
if clause_child.type == "named_imports":
for spec in clause_child.children:
if spec.type != "import_specifier":
continue
name_node = spec.child_by_field_name("name")
alias_node = spec.child_by_field_name("alias")
if name_node is None:
continue
imported = self._node_text(source_bytes, name_node).strip()
if not imported:
continue
local = (
self._node_text(source_bytes, alias_node).strip()
if alias_node is not None
else imported
)
if local and module_name:
aliases[local] = f"{module_name}.{imported}"
targets.append(f"{module_name}.{imported}")
return aliases, targets
def _js_first_string_argument(self, source_bytes: bytes, args_node: TreeSitterNode) -> str:
for child in args_node.children:
if child.type == "string":
return self._node_text(source_bytes, child).strip().strip("\"'").strip()
return ""
def _extract_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract Python symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of Python symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind=self._python_function_kind(node),
range=self._node_range(node),
))
return symbols
def _extract_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[Symbol]:
"""Extract JavaScript/TypeScript symbols from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of JS/TS symbols (classes, functions, methods)
"""
symbols: List[Symbol] = []
for node in self._iter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="class",
range=self._node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=self._node_text(source_bytes, name_node),
kind="function",
range=self._node_range(node),
))
elif node.type == "method_definition" and self._has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = self._node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=self._node_range(node),
))
return symbols
def _python_function_kind(self, node: TreeSitterNode) -> str:
"""Determine if Python function is a method or standalone function.
Args:
node: Function definition node
Returns:
'method' if inside a class, 'function' otherwise
"""
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _has_class_ancestor(self, node: TreeSitterNode) -> bool:
"""Check if node has a class ancestor.
Args:
node: AST node to check
Returns:
True if node is inside a class
"""
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(self, node: TreeSitterNode) -> tuple[int, int]:
"""Get line range for a node.
Args:
node: AST node
Returns:
(start_line, end_line) tuple, 1-based inclusive
"""
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def count_tokens(self, text: str) -> int:
"""Count tokens in text.
Args:
text: Text to count tokens for
Returns:
Token count
"""
return self._tokenizer.count_tokens(text)