feat: Add support for Tree-Sitter parsing and enhance SQLite storage performance

This commit is contained in:
catlog22
2025-12-12 18:40:24 +08:00
parent 92d2085b64
commit c42f91a7fe
5 changed files with 519 additions and 27 deletions

View File

@@ -17,6 +17,9 @@ dependencies = [
"rich>=13",
"pydantic>=2.0",
"tree-sitter>=0.20",
"tree-sitter-python>=0.25",
"tree-sitter-javascript>=0.25",
"tree-sitter-typescript>=0.23",
"pathspec>=0.11",
]
@@ -24,6 +27,7 @@ dependencies = [
semantic = [
"numpy>=1.24",
"sentence-transformers>=2.2",
"fastembed>=0.2",
]
[project.urls]
@@ -31,4 +35,3 @@ Homepage = "https://github.com/openai/codex-lens"
[tool.setuptools]
package-dir = { "" = "src" }

View File

@@ -62,25 +62,42 @@ def _iter_source_files(
languages: Optional[List[str]] = None,
) -> Iterable[Path]:
ignore_dirs = {".git", ".venv", "venv", "node_modules", "__pycache__", ".codexlens"}
ignore_patterns = _load_gitignore(base_path)
pathspec = None
if ignore_patterns:
# Cache for PathSpec objects per directory
pathspec_cache: Dict[Path, Optional[Any]] = {}
def get_pathspec_for_dir(dir_path: Path) -> Optional[Any]:
"""Get PathSpec for a directory, loading .gitignore if present."""
if dir_path in pathspec_cache:
return pathspec_cache[dir_path]
ignore_patterns = _load_gitignore(dir_path)
if not ignore_patterns:
pathspec_cache[dir_path] = None
return None
try:
from pathspec import PathSpec
from pathspec.patterns.gitwildmatch import GitWildMatchPattern
pathspec = PathSpec.from_lines(GitWildMatchPattern, ignore_patterns)
pathspec_cache[dir_path] = pathspec
return pathspec
except Exception:
pathspec = None
pathspec_cache[dir_path] = None
return None
for root, dirs, files in os.walk(base_path):
dirs[:] = [d for d in dirs if d not in ignore_dirs and not d.startswith(".")]
root_path = Path(root)
# Get pathspec for current directory
pathspec = get_pathspec_for_dir(root_path)
for file in files:
if file.startswith("."):
continue
full_path = root_path / file
rel = full_path.relative_to(base_path)
rel = full_path.relative_to(root_path)
if pathspec and pathspec.match_file(str(rel)):
continue
language_id = config.language_for_path(full_path)
@@ -112,6 +129,25 @@ def _get_store_for_path(path: Path, use_global: bool = False) -> tuple[SQLiteSto
return SQLiteStore(config.db_path), config.db_path
def _is_safe_to_clean(target_dir: Path) -> bool:
"""Verify directory is a CodexLens directory before deletion.
Checks for presence of .codexlens directory or index.db file.
"""
if not target_dir.exists():
return True
# Check if it's the .codexlens directory itself
if target_dir.name == ".codexlens":
# Verify it contains index.db or cache directory
return (target_dir / "index.db").exists() or (target_dir / "cache").exists()
# Check if it contains .codexlens subdirectory
return (target_dir / ".codexlens").exists()
@app.command()
def init(
path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to index."),
@@ -469,12 +505,16 @@ def clean(
config = Config()
import shutil
if config.index_dir.exists():
if not _is_safe_to_clean(config.index_dir):
raise CodexLensError(f"Safety check failed: {config.index_dir} does not appear to be a CodexLens directory")
shutil.rmtree(config.index_dir)
result = {"cleaned": str(config.index_dir), "type": "global"}
else:
workspace = WorkspaceConfig.from_path(base_path)
if workspace and workspace.codexlens_dir.exists():
import shutil
if not _is_safe_to_clean(workspace.codexlens_dir):
raise CodexLensError(f"Safety check failed: {workspace.codexlens_dir} does not appear to be a CodexLens directory")
shutil.rmtree(workspace.codexlens_dir)
result = {"cleaned": str(workspace.codexlens_dir), "type": "workspace"}
else:

View File

@@ -1,8 +1,8 @@
"""Parser factory for CodexLens.
The project currently ships lightweight regex-based parsers per language.
They can be swapped for tree-sitter based parsers later without changing
CLI or storage interfaces.
Python and JavaScript/TypeScript parsing use Tree-Sitter grammars when
available. Regex fallbacks are retained to preserve the existing parser
interface and behavior in minimal environments.
"""
from __future__ import annotations
@@ -10,7 +10,16 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Protocol
from typing import Dict, Iterable, List, Optional, Protocol
try:
from tree_sitter import Language as TreeSitterLanguage
from tree_sitter import Node as TreeSitterNode
from tree_sitter import Parser as TreeSitterParser
except Exception: # pragma: no cover
TreeSitterLanguage = None # type: ignore[assignment]
TreeSitterNode = None # type: ignore[assignment]
TreeSitterParser = None # type: ignore[assignment]
from codexlens.config import Config
from codexlens.entities import IndexedFile, Symbol
@@ -25,11 +34,10 @@ class SimpleRegexParser:
language_id: str
def parse(self, text: str, path: Path) -> IndexedFile:
symbols: List[Symbol] = []
if self.language_id == "python":
symbols = _parse_python_symbols(text)
elif self.language_id in {"javascript", "typescript"}:
symbols = _parse_js_ts_symbols(text)
symbols = _parse_js_ts_symbols(text, self.language_id, path)
elif self.language_id == "java":
symbols = _parse_java_symbols(text)
elif self.language_id == "go":
@@ -57,24 +65,135 @@ class ParserFactory:
_PY_CLASS_RE = re.compile(r"^\s*class\s+([A-Za-z_]\w*)\b")
_PY_DEF_RE = re.compile(r"^\s*def\s+([A-Za-z_]\w*)\s*\(")
_PY_DEF_RE = re.compile(r"^\s*(?:async\s+)?def\s+([A-Za-z_]\w*)\s*\(")
_TREE_SITTER_LANGUAGE_CACHE: Dict[str, TreeSitterLanguage] = {}
def _parse_python_symbols(text: str) -> List[Symbol]:
def _get_tree_sitter_language(language_id: str, path: Path | None = None) -> TreeSitterLanguage | None:
if TreeSitterLanguage is None:
return None
cache_key = language_id
if language_id == "typescript" and path is not None and path.suffix.lower() == ".tsx":
cache_key = "tsx"
cached = _TREE_SITTER_LANGUAGE_CACHE.get(cache_key)
if cached is not None:
return cached
try:
if cache_key == "python":
import tree_sitter_python # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_python.language())
elif cache_key == "javascript":
import tree_sitter_javascript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_javascript.language())
elif cache_key == "typescript":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_typescript())
elif cache_key == "tsx":
import tree_sitter_typescript # type: ignore[import-not-found]
language = TreeSitterLanguage(tree_sitter_typescript.language_tsx())
else:
return None
except Exception:
return None
_TREE_SITTER_LANGUAGE_CACHE[cache_key] = language
return language
def _iter_tree_sitter_nodes(root: TreeSitterNode) -> Iterable[TreeSitterNode]:
stack: List[TreeSitterNode] = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(source_bytes: bytes, node: TreeSitterNode) -> str:
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
def _node_range(node: TreeSitterNode) -> tuple[int, int]:
start_line = node.start_point[0] + 1
end_line = node.end_point[0] + 1
return (start_line, max(start_line, end_line))
def _python_kind_for_function_node(node: TreeSitterNode) -> str:
parent = node.parent
while parent is not None:
if parent.type in {"function_definition", "async_function_definition"}:
return "function"
if parent.type == "class_definition":
return "method"
parent = parent.parent
return "function"
def _parse_python_symbols_tree_sitter(text: str) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language("python")
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind=_python_kind_for_function_node(node),
range=_node_range(node),
))
return symbols
def _parse_python_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
current_class_indent: Optional[int] = None
for i, line in enumerate(text.splitlines(), start=1):
if _PY_CLASS_RE.match(line):
name = _PY_CLASS_RE.match(line).group(1)
class_match = _PY_CLASS_RE.match(line)
if class_match:
current_class_indent = len(line) - len(line.lstrip(" "))
symbols.append(Symbol(name=name, kind="class", range=(i, i)))
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
continue
def_match = _PY_DEF_RE.match(line)
if def_match:
name = def_match.group(1)
indent = len(line) - len(line.lstrip(" "))
kind = "method" if current_class_indent is not None and indent > current_class_indent else "function"
symbols.append(Symbol(name=name, kind=kind, range=(i, i)))
symbols.append(Symbol(name=def_match.group(1), kind=kind, range=(i, i)))
continue
if current_class_indent is not None:
indent = len(line) - len(line.lstrip(" "))
@@ -83,23 +202,153 @@ def _parse_python_symbols(text: str) -> List[Symbol]:
return symbols
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
def _parse_python_symbols(text: str) -> List[Symbol]:
symbols = _parse_python_symbols_tree_sitter(text)
if symbols is not None:
return symbols
return _parse_python_symbols_regex(text)
_JS_FUNC_RE = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(")
_JS_CLASS_RE = re.compile(r"^\s*(?:export\s+)?class\s+([A-Za-z_$][\w$]*)\b")
_JS_ARROW_RE = re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?\(?[^)]*\)?\s*=>"
)
_JS_METHOD_RE = re.compile(r"^\s+(?:async\s+)?([A-Za-z_$][\w$]*)\s*\([^)]*\)\s*\{")
def _parse_js_ts_symbols(text: str) -> List[Symbol]:
def _js_has_class_ancestor(node: TreeSitterNode) -> bool:
parent = node.parent
while parent is not None:
if parent.type in {"class_declaration", "class"}:
return True
parent = parent.parent
return False
def _parse_js_ts_symbols_tree_sitter(
text: str,
language_id: str,
path: Path | None = None,
) -> List[Symbol] | None:
if TreeSitterParser is None:
return None
language = _get_tree_sitter_language(language_id, path)
if language is None:
return None
parser = TreeSitterParser()
if hasattr(parser, "set_language"):
parser.set_language(language) # type: ignore[attr-defined]
else:
parser.language = language # type: ignore[assignment]
source_bytes = text.encode("utf8")
tree = parser.parse(source_bytes)
root = tree.root_node
symbols: List[Symbol] = []
for node in _iter_tree_sitter_nodes(root):
if node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="class",
range=_node_range(node),
))
elif node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node is None:
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if (
name_node is None
or value_node is None
or name_node.type not in {"identifier", "property_identifier"}
or value_node.type != "arrow_function"
):
continue
symbols.append(Symbol(
name=_node_text(source_bytes, name_node),
kind="function",
range=_node_range(node),
))
elif node.type == "method_definition" and _js_has_class_ancestor(node):
name_node = node.child_by_field_name("name")
if name_node is None:
continue
name = _node_text(source_bytes, name_node)
if name == "constructor":
continue
symbols.append(Symbol(
name=name,
kind="method",
range=_node_range(node),
))
return symbols
def _parse_js_ts_symbols_regex(text: str) -> List[Symbol]:
symbols: List[Symbol] = []
in_class = False
class_brace_depth = 0
brace_depth = 0
for i, line in enumerate(text.splitlines(), start=1):
brace_depth += line.count("{") - line.count("}")
class_match = _JS_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
in_class = True
class_brace_depth = brace_depth
continue
if in_class and brace_depth < class_brace_depth:
in_class = False
func_match = _JS_FUNC_RE.match(line)
if func_match:
symbols.append(Symbol(name=func_match.group(1), kind="function", range=(i, i)))
continue
class_match = _JS_CLASS_RE.match(line)
if class_match:
symbols.append(Symbol(name=class_match.group(1), kind="class", range=(i, i)))
arrow_match = _JS_ARROW_RE.match(line)
if arrow_match:
symbols.append(Symbol(name=arrow_match.group(1), kind="function", range=(i, i)))
continue
if in_class:
method_match = _JS_METHOD_RE.match(line)
if method_match:
name = method_match.group(1)
if name != "constructor":
symbols.append(Symbol(name=name, kind="method", range=(i, i)))
return symbols
def _parse_js_ts_symbols(
text: str,
language_id: str = "javascript",
path: Path | None = None,
) -> List[Symbol]:
symbols = _parse_js_ts_symbols_tree_sitter(text, language_id, path)
if symbols is not None:
return symbols
return _parse_js_ts_symbols_regex(text)
_JAVA_CLASS_RE = re.compile(r"^\s*(?:public\s+)?class\s+([A-Za-z_]\w*)\b")
_JAVA_METHOD_RE = re.compile(
r"^\s*(?:public|private|protected|static|\s)+[\w<>\[\]]+\s+([A-Za-z_]\w*)\s*\("
@@ -151,4 +400,3 @@ def _parse_generic_symbols(text: str) -> List[Symbol]:
if def_match:
symbols.append(Symbol(name=def_match.group(1), kind="function", range=(i, i)))
return symbols

View File

@@ -118,6 +118,59 @@ class SQLiteStore:
)
conn.commit()
def add_files(self, files_data: List[tuple[IndexedFile, str]]) -> None:
"""Add multiple files in a single transaction for better performance.
Args:
files_data: List of (indexed_file, content) tuples
"""
with self._lock:
conn = self._get_connection()
try:
conn.execute("BEGIN")
for indexed_file, content in files_data:
path = str(Path(indexed_file.path).resolve())
language = indexed_file.language
mtime = Path(path).stat().st_mtime if Path(path).exists() else None
line_count = content.count(chr(10)) + 1
conn.execute(
"""
INSERT INTO files(path, language, content, mtime, line_count)
VALUES(?, ?, ?, ?, ?)
ON CONFLICT(path) DO UPDATE SET
language=excluded.language,
content=excluded.content,
mtime=excluded.mtime,
line_count=excluded.line_count
""",
(path, language, content, mtime, line_count),
)
row = conn.execute("SELECT id FROM files WHERE path=?", (path,)).fetchone()
if not row:
raise StorageError(f"Failed to read file id for {path}")
file_id = int(row["id"])
conn.execute("DELETE FROM symbols WHERE file_id=?", (file_id,))
if indexed_file.symbols:
conn.executemany(
"""
INSERT INTO symbols(file_id, name, kind, start_line, end_line)
VALUES(?, ?, ?, ?, ?)
""",
[
(file_id, s.name, s.kind, s.range[0], s.range[1])
for s in indexed_file.symbols
],
)
conn.commit()
except Exception:
conn.rollback()
raise
def remove_file(self, path: str | Path) -> bool:
"""Remove a file from the index."""
with self._lock:
@@ -178,7 +231,7 @@ class SQLiteStore:
results: List[SearchResult] = []
for row in rows:
rank = float(row["rank"]) if row["rank"] is not None else 0.0
score = max(0.0, -rank)
score = abs(rank) if rank < 0 else 0.0
results.append(
SearchResult(
path=row["path"],

View File

@@ -0,0 +1,148 @@
"""Tests for CodexLens parsers."""
from pathlib import Path
import pytest
from codexlens.parsers.factory import (
SimpleRegexParser,
_parse_js_ts_symbols,
_parse_python_symbols,
)
TREE_SITTER_JS_AVAILABLE = True
try:
import tree_sitter_javascript # type: ignore[import-not-found] # noqa: F401
except Exception:
TREE_SITTER_JS_AVAILABLE = False
class TestPythonParser:
"""Tests for Python symbol parsing."""
def test_parse_function(self):
code = "def hello():\n pass"
symbols = _parse_python_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "hello"
assert symbols[0].kind == "function"
def test_parse_async_function(self):
code = "async def fetch_data():\n pass"
symbols = _parse_python_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "fetch_data"
assert symbols[0].kind == "function"
def test_parse_class(self):
code = "class MyClass:\n pass"
symbols = _parse_python_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "MyClass"
assert symbols[0].kind == "class"
def test_parse_method(self):
code = "class MyClass:\n def method(self):\n pass"
symbols = _parse_python_symbols(code)
assert len(symbols) == 2
assert symbols[0].name == "MyClass"
assert symbols[0].kind == "class"
assert symbols[1].name == "method"
assert symbols[1].kind == "method"
def test_parse_async_method(self):
code = "class MyClass:\n async def async_method(self):\n pass"
symbols = _parse_python_symbols(code)
assert len(symbols) == 2
assert symbols[1].name == "async_method"
assert symbols[1].kind == "method"
class TestJavaScriptParser:
"""Tests for JavaScript/TypeScript symbol parsing."""
def test_parse_function(self):
code = "function hello() {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "hello"
assert symbols[0].kind == "function"
def test_parse_async_function(self):
code = "async function fetchData() {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "fetchData"
assert symbols[0].kind == "function"
def test_parse_arrow_function(self):
code = "const hello = () => {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "hello"
assert symbols[0].kind == "function"
def test_parse_async_arrow_function(self):
code = "const fetchData = async () => {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "fetchData"
assert symbols[0].kind == "function"
def test_parse_class(self):
code = "class MyClass {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "MyClass"
assert symbols[0].kind == "class"
def test_parse_export_function(self):
code = "export function hello() {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "hello"
assert symbols[0].kind == "function"
def test_parse_export_class(self):
code = "export class MyClass {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "MyClass"
assert symbols[0].kind == "class"
def test_parse_export_arrow_function(self):
code = "export const hello = () => {}"
symbols = _parse_js_ts_symbols(code)
assert len(symbols) == 1
assert symbols[0].name == "hello"
assert symbols[0].kind == "function"
@pytest.mark.skipif(not TREE_SITTER_JS_AVAILABLE, reason="tree-sitter-javascript not installed")
def test_parse_class_methods(self):
code = (
"class MyClass {\n"
" method() {}\n"
" async asyncMethod() {}\n"
" static staticMethod() {}\n"
" constructor() {}\n"
"}"
)
symbols = _parse_js_ts_symbols(code)
names_kinds = [(s.name, s.kind) for s in symbols]
assert ("MyClass", "class") in names_kinds
assert ("method", "method") in names_kinds
assert ("asyncMethod", "method") in names_kinds
assert ("staticMethod", "method") in names_kinds
assert all(name != "constructor" for name, _ in names_kinds)
class TestParserInterface:
"""High-level interface tests."""
def test_simple_parser_parse(self):
parser = SimpleRegexParser("python")
indexed = parser.parse("def hello():\n pass", Path("test.py"))
assert indexed.language == "python"
assert len(indexed.symbols) == 1
assert indexed.symbols[0].name == "hello"