mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-11 02:33:51 +08:00
Add graph expansion and cross-encoder reranking features
- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors. - Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring. - Created migrations to establish necessary database tables for relationships and graph neighbors. - Developed tests for graph expansion functionality, ensuring related results are populated correctly. - Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead. - Updated schema cleanup tests to reflect changes in versioning and deprecated fields. - Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
188
codex-lens/tests/test_graph_expansion.py
Normal file
188
codex-lens/tests/test_graph_expansion.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import CodeRelationship, RelationshipType, SearchResult, Symbol
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
from codexlens.search.graph_expander import GraphExpander
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
from codexlens.storage.index_tree import _compute_graph_neighbors
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def temp_paths() -> Path:
|
||||
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
root = Path(tmpdir.name)
|
||||
yield root
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _create_index_with_neighbors(root: Path) -> tuple[PathMapper, Path, Path]:
|
||||
project_root = root / "project"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_root = root / "indexes"
|
||||
mapper = PathMapper(index_root=index_root)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
index_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
content = "\n".join(
|
||||
[
|
||||
"def a():",
|
||||
" b()",
|
||||
"",
|
||||
"def b():",
|
||||
" c()",
|
||||
"",
|
||||
"def c():",
|
||||
" return 1",
|
||||
"",
|
||||
]
|
||||
)
|
||||
file_path = project_root / "module.py"
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
|
||||
symbols = [
|
||||
Symbol(name="a", kind="function", range=(1, 2), file=str(file_path)),
|
||||
Symbol(name="b", kind="function", range=(4, 5), file=str(file_path)),
|
||||
Symbol(name="c", kind="function", range=(7, 8), file=str(file_path)),
|
||||
]
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol="a",
|
||||
target_symbol="b",
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=str(file_path),
|
||||
target_file=None,
|
||||
source_line=2,
|
||||
),
|
||||
CodeRelationship(
|
||||
source_symbol="b",
|
||||
target_symbol="c",
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=str(file_path),
|
||||
target_file=None,
|
||||
source_line=5,
|
||||
),
|
||||
]
|
||||
|
||||
config = Config(data_dir=root / "data")
|
||||
store = DirIndexStore(index_db_path, config=config)
|
||||
store.initialize()
|
||||
store.add_file(
|
||||
name=file_path.name,
|
||||
full_path=file_path,
|
||||
content=content,
|
||||
language="python",
|
||||
symbols=symbols,
|
||||
relationships=relationships,
|
||||
)
|
||||
_compute_graph_neighbors(store)
|
||||
store.close()
|
||||
|
||||
return mapper, project_root, file_path
|
||||
|
||||
|
||||
def test_graph_neighbors_precomputed_two_hop(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
|
||||
conn = sqlite3.connect(str(index_db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT s1.name AS source_name, s2.name AS neighbor_name, gn.relationship_depth
|
||||
FROM graph_neighbors gn
|
||||
JOIN symbols s1 ON s1.id = gn.source_symbol_id
|
||||
JOIN symbols s2 ON s2.id = gn.neighbor_symbol_id
|
||||
ORDER BY source_name, neighbor_name, relationship_depth
|
||||
"""
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
triples = {(r["source_name"], r["neighbor_name"], int(r["relationship_depth"])) for r in rows}
|
||||
assert ("a", "b", 1) in triples
|
||||
assert ("a", "c", 2) in triples
|
||||
assert ("b", "c", 1) in triples
|
||||
assert ("c", "b", 1) in triples
|
||||
assert file_path.exists()
|
||||
|
||||
|
||||
def test_graph_expander_returns_related_results_with_depth_metadata(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
_ = project_root
|
||||
|
||||
expander = GraphExpander(mapper, config=Config(data_dir=temp_paths / "data", graph_expansion_depth=2))
|
||||
base = SearchResult(
|
||||
path=str(file_path.resolve()),
|
||||
score=1.0,
|
||||
excerpt="",
|
||||
content=None,
|
||||
start_line=1,
|
||||
end_line=2,
|
||||
symbol_name="a",
|
||||
symbol_kind="function",
|
||||
)
|
||||
related = expander.expand([base], depth=2, max_expand=1, max_related=10)
|
||||
|
||||
depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in related}
|
||||
assert depth_by_symbol.get("b") == 1
|
||||
assert depth_by_symbol.get("c") == 2
|
||||
|
||||
|
||||
def test_chain_search_populates_related_results_when_enabled(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
_ = file_path
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
|
||||
config = Config(
|
||||
data_dir=temp_paths / "data",
|
||||
enable_graph_expansion=True,
|
||||
graph_expansion_depth=2,
|
||||
)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
try:
|
||||
options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False)
|
||||
result = engine.search("b", project_root, options)
|
||||
|
||||
assert result.results
|
||||
assert result.results[0].symbol_name == "a"
|
||||
|
||||
depth_by_symbol = {r.symbol_name: r.metadata.get("relationship_depth") for r in result.related_results}
|
||||
assert depth_by_symbol.get("b") == 1
|
||||
assert depth_by_symbol.get("c") == 2
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
|
||||
def test_chain_search_related_results_empty_when_disabled(temp_paths: Path) -> None:
|
||||
mapper, project_root, file_path = _create_index_with_neighbors(temp_paths)
|
||||
_ = file_path
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
|
||||
config = Config(
|
||||
data_dir=temp_paths / "data",
|
||||
enable_graph_expansion=False,
|
||||
)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
try:
|
||||
options = SearchOptions(depth=0, total_limit=10, enable_fuzzy=False)
|
||||
result = engine.search("b", project_root, options)
|
||||
assert result.related_results == []
|
||||
finally:
|
||||
engine.close()
|
||||
|
||||
@@ -869,3 +869,47 @@ class TestHybridSearchAdaptiveWeights:
|
||||
) as rerank_mock:
|
||||
engine_on.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
assert rerank_mock.call_count == 1
|
||||
|
||||
def test_cross_encoder_reranking_enabled(self, tmp_path):
|
||||
"""Cross-encoder stage runs only when explicitly enabled via config."""
|
||||
from unittest.mock import patch
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
class DummyEmbedder:
|
||||
def embed(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return [[1.0, 0.0] for _ in texts]
|
||||
|
||||
class DummyReranker:
|
||||
def score_pairs(self, pairs, batch_size=32):
|
||||
return [0.0 for _ in pairs]
|
||||
|
||||
config = Config(
|
||||
data_dir=tmp_path / "ce",
|
||||
enable_reranking=True,
|
||||
enable_cross_encoder_rerank=True,
|
||||
reranker_top_k=10,
|
||||
)
|
||||
engine = HybridSearchEngine(config=config, embedder=DummyEmbedder())
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.rerank_results",
|
||||
side_effect=lambda q, r, e, top_k=50: r,
|
||||
) as rerank_mock, patch.object(
|
||||
HybridSearchEngine,
|
||||
"_get_cross_encoder_reranker",
|
||||
return_value=DummyReranker(),
|
||||
) as get_ce_mock, patch(
|
||||
"codexlens.search.hybrid_search.cross_encoder_rerank",
|
||||
side_effect=lambda q, r, ce, top_k=50: r,
|
||||
) as ce_mock:
|
||||
engine.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
assert rerank_mock.call_count == 1
|
||||
assert get_ce_mock.call_count == 1
|
||||
assert ce_mock.call_count == 1
|
||||
|
||||
100
codex-lens/tests/test_merkle_detection.py
Normal file
100
codex-lens/tests/test_merkle_detection.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
|
||||
|
||||
def _make_merkle_config(tmp_path: Path) -> Config:
|
||||
data_dir = tmp_path / "data"
|
||||
return Config(
|
||||
data_dir=data_dir,
|
||||
venv_path=data_dir / "venv",
|
||||
enable_merkle_detection=True,
|
||||
)
|
||||
|
||||
|
||||
class TestMerkleDetection:
|
||||
def test_needs_reindex_touch_updates_mtime(self, tmp_path: Path) -> None:
|
||||
config = _make_merkle_config(tmp_path)
|
||||
source_dir = tmp_path / "src"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = source_dir / "a.py"
|
||||
file_path.write_text("print('hi')\n", encoding="utf-8")
|
||||
original_content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
index_db = tmp_path / "_index.db"
|
||||
with DirIndexStore(index_db, config=config) as store:
|
||||
store.add_file(
|
||||
name=file_path.name,
|
||||
full_path=file_path,
|
||||
content=original_content,
|
||||
language="python",
|
||||
symbols=[],
|
||||
)
|
||||
|
||||
stored_mtime_before = store.get_file_mtime(file_path)
|
||||
assert stored_mtime_before is not None
|
||||
|
||||
# Touch file without changing content
|
||||
time.sleep(0.02)
|
||||
file_path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
assert store.needs_reindex(file_path) is False
|
||||
|
||||
stored_mtime_after = store.get_file_mtime(file_path)
|
||||
assert stored_mtime_after is not None
|
||||
assert stored_mtime_after != stored_mtime_before
|
||||
|
||||
current_mtime = file_path.stat().st_mtime
|
||||
assert abs(stored_mtime_after - current_mtime) <= 0.001
|
||||
|
||||
def test_parent_root_changes_when_child_changes(self, tmp_path: Path) -> None:
|
||||
config = _make_merkle_config(tmp_path)
|
||||
|
||||
source_root = tmp_path / "project"
|
||||
child_dir = source_root / "child"
|
||||
child_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
child_file = child_dir / "child.py"
|
||||
child_file.write_text("x = 1\n", encoding="utf-8")
|
||||
|
||||
child_db = tmp_path / "child_index.db"
|
||||
parent_db = tmp_path / "parent_index.db"
|
||||
|
||||
with DirIndexStore(child_db, config=config) as child_store:
|
||||
child_store.add_file(
|
||||
name=child_file.name,
|
||||
full_path=child_file,
|
||||
content=child_file.read_text(encoding="utf-8"),
|
||||
language="python",
|
||||
symbols=[],
|
||||
)
|
||||
child_root_1 = child_store.update_merkle_root()
|
||||
assert child_root_1
|
||||
|
||||
with DirIndexStore(parent_db, config=config) as parent_store:
|
||||
parent_store.register_subdir(name="child", index_path=child_db, files_count=1)
|
||||
parent_root_1 = parent_store.update_merkle_root()
|
||||
assert parent_root_1
|
||||
|
||||
time.sleep(0.02)
|
||||
child_file.write_text("x = 2\n", encoding="utf-8")
|
||||
|
||||
with DirIndexStore(child_db, config=config) as child_store:
|
||||
child_store.add_file(
|
||||
name=child_file.name,
|
||||
full_path=child_file,
|
||||
content=child_file.read_text(encoding="utf-8"),
|
||||
language="python",
|
||||
symbols=[],
|
||||
)
|
||||
child_root_2 = child_store.update_merkle_root()
|
||||
assert child_root_2
|
||||
assert child_root_2 != child_root_1
|
||||
|
||||
with DirIndexStore(parent_db, config=config) as parent_store:
|
||||
parent_root_2 = parent_store.update_merkle_root()
|
||||
assert parent_root_2
|
||||
assert parent_root_2 != parent_root_1
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Tests for performance optimizations in CodexLens storage.
|
||||
"""Tests for performance optimizations in CodexLens.
|
||||
|
||||
This module tests the following optimizations:
|
||||
1. Normalized keywords search (migration_001)
|
||||
2. Optimized path lookup in registry
|
||||
3. Prefix-mode symbol search
|
||||
4. Graph expansion neighbor precompute overhead (<20%)
|
||||
5. Cross-encoder reranking latency (<200ms)
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -479,3 +481,113 @@ class TestPerformanceComparison:
|
||||
print(f" Substring: {substring_time*1000:.3f}ms ({len(substring_results)} results)")
|
||||
print(f" Ratio: {prefix_time/substring_time:.2f}x")
|
||||
print(f" Note: Performance benefits appear with 1000+ symbols")
|
||||
|
||||
|
||||
class TestPerformanceBenchmarks:
|
||||
"""Benchmark-style assertions for key performance requirements."""
|
||||
|
||||
def test_graph_expansion_indexing_overhead_under_20_percent(self, temp_index_db, tmp_path):
|
||||
"""Graph neighbor precompute adds <20% overhead versus indexing baseline."""
|
||||
from codexlens.entities import CodeRelationship, RelationshipType, Symbol
|
||||
from codexlens.storage.index_tree import _compute_graph_neighbors
|
||||
|
||||
store = temp_index_db
|
||||
|
||||
file_count = 60
|
||||
symbols_per_file = 8
|
||||
|
||||
start = time.perf_counter()
|
||||
for file_idx in range(file_count):
|
||||
file_path = tmp_path / f"graph_{file_idx}.py"
|
||||
lines = []
|
||||
for sym_idx in range(symbols_per_file):
|
||||
lines.append(f"def func_{file_idx}_{sym_idx}():")
|
||||
lines.append(f" return {sym_idx}")
|
||||
lines.append("")
|
||||
content = "\n".join(lines)
|
||||
|
||||
symbols = [
|
||||
Symbol(
|
||||
name=f"func_{file_idx}_{sym_idx}",
|
||||
kind="function",
|
||||
range=(sym_idx * 3 + 1, sym_idx * 3 + 2),
|
||||
file=str(file_path),
|
||||
)
|
||||
for sym_idx in range(symbols_per_file)
|
||||
]
|
||||
|
||||
relationships = [
|
||||
CodeRelationship(
|
||||
source_symbol=f"func_{file_idx}_{sym_idx}",
|
||||
target_symbol=f"func_{file_idx}_{sym_idx + 1}",
|
||||
relationship_type=RelationshipType.CALL,
|
||||
source_file=str(file_path),
|
||||
target_file=None,
|
||||
source_line=sym_idx * 3 + 2,
|
||||
)
|
||||
for sym_idx in range(symbols_per_file - 1)
|
||||
]
|
||||
|
||||
store.add_file(
|
||||
name=file_path.name,
|
||||
full_path=file_path,
|
||||
content=content,
|
||||
language="python",
|
||||
symbols=symbols,
|
||||
relationships=relationships,
|
||||
)
|
||||
baseline_time = time.perf_counter() - start
|
||||
|
||||
durations = []
|
||||
for _ in range(3):
|
||||
start = time.perf_counter()
|
||||
_compute_graph_neighbors(store)
|
||||
durations.append(time.perf_counter() - start)
|
||||
graph_time = min(durations)
|
||||
|
||||
# Sanity-check that the benchmark exercised graph neighbor generation.
|
||||
conn = store._get_connection()
|
||||
neighbor_count = conn.execute(
|
||||
"SELECT COUNT(*) as c FROM graph_neighbors"
|
||||
).fetchone()["c"]
|
||||
assert neighbor_count > 0
|
||||
|
||||
assert baseline_time > 0.0
|
||||
overhead_ratio = graph_time / baseline_time
|
||||
assert overhead_ratio < 0.2, (
|
||||
f"Graph neighbor precompute overhead too high: {overhead_ratio:.2%} "
|
||||
f"(baseline={baseline_time:.3f}s, graph={graph_time:.3f}s)"
|
||||
)
|
||||
|
||||
def test_cross_encoder_reranking_latency_under_200ms(self):
|
||||
"""Cross-encoder rerank step completes under 200ms (excluding model load)."""
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import cross_encoder_rerank
|
||||
|
||||
query = "find function"
|
||||
results = [
|
||||
SearchResult(
|
||||
path=f"file_{idx}.py",
|
||||
score=1.0 / (idx + 1),
|
||||
excerpt=f"def func_{idx}():\n return {idx}",
|
||||
symbol_name=f"func_{idx}",
|
||||
symbol_kind="function",
|
||||
)
|
||||
for idx in range(50)
|
||||
]
|
||||
|
||||
class DummyReranker:
|
||||
def score_pairs(self, pairs, batch_size=32):
|
||||
_ = batch_size
|
||||
# Return deterministic pseudo-logits to exercise sigmoid normalization.
|
||||
return [float(i) for i in range(len(pairs))]
|
||||
|
||||
reranker = DummyReranker()
|
||||
|
||||
start = time.perf_counter()
|
||||
reranked = cross_encoder_rerank(query, results, reranker, top_k=50, batch_size=32)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
assert len(reranked) == len(results)
|
||||
assert any(r.metadata.get("cross_encoder_reranked") for r in reranked[:50])
|
||||
assert elapsed_ms < 200.0, f"Cross-encoder rerank too slow: {elapsed_ms:.1f}ms"
|
||||
|
||||
@@ -19,7 +19,7 @@ from codexlens.entities import Symbol
|
||||
|
||||
|
||||
class TestSchemaCleanupMigration:
|
||||
"""Test schema cleanup migration (v4 -> v5)."""
|
||||
"""Test schema cleanup migration (v4 -> latest)."""
|
||||
|
||||
def test_migration_from_v4_to_v5(self):
|
||||
"""Test that migration successfully removes deprecated fields."""
|
||||
@@ -129,10 +129,12 @@ class TestSchemaCleanupMigration:
|
||||
# Now initialize store - this should trigger migration
|
||||
store.initialize()
|
||||
|
||||
# Verify schema version is now 5
|
||||
# Verify schema version is now the latest
|
||||
conn = store._get_connection()
|
||||
version_row = conn.execute("PRAGMA user_version").fetchone()
|
||||
assert version_row[0] == 5, f"Expected schema version 5, got {version_row[0]}"
|
||||
assert version_row[0] == DirIndexStore.SCHEMA_VERSION, (
|
||||
f"Expected schema version {DirIndexStore.SCHEMA_VERSION}, got {version_row[0]}"
|
||||
)
|
||||
|
||||
# Check that deprecated columns are removed
|
||||
# 1. Check semantic_metadata doesn't have keywords column
|
||||
@@ -166,7 +168,7 @@ class TestSchemaCleanupMigration:
|
||||
store.close()
|
||||
|
||||
def test_new_database_has_clean_schema(self):
|
||||
"""Test that new databases are created with clean schema (v5)."""
|
||||
"""Test that new databases are created with clean schema (latest)."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "_index.db"
|
||||
store = DirIndexStore(db_path)
|
||||
@@ -174,9 +176,9 @@ class TestSchemaCleanupMigration:
|
||||
|
||||
conn = store._get_connection()
|
||||
|
||||
# Verify schema version is 5
|
||||
# Verify schema version is the latest
|
||||
version_row = conn.execute("PRAGMA user_version").fetchone()
|
||||
assert version_row[0] == 5
|
||||
assert version_row[0] == DirIndexStore.SCHEMA_VERSION
|
||||
|
||||
# Check that new schema doesn't have deprecated columns
|
||||
cursor = conn.execute("PRAGMA table_info(semantic_metadata)")
|
||||
|
||||
@@ -582,6 +582,7 @@ class TestChainSearchResult:
|
||||
)
|
||||
assert result.query == "test"
|
||||
assert result.results == []
|
||||
assert result.related_results == []
|
||||
assert result.symbols == []
|
||||
assert result.stats.dirs_searched == 0
|
||||
|
||||
|
||||
@@ -1173,6 +1173,7 @@ class TestChainSearchResultExtended:
|
||||
assert result.query == "test query"
|
||||
assert len(result.results) == 1
|
||||
assert len(result.symbols) == 1
|
||||
assert result.related_results == []
|
||||
assert result.stats.dirs_searched == 5
|
||||
|
||||
def test_result_with_empty_collections(self):
|
||||
@@ -1186,5 +1187,6 @@ class TestChainSearchResultExtended:
|
||||
|
||||
assert result.query == "no matches"
|
||||
assert result.results == []
|
||||
assert result.related_results == []
|
||||
assert result.symbols == []
|
||||
assert result.stats.dirs_searched == 0
|
||||
|
||||
@@ -110,6 +110,37 @@ class DataProcessor:
|
||||
assert result is not None
|
||||
assert len(result.symbols) == 0
|
||||
|
||||
def test_extracts_relationships_with_alias_resolution(self):
|
||||
parser = TreeSitterSymbolParser("python")
|
||||
code = """
|
||||
import os.path as osp
|
||||
from math import sqrt as sq
|
||||
|
||||
class Base:
|
||||
pass
|
||||
|
||||
class Child(Base):
|
||||
pass
|
||||
|
||||
def main():
|
||||
osp.join("a", "b")
|
||||
sq(4)
|
||||
"""
|
||||
result = parser.parse(code, Path("test.py"))
|
||||
|
||||
assert result is not None
|
||||
|
||||
rels = [r for r in result.relationships if r.source_symbol == "main"]
|
||||
targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"}
|
||||
assert "os.path.join" in targets
|
||||
assert "math.sqrt" in targets
|
||||
|
||||
inherits = [
|
||||
r for r in result.relationships
|
||||
if r.source_symbol == "Child" and r.relationship_type.value == "inherits"
|
||||
]
|
||||
assert any(r.target_symbol == "Base" for r in inherits)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
|
||||
class TestTreeSitterJavaScriptParser:
|
||||
@@ -175,6 +206,22 @@ export const arrowFunc = () => {}
|
||||
assert "exported" in names
|
||||
assert "arrowFunc" in names
|
||||
|
||||
def test_extracts_relationships_with_import_alias(self):
|
||||
parser = TreeSitterSymbolParser("javascript")
|
||||
code = """
|
||||
import { readFile as rf } from "fs";
|
||||
|
||||
function main() {
|
||||
rf("a");
|
||||
}
|
||||
"""
|
||||
result = parser.parse(code, Path("test.js"))
|
||||
|
||||
assert result is not None
|
||||
rels = [r for r in result.relationships if r.source_symbol == "main"]
|
||||
targets = {r.target_symbol for r in rels if r.relationship_type.value == "calls"}
|
||||
assert "fs.readFile" in targets
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TREE_SITTER_AVAILABLE, reason="tree-sitter not installed")
|
||||
class TestTreeSitterTypeScriptParser:
|
||||
|
||||
Reference in New Issue
Block a user