Files
Claude-Code-Workflow/codex-lens/src/codexlens/parsers/factory.py

403 lines
14 KiB
Python

"""Parser factory for CodexLens.
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
available. Regex fallbacks are retained to preserve the existing parser
interface and behavior in minimal environments.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Protocol
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
except Exception: # pragma: no cover
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
from codexlens.config import Config
from codexlens.entities import IndexedFile, Symbol
class Parser(Protocol):
def parse(self, text: str, path: Path) -> IndexedFile: ...
@dataclass
class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
if self.language_id == "python":
symbols = _parse_python_symbols(text)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols(text, self.language_id, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
elif self.language_id == "go":
symbols = _parse_go_symbols(text)
else:
symbols = _parse_generic_symbols(text)
return IndexedFile(
path=str(path.resolve()),
language=self.language_id,
symbols=symbols,
chunks=[],
)
class ParserFactory:
def __init__(self, config: Config) -> None:
self.config = config
self._parsers: Dict[str, Parser] = {}
def get_parser(self, language_id: str) -> Parser:
if language_id not in self._parsers:
self._parsers[language_id] = SimpleRegexParser(language_id)
return self._parsers[language_id]
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_TREE_SITTER_LANGUAGE_CACHE: Dict[str, TreeSitterLanguage] = {}
def _get_tree_sitter_language(language_id: str, path: Path | None = None) -> TreeSitterLanguage | None:
if TreeSitterLanguage is None:
return None
cache_key = language_id
if language_id == "typescript" and path is not None and path.suffix.lower() == ".tsx":
cache_key = "tsx"
cached = _TREE_SITTER_LANGUAGE_CACHE.get(cache_key)
if cached is not None:
return cached
try:
if cache_key == "python":
import tree_sitter_python # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_python.language())
elif cache_key == "javascript":
import tree_sitter_javascript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_javascript.language())
elif cache_key == "typescript":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
elif cache_key == "tsx":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
return None
except Exception:
return None
_TREE_SITTER_LANGUAGE_CACHE[cache_key] = language
return language
def _iter_tree_sitter_nodes(root: TreeSitterNode) -> Iterable[TreeSitterNode]:
stack: List[TreeSitterNode] = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(source_bytes: bytes, node: TreeSitterNode) -> str:
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(node: TreeSitterNode) -> tuple[int, int]:
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def _python_kind_for_function_node(node: TreeSitterNode) -> str:
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _parse_python_symbols_tree_sitter(text: str) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language("python")
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind=_python_kind_for_function_node(node),
range=_node_range(node),
))
return symbols
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
current_class_indent: Optional[int] = None
for i, line in enumerate(text.splitlines(), start=1):
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_class_indent = len(line) - len(line.lstrip(" "))
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
indent = len(line) - len(line.lstrip(" "))
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
continue
if current_class_indent is not None:
indent = len(line) - len(line.lstrip(" "))
if line.strip() and indent <= current_class_indent:
current_class_indent = None
return symbols
def _parse_python_symbols(text: str) -> List[Symbol]:
symbols = _parse_python_symbols_tree_sitter(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
def _js_has_class_ancestor(node: TreeSitterNode) -> bool:
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _parse_js_ts_symbols_tree_sitter(
text: str,
language_id: str,
path: Path | None = None,
) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language(language_id, path)
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "method_definition" and _js_has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = _node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=_node_range(node),
))
return symbols
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
class_brace_depth = 0
brace_depth = 0
for i, line in enumerate(text.splitlines(), start=1):
brace_depth += line.count("{") - line.count("}")
class_match = _JS_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
in_class = True
class_brace_depth = brace_depth
continue
if in_class and brace_depth < class_brace_depth:
in_class = False
func_match = _JS_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
continue
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
continue
if in_class:
method_match = _JS_METHOD_RE.match(line)
if method_match:
name = method_match.group(1)
if name != "constructor":
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
return symbols
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Path | None = None,
) -> List[Symbol]:
symbols = _parse_js_ts_symbols_tree_sitter(text, language_id, path)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
)
def _parse_java_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _JAVA_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
method_match = _JAVA_METHOD_RE.match(line)
if method_match:
symbols.append(Symbol(name=method_match.group(1), kind="method", range=(i, i)))
return symbols
_GO_FUNC_RE = re.compile(r"^\s*func\s+(?:\([^)]+\)\s+)?([A-Za-z_]\w*)\s*\(")
_GO_TYPE_RE = re.compile(r"^\s*type\s+([A-Za-z_]\w*)\s+(?:struct|interface)\b")
def _parse_go_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
type_match = _GO_TYPE_RE.match(line)
if type_match:
symbols.append(Symbol(name=type_match.group(1), kind="class", range=(i, i)))
continue
func_match = _GO_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
return symbols
_GENERIC_DEF_RE = re.compile(r"^\s*(?:def|function|func)\s+([A-Za-z_]\w*)\b")
_GENERIC_CLASS_RE = re.compile(r"^\s*(?:class|struct|interface)\s+([A-Za-z_]\w*)\b")
def _parse_generic_symbols(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
for i, line in enumerate(text.splitlines(), start=1):
class_match = _GENERIC_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _GENERIC_DEF_RE.match(line)
if def_match:
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
return symbols