From c42f91a7fe3b05f45adfd70052ef7ccc8fccfe77 Mon Sep 17 00:00:00 2001 From: catlog22 Date: Fri, 12 Dec 2025 18:40:24 +0800 Subject: [PATCH] feat: Add support for Tree-Sitter parsing and enhance SQLite storage performance --- codex-lens/pyproject.toml | 5 +- codex-lens/src/codexlens/cli/commands.py | 52 +++- codex-lens/src/codexlens/parsers/factory.py | 286 ++++++++++++++++-- .../src/codexlens/storage/sqlite_store.py | 55 +++- codex-lens/tests/test_parsers.py | 148 +++++++++ 5 files changed, 519 insertions(+), 27 deletions(-) create mode 100644 codex-lens/tests/test_parsers.py diff --git a/codex-lens/pyproject.toml b/codex-lens/pyproject.toml index 038bad66..8a7ae599 100644 --- a/codex-lens/pyproject.toml +++ b/codex-lens/pyproject.toml @@ -17,6 +17,9 @@ dependencies = [ "rich>=13", "pydantic>=2.0", "tree-sitter>=0.20", + "tree-sitter-python>=0.25", + "tree-sitter-javascript>=0.25", + "tree-sitter-typescript>=0.23", "pathspec>=0.11", ] @@ -24,6 +27,7 @@ dependencies = [ semantic = [ "numpy>=1.24", "sentence-transformers>=2.2", + "fastembed>=0.2", ] [project.urls] @@ -31,4 +35,3 @@ Homepage = "https://github.com/openai/codex-lens" [tool.setuptools] package-dir = { "" = "src" } - diff --git a/codex-lens/src/codexlens/cli/commands.py b/codex-lens/src/codexlens/cli/commands.py index 40123a21..356bdcbf 100644 --- a/codex-lens/src/codexlens/cli/commands.py +++ b/codex-lens/src/codexlens/cli/commands.py @@ -62,25 +62,42 @@ def _iter_source_files( languages: Optional[List[str]] = None, ) -> Iterable[Path]: ignore_dirs = {".git", ".venv", "venv", "node_modules", "__pycache__", ".codexlens"} - ignore_patterns = _load_gitignore(base_path) - pathspec = None - if ignore_patterns: + + # Cache for PathSpec objects per directory + pathspec_cache: Dict[Path, Optional[Any]] = {} + + def get_pathspec_for_dir(dir_path: Path) -> Optional[Any]: + """Get PathSpec for a directory, loading .gitignore if present.""" + if dir_path in pathspec_cache: + return pathspec_cache[dir_path] + + ignore_patterns = _load_gitignore(dir_path) + if not ignore_patterns: + pathspec_cache[dir_path] = None + return None + try: from pathspec import PathSpec from pathspec.patterns.gitwildmatch import GitWildMatchPattern - pathspec = PathSpec.from_lines(GitWildMatchPattern, ignore_patterns) + pathspec_cache[dir_path] = pathspec + return pathspec except Exception: - pathspec = None + pathspec_cache[dir_path] = None + return None for root, dirs, files in os.walk(base_path): dirs[:] = [d for d in dirs if d not in ignore_dirs and not d.startswith(".")] root_path = Path(root) + + # Get pathspec for current directory + pathspec = get_pathspec_for_dir(root_path) + for file in files: if file.startswith("."): continue full_path = root_path / file - rel = full_path.relative_to(base_path) + rel = full_path.relative_to(root_path) if pathspec and pathspec.match_file(str(rel)): continue language_id = config.language_for_path(full_path) @@ -112,6 +129,25 @@ def _get_store_for_path(path: Path, use_global: bool = False) -> tuple[SQLiteSto return SQLiteStore(config.db_path), config.db_path + + +def _is_safe_to_clean(target_dir: Path) -> bool: + """Verify directory is a CodexLens directory before deletion. + + Checks for presence of .codexlens directory or index.db file. + """ + if not target_dir.exists(): + return True + + # Check if it's the .codexlens directory itself + if target_dir.name == ".codexlens": + # Verify it contains index.db or cache directory + return (target_dir / "index.db").exists() or (target_dir / "cache").exists() + + # Check if it contains .codexlens subdirectory + return (target_dir / ".codexlens").exists() + + @app.command() def init( path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to index."), @@ -469,12 +505,16 @@ def clean( config = Config() import shutil if config.index_dir.exists(): + if not _is_safe_to_clean(config.index_dir): + raise CodexLensError(f"Safety check failed: {config.index_dir} does not appear to be a CodexLens directory") shutil.rmtree(config.index_dir) result = {"cleaned": str(config.index_dir), "type": "global"} else: workspace = WorkspaceConfig.from_path(base_path) if workspace and workspace.codexlens_dir.exists(): import shutil + if not _is_safe_to_clean(workspace.codexlens_dir): + raise CodexLensError(f"Safety check failed: {workspace.codexlens_dir} does not appear to be a CodexLens directory") shutil.rmtree(workspace.codexlens_dir) result = {"cleaned": str(workspace.codexlens_dir), "type": "workspace"} else: diff --git a/codex-lens/src/codexlens/parsers/factory.py b/codex-lens/src/codexlens/parsers/factory.py index 692d1be8..9f793d10 100644 --- a/codex-lens/src/codexlens/parsers/factory.py +++ b/codex-lens/src/codexlens/parsers/factory.py @@ -1,8 +1,8 @@ """Parser factory for CodexLens. -The project currently ships lightweight regex-based parsers per language. -They can be swapped for tree-sitter based parsers later without changing -CLI or storage interfaces. +Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when +available. Regex fallbacks are retained to preserve the existing parser +interface and behavior in minimal environments. """ from __future__ import annotations @@ -10,7 +10,16 @@ from __future__ import annotations import re from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Protocol +from typing import Dict, Iterable, List, Optional, Protocol + +try: + from tree_sitter import Language as TreeSitterLanguage + from tree_sitter import Node as TreeSitterNode + from tree_sitter import Parser as TreeSitterParser +except Exception: # pragma: no cover + TreeSitterLanguage = None # type: ignore[assignment] + TreeSitterNode = None # type: ignore[assignment] + TreeSitterParser = None # type: ignore[assignment] from codexlens.config import Config from codexlens.entities import IndexedFile, Symbol @@ -25,11 +34,10 @@ class SimpleRegexParser: language_id: str def parse(self, text: str, path: Path) -> IndexedFile: - symbols: List[Symbol] = [] if self.language_id == "python": symbols = _parse_python_symbols(text) elif self.language_id in {"javascript", "typescript"}: - symbols = _parse_js_ts_symbols(text) + symbols = _parse_js_ts_symbols(text, self.language_id, path) elif self.language_id == "java": symbols = _parse_java_symbols(text) elif self.language_id == "go": @@ -57,24 +65,135 @@ class ParserFactory: _PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b") -_PY_DEF_RE = re.compile(r"^\s*def\s+([A-Za-z_]\w*)\s*\(") +_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(") + +_TREE_SITTER_LANGUAGE_CACHE: Dict[str, TreeSitterLanguage] = {} -def _parse_python_symbols(text: str) -> List[Symbol]: +def _get_tree_sitter_language(language_id: str, path: Path | None = None) -> TreeSitterLanguage | None: + if TreeSitterLanguage is None: + return None + + cache_key = language_id + if language_id == "typescript" and path is not None and path.suffix.lower() == ".tsx": + cache_key = "tsx" + + cached = _TREE_SITTER_LANGUAGE_CACHE.get(cache_key) + if cached is not None: + return cached + + try: + if cache_key == "python": + import tree_sitter_python # type: ignore[import-not-found] + + language = TreeSitterLanguage(tree_sitter_python.language()) + elif cache_key == "javascript": + import tree_sitter_javascript # type: ignore[import-not-found] + + language = TreeSitterLanguage(tree_sitter_javascript.language()) + elif cache_key == "typescript": + import tree_sitter_typescript # type: ignore[import-not-found] + + language = TreeSitterLanguage(tree_sitter_typescript.language_typescript()) + elif cache_key == "tsx": + import tree_sitter_typescript # type: ignore[import-not-found] + + language = TreeSitterLanguage(tree_sitter_typescript.language_tsx()) + else: + return None + except Exception: + return None + + _TREE_SITTER_LANGUAGE_CACHE[cache_key] = language + return language + + +def _iter_tree_sitter_nodes(root: TreeSitterNode) -> Iterable[TreeSitterNode]: + stack: List[TreeSitterNode] = [root] + while stack: + node = stack.pop() + yield node + for child in reversed(node.children): + stack.append(child) + + +def _node_text(source_bytes: bytes, node: TreeSitterNode) -> str: + return source_bytes[node.start_byte:node.end_byte].decode("utf8") + + +def _node_range(node: TreeSitterNode) -> tuple[int, int]: + start_line = node.start_point[0] + 1 + end_line = node.end_point[0] + 1 + return (start_line, max(start_line, end_line)) + + +def _python_kind_for_function_node(node: TreeSitterNode) -> str: + parent = node.parent + while parent is not None: + if parent.type in {"function_definition", "async_function_definition"}: + return "function" + if parent.type == "class_definition": + return "method" + parent = parent.parent + return "function" + + +def _parse_python_symbols_tree_sitter(text: str) -> List[Symbol] | None: + if TreeSitterParser is None: + return None + + language = _get_tree_sitter_language("python") + if language is None: + return None + + parser = TreeSitterParser() + if hasattr(parser, "set_language"): + parser.set_language(language) # type: ignore[attr-defined] + else: + parser.language = language # type: ignore[assignment] + + source_bytes = text.encode("utf8") + tree = parser.parse(source_bytes) + root = tree.root_node + + symbols: List[Symbol] = [] + for node in _iter_tree_sitter_nodes(root): + if node.type == "class_definition": + name_node = node.child_by_field_name("name") + if name_node is None: + continue + symbols.append(Symbol( + name=_node_text(source_bytes, name_node), + kind="class", + range=_node_range(node), + )) + elif node.type in {"function_definition", "async_function_definition"}: + name_node = node.child_by_field_name("name") + if name_node is None: + continue + symbols.append(Symbol( + name=_node_text(source_bytes, name_node), + kind=_python_kind_for_function_node(node), + range=_node_range(node), + )) + + return symbols + + +def _parse_python_symbols_regex(text: str) -> List[Symbol]: symbols: List[Symbol] = [] current_class_indent: Optional[int] = None for i, line in enumerate(text.splitlines(), start=1): - if _PY_CLASS_RE.match(line): - name = _PY_CLASS_RE.match(line).group(1) + class_match = _PY_CLASS_RE.match(line) + if class_match: current_class_indent = len(line) - len(line.lstrip(" ")) - symbols.append(Symbol(name=name, kind="class", range=(i, i))) + symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i))) continue def_match = _PY_DEF_RE.match(line) if def_match: - name = def_match.group(1) indent = len(line) - len(line.lstrip(" ")) kind = "method" if current_class_indent is not None and indent > current_class_indent else "function" - symbols.append(Symbol(name=name, kind=kind, range=(i, i))) + symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i))) continue if current_class_indent is not None: indent = len(line) - len(line.lstrip(" ")) @@ -83,23 +202,153 @@ def _parse_python_symbols(text: str) -> List[Symbol]: return symbols -_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(") +def _parse_python_symbols(text: str) -> List[Symbol]: + symbols = _parse_python_symbols_tree_sitter(text) + if symbols is not None: + return symbols + return _parse_python_symbols_regex(text) + + +_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(") _JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b") +_JS_ARROW_RE = re.compile( + r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>" +) +_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{") -def _parse_js_ts_symbols(text: str) -> List[Symbol]: +def _js_has_class_ancestor(node: TreeSitterNode) -> bool: + parent = node.parent + while parent is not None: + if parent.type in {"class_declaration", "class"}: + return True + parent = parent.parent + return False + + +def _parse_js_ts_symbols_tree_sitter( + text: str, + language_id: str, + path: Path | None = None, +) -> List[Symbol] | None: + if TreeSitterParser is None: + return None + + language = _get_tree_sitter_language(language_id, path) + if language is None: + return None + + parser = TreeSitterParser() + if hasattr(parser, "set_language"): + parser.set_language(language) # type: ignore[attr-defined] + else: + parser.language = language # type: ignore[assignment] + + source_bytes = text.encode("utf8") + tree = parser.parse(source_bytes) + root = tree.root_node + symbols: List[Symbol] = [] + for node in _iter_tree_sitter_nodes(root): + if node.type in {"class_declaration", "class"}: + name_node = node.child_by_field_name("name") + if name_node is None: + continue + symbols.append(Symbol( + name=_node_text(source_bytes, name_node), + kind="class", + range=_node_range(node), + )) + elif node.type in {"function_declaration", "generator_function_declaration"}: + name_node = node.child_by_field_name("name") + if name_node is None: + continue + symbols.append(Symbol( + name=_node_text(source_bytes, name_node), + kind="function", + range=_node_range(node), + )) + elif node.type == "variable_declarator": + name_node = node.child_by_field_name("name") + value_node = node.child_by_field_name("value") + if ( + name_node is None + or value_node is None + or name_node.type not in {"identifier", "property_identifier"} + or value_node.type != "arrow_function" + ): + continue + symbols.append(Symbol( + name=_node_text(source_bytes, name_node), + kind="function", + range=_node_range(node), + )) + elif node.type == "method_definition" and _js_has_class_ancestor(node): + name_node = node.child_by_field_name("name") + if name_node is None: + continue + name = _node_text(source_bytes, name_node) + if name == "constructor": + continue + symbols.append(Symbol( + name=name, + kind="method", + range=_node_range(node), + )) + + return symbols + + +def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]: + symbols: List[Symbol] = [] + in_class = False + class_brace_depth = 0 + brace_depth = 0 + for i, line in enumerate(text.splitlines(), start=1): + brace_depth += line.count("{") - line.count("}") + + class_match = _JS_CLASS_RE.match(line) + if class_match: + symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i))) + in_class = True + class_brace_depth = brace_depth + continue + + if in_class and brace_depth < class_brace_depth: + in_class = False + func_match = _JS_FUNC_RE.match(line) if func_match: symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i))) continue - class_match = _JS_CLASS_RE.match(line) - if class_match: - symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i))) + + arrow_match = _JS_ARROW_RE.match(line) + if arrow_match: + symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i))) + continue + + if in_class: + method_match = _JS_METHOD_RE.match(line) + if method_match: + name = method_match.group(1) + if name != "constructor": + symbols.append(Symbol(name=name, kind="method", range=(i, i))) + return symbols +def _parse_js_ts_symbols( + text: str, + language_id: str = "javascript", + path: Path | None = None, +) -> List[Symbol]: + symbols = _parse_js_ts_symbols_tree_sitter(text, language_id, path) + if symbols is not None: + return symbols + return _parse_js_ts_symbols_regex(text) + + _JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b") _JAVA_METHOD_RE = re.compile( r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\(" @@ -151,4 +400,3 @@ def _parse_generic_symbols(text: str) -> List[Symbol]: if def_match: symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i))) return symbols - diff --git a/codex-lens/src/codexlens/storage/sqlite_store.py b/codex-lens/src/codexlens/storage/sqlite_store.py index c5540767..f600c89d 100644 --- a/codex-lens/src/codexlens/storage/sqlite_store.py +++ b/codex-lens/src/codexlens/storage/sqlite_store.py @@ -118,6 +118,59 @@ class SQLiteStore: ) conn.commit() + def add_files(self, files_data: List[tuple[IndexedFile, str]]) -> None: + """Add multiple files in a single transaction for better performance. + + Args: + files_data: List of (indexed_file, content) tuples + """ + with self._lock: + conn = self._get_connection() + try: + conn.execute("BEGIN") + + for indexed_file, content in files_data: + path = str(Path(indexed_file.path).resolve()) + language = indexed_file.language + mtime = Path(path).stat().st_mtime if Path(path).exists() else None + line_count = content.count(chr(10)) + 1 + + conn.execute( + """ + INSERT INTO files(path, language, content, mtime, line_count) + VALUES(?, ?, ?, ?, ?) + ON CONFLICT(path) DO UPDATE SET + language=excluded.language, + content=excluded.content, + mtime=excluded.mtime, + line_count=excluded.line_count + """, + (path, language, content, mtime, line_count), + ) + + row = conn.execute("SELECT id FROM files WHERE path=?", (path,)).fetchone() + if not row: + raise StorageError(f"Failed to read file id for {path}") + file_id = int(row["id"]) + + conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,)) + if indexed_file.symbols: + conn.executemany( + """ + INSERT INTO symbols(file_id, name, kind, start_line, end_line) + VALUES(?, ?, ?, ?, ?) + """, + [ + (file_id, s.name, s.kind, s.range[0], s.range[1]) + for s in indexed_file.symbols + ], + ) + + conn.commit() + except Exception: + conn.rollback() + raise + def remove_file(self, path: str | Path) -> bool: """Remove a file from the index.""" with self._lock: @@ -178,7 +231,7 @@ class SQLiteStore: results: List[SearchResult] = [] for row in rows: rank = float(row["rank"]) if row["rank"] is not None else 0.0 - score = max(0.0, -rank) + score = abs(rank) if rank < 0 else 0.0 results.append( SearchResult( path=row["path"], diff --git a/codex-lens/tests/test_parsers.py b/codex-lens/tests/test_parsers.py new file mode 100644 index 00000000..43696318 --- /dev/null +++ b/codex-lens/tests/test_parsers.py @@ -0,0 +1,148 @@ +"""Tests for CodexLens parsers.""" + +from pathlib import Path + +import pytest + +from codexlens.parsers.factory import ( + SimpleRegexParser, + _parse_js_ts_symbols, + _parse_python_symbols, +) + + +TREE_SITTER_JS_AVAILABLE = True +try: + import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401 +except Exception: + TREE_SITTER_JS_AVAILABLE = False + + +class TestPythonParser: + """Tests for Python symbol parsing.""" + + def test_parse_function(self): + code = "def hello():\n pass" + symbols = _parse_python_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "hello" + assert symbols[0].kind == "function" + + def test_parse_async_function(self): + code = "async def fetch_data():\n pass" + symbols = _parse_python_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "fetch_data" + assert symbols[0].kind == "function" + + def test_parse_class(self): + code = "class MyClass:\n pass" + symbols = _parse_python_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "MyClass" + assert symbols[0].kind == "class" + + def test_parse_method(self): + code = "class MyClass:\n def method(self):\n pass" + symbols = _parse_python_symbols(code) + assert len(symbols) == 2 + assert symbols[0].name == "MyClass" + assert symbols[0].kind == "class" + assert symbols[1].name == "method" + assert symbols[1].kind == "method" + + def test_parse_async_method(self): + code = "class MyClass:\n async def async_method(self):\n pass" + symbols = _parse_python_symbols(code) + assert len(symbols) == 2 + assert symbols[1].name == "async_method" + assert symbols[1].kind == "method" + + +class TestJavaScriptParser: + """Tests for JavaScript/TypeScript symbol parsing.""" + + def test_parse_function(self): + code = "function hello() {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "hello" + assert symbols[0].kind == "function" + + def test_parse_async_function(self): + code = "async function fetchData() {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "fetchData" + assert symbols[0].kind == "function" + + def test_parse_arrow_function(self): + code = "const hello = () => {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "hello" + assert symbols[0].kind == "function" + + def test_parse_async_arrow_function(self): + code = "const fetchData = async () => {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "fetchData" + assert symbols[0].kind == "function" + + def test_parse_class(self): + code = "class MyClass {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "MyClass" + assert symbols[0].kind == "class" + + def test_parse_export_function(self): + code = "export function hello() {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "hello" + assert symbols[0].kind == "function" + + def test_parse_export_class(self): + code = "export class MyClass {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "MyClass" + assert symbols[0].kind == "class" + + def test_parse_export_arrow_function(self): + code = "export const hello = () => {}" + symbols = _parse_js_ts_symbols(code) + assert len(symbols) == 1 + assert symbols[0].name == "hello" + assert symbols[0].kind == "function" + + @pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed") + def test_parse_class_methods(self): + code = ( + "class MyClass {\n" + " method() {}\n" + " async asyncMethod() {}\n" + " static staticMethod() {}\n" + " constructor() {}\n" + "}" + ) + symbols = _parse_js_ts_symbols(code) + names_kinds = [(s.name, s.kind) for s in symbols] + assert ("MyClass", "class") in names_kinds + assert ("method", "method") in names_kinds + assert ("asyncMethod", "method") in names_kinds + assert ("staticMethod", "method") in names_kinds + assert all(name != "constructor" for name, _ in names_kinds) + + +class TestParserInterface: + """High-level interface tests.""" + + def test_simple_parser_parse(self): + parser = SimpleRegexParser("python") + indexed = parser.parse("def hello():\n pass", Path("test.py")) + assert indexed.language == "python" + assert len(indexed.symbols) == 1 + assert indexed.symbols[0].name == "hello"