mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-09 02:24:11 +08:00
Add comprehensive tests for tokenizer, performance benchmarks, and TreeSitter parser functionality
- Implemented unit tests for the Tokenizer class, covering various text inputs, edge cases, and fallback mechanisms. - Created performance benchmarks comparing tiktoken and pure Python implementations for token counting. - Developed extensive tests for TreeSitterSymbolParser across Python, JavaScript, and TypeScript, ensuring accurate symbol extraction and parsing. - Added configuration documentation for MCP integration and custom prompts, enhancing usability and flexibility. - Introduced a refactor script for GraphAnalyzer to streamline future improvements.
This commit is contained in:
@@ -4,9 +4,10 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SemanticChunk, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,6 +15,7 @@ class ChunkConfig:
|
||||
"""Configuration for chunking strategies."""
|
||||
max_chunk_size: int = 1000 # Max characters per chunk
|
||||
overlap: int = 100 # Overlap for sliding window
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
|
||||
|
||||
@@ -22,6 +24,7 @@ class Chunker:
|
||||
|
||||
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||
self.config = config or ChunkConfig()
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
|
||||
def chunk_by_symbol(
|
||||
self,
|
||||
@@ -29,10 +32,18 @@ class Chunker:
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code by extracted symbols (functions, classes).
|
||||
|
||||
Each symbol becomes one chunk with its full content.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
@@ -47,6 +58,13 @@ class Chunker:
|
||||
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._tokenizer.count_tokens(chunk_content)
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
embedding=None,
|
||||
@@ -58,6 +76,7 @@ class Chunker:
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
|
||||
@@ -68,10 +87,19 @@ class Chunker:
|
||||
content: str,
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
line_mapping: Optional[List[int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code using sliding window approach.
|
||||
|
||||
Used for files without clear symbol boundaries or very long functions.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
line_mapping: Optional list mapping content line indices to original line numbers
|
||||
(1-indexed). If provided, line_mapping[i] is the original line number
|
||||
for the i-th line in content.
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
@@ -92,6 +120,18 @@ class Chunker:
|
||||
chunk_content = "".join(lines[start:end])
|
||||
|
||||
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||
token_count = self._tokenizer.count_tokens(chunk_content)
|
||||
|
||||
# Calculate correct line numbers
|
||||
if line_mapping:
|
||||
# Use line mapping to get original line numbers
|
||||
start_line = line_mapping[start]
|
||||
end_line = line_mapping[end - 1]
|
||||
else:
|
||||
# Default behavior: treat content as starting at line 1
|
||||
start_line = start + 1
|
||||
end_line = end
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
embedding=None,
|
||||
@@ -99,9 +139,10 @@ class Chunker:
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start + 1,
|
||||
"end_line": end,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
chunk_idx += 1
|
||||
@@ -119,12 +160,239 @@ class Chunker:
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk a file using the best strategy.
|
||||
|
||||
Uses symbol-based chunking if symbols available,
|
||||
falls back to sliding window for files without symbols.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
if symbols:
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language)
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
|
||||
return self.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
class DocstringExtractor:
|
||||
"""Extract docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract Python docstrings with their line ranges.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
start_line = i + 1
|
||||
|
||||
if stripped.count(quote_type) >= 2:
|
||||
docstring_content = line
|
||||
end_line = i + 1
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
docstring_lines = [line]
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
docstring_lines.append(lines[i])
|
||||
if quote_type in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
docstring_content = "".join(docstring_lines)
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return docstrings
|
||||
|
||||
@staticmethod
|
||||
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract JSDoc comments with their line ranges.
|
||||
|
||||
Returns: List of (comment_content, start_line, end_line) tuples
|
||||
"""
|
||||
comments: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('/**'):
|
||||
start_line = i + 1
|
||||
comment_lines = [line]
|
||||
i += 1
|
||||
|
||||
while i < len(lines):
|
||||
comment_lines.append(lines[i])
|
||||
if '*/' in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
comment_content = "".join(comment_lines)
|
||||
comments.append((comment_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return comments
|
||||
|
||||
@classmethod
|
||||
def extract_docstrings(
|
||||
cls,
|
||||
content: str,
|
||||
language: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Extract docstrings based on language.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.extract_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.extract_jsdoc_comments(content)
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
|
||||
|
||||
Composition-based strategy that:
|
||||
1. Extracts docstrings as dedicated chunks
|
||||
2. For remaining code, uses base chunker (symbol or sliding window)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_chunker: Chunker | None = None,
|
||||
config: ChunkConfig | None = None
|
||||
) -> None:
|
||||
"""Initialize hybrid chunker.
|
||||
|
||||
Args:
|
||||
base_chunker: Chunker to use for non-docstring content
|
||||
config: Configuration for chunking
|
||||
"""
|
||||
self.config = config or ChunkConfig()
|
||||
self.base_chunker = base_chunker or Chunker(self.config)
|
||||
self.docstring_extractor = DocstringExtractor()
|
||||
|
||||
def _get_excluded_line_ranges(
|
||||
self,
|
||||
docstrings: List[Tuple[str, int, int]]
|
||||
) -> set[int]:
|
||||
"""Get set of line numbers that are part of docstrings."""
|
||||
excluded_lines: set[int] = set()
|
||||
for _, start_line, end_line in docstrings:
|
||||
for line_num in range(start_line, end_line + 1):
|
||||
excluded_lines.add(line_num)
|
||||
return excluded_lines
|
||||
|
||||
def _filter_symbols_outside_docstrings(
|
||||
self,
|
||||
symbols: List[Symbol],
|
||||
excluded_lines: set[int]
|
||||
) -> List[Symbol]:
|
||||
"""Filter symbols to exclude those completely within docstrings."""
|
||||
filtered: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
symbol_lines = set(range(start_line, end_line + 1))
|
||||
if not symbol_lines.issubset(excluded_lines):
|
||||
filtered.append(symbol)
|
||||
return filtered
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk file using hybrid strategy.
|
||||
|
||||
Extracts docstrings first, then chunks remaining code.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
tokenizer = get_default_tokenizer()
|
||||
|
||||
# Step 1: Extract docstrings as dedicated chunks
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
|
||||
for docstring_content, start_line, end_line in docstrings:
|
||||
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||
token_count = tokenizer.count_tokens(docstring_content)
|
||||
chunks.append(SemanticChunk(
|
||||
content=docstring_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
|
||||
# Step 2: Get line ranges occupied by docstrings
|
||||
excluded_lines = self._get_excluded_line_ranges(docstrings)
|
||||
|
||||
# Step 3: Filter symbols to exclude docstring-only ranges
|
||||
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
|
||||
|
||||
# Step 4: Chunk remaining content using base chunker
|
||||
if filtered_symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, filtered_symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
lines = content.splitlines(keepends=True)
|
||||
remaining_lines: List[str] = []
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
if i not in excluded_lines:
|
||||
remaining_lines.append(line)
|
||||
|
||||
if remaining_lines:
|
||||
remaining_content = "".join(remaining_lines)
|
||||
if len(remaining_content.strip()) >= self.config.min_chunk_size:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(
|
||||
remaining_content, file_path, language
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
531
codex-lens/src/codexlens/semantic/graph_analyzer.py
Normal file
531
codex-lens/src/codexlens/semantic/graph_analyzer.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""Graph analyzer for extracting code relationships using tree-sitter.
|
||||
|
||||
Provides AST-based analysis to identify function calls, method invocations,
|
||||
and class inheritance relationships within source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from tree_sitter import Node as TreeSitterNode
|
||||
TREE_SITTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
TreeSitterNode = None # type: ignore[assignment]
|
||||
TREE_SITTER_AVAILABLE = False
|
||||
|
||||
from codexlens.entities import CodeRelationship, Symbol
|
||||
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
|
||||
|
||||
|
||||
class GraphAnalyzer:
|
||||
"""Analyzer for extracting semantic relationships from code using AST traversal."""
|
||||
|
||||
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
|
||||
"""Initialize graph analyzer for a language.
|
||||
|
||||
Args:
|
||||
language_id: Language identifier (python, javascript, typescript, etc.)
|
||||
parser: Optional TreeSitterSymbolParser instance for dependency injection.
|
||||
If None, creates a new parser instance (backward compatibility).
|
||||
"""
|
||||
self.language_id = language_id
|
||||
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if graph analyzer is available.
|
||||
|
||||
Returns:
|
||||
True if tree-sitter parser is initialized and ready
|
||||
"""
|
||||
return self._parser.is_available()
|
||||
|
||||
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
|
||||
"""Analyze source code and extract relationships.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def analyze_with_symbols(
|
||||
self, text: str, file_path: Path, symbols: List[Symbol]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
|
||||
|
||||
Args:
|
||||
text: Source code text
|
||||
file_path: File path for relationship context
|
||||
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
|
||||
|
||||
Returns:
|
||||
List of CodeRelationship objects representing intra-file relationships
|
||||
"""
|
||||
if not self.is_available() or self._parser._parser is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
source_bytes = text.encode("utf8")
|
||||
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
|
||||
root = tree.root_node
|
||||
|
||||
# Convert Symbol objects to internal symbol format
|
||||
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
|
||||
|
||||
# Extract relationships using provided symbols
|
||||
relationships = self._extract_relationships_with_symbols(
|
||||
source_bytes, root, str(file_path.resolve()), defined_symbols
|
||||
)
|
||||
|
||||
return relationships
|
||||
except Exception:
|
||||
# Gracefully handle parsing errors
|
||||
return []
|
||||
|
||||
def _convert_symbols_to_dict(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
|
||||
) -> List[dict]:
|
||||
"""Convert Symbol objects to internal dict format for relationship extraction.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
symbols: Pre-parsed Symbol objects
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbol_dicts = []
|
||||
symbol_names = {s.name for s in symbols}
|
||||
|
||||
# Find AST nodes corresponding to symbols
|
||||
for node in self._iter_nodes(root):
|
||||
node_type = node.type
|
||||
|
||||
# Check if this node matches any of our symbols
|
||||
if node_type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node_type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node_type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node_type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
name = self._node_text(source_bytes, name_node)
|
||||
if name in symbol_names:
|
||||
symbol_dicts.append({
|
||||
"name": name,
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
|
||||
return symbol_dicts
|
||||
|
||||
def _extract_relationships_with_symbols(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST using pre-parsed symbols.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
defined_symbols: Pre-parsed symbol dicts
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# Determine call node type based on language
|
||||
if self.language_id == "python":
|
||||
call_node_type = "call"
|
||||
extract_target = self._extract_call_target
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
call_node_type = "call_expression"
|
||||
extract_target = self._extract_js_call_target
|
||||
else:
|
||||
return []
|
||||
|
||||
# Find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == call_node_type:
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = extract_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of extracted relationships
|
||||
"""
|
||||
if self.language_id == "python":
|
||||
return self._extract_python_relationships(source_bytes, root, file_path)
|
||||
elif self.language_id in {"javascript", "typescript"}:
|
||||
return self._extract_js_ts_relationships(source_bytes, root, file_path)
|
||||
else:
|
||||
return []
|
||||
|
||||
def _extract_python_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract Python relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of Python relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols with their scopes
|
||||
defined_symbols = self._collect_python_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions and match to defined symbols
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call":
|
||||
# Extract caller context (enclosing function/method/class)
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
# Call at module level, use "<module>" as source
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee (function/method being called)
|
||||
target_symbol = self._extract_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None, # Intra-file only
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _extract_js_ts_relationships(
|
||||
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
|
||||
) -> List[CodeRelationship]:
|
||||
"""Extract JavaScript/TypeScript relationships from AST.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
file_path: Absolute file path
|
||||
|
||||
Returns:
|
||||
List of JS/TS relationships (function/method calls)
|
||||
"""
|
||||
relationships: List[CodeRelationship] = []
|
||||
|
||||
# First pass: collect all defined symbols
|
||||
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
|
||||
|
||||
# Second pass: find call expressions
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type == "call_expression":
|
||||
# Extract caller context
|
||||
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
|
||||
if source_symbol is None:
|
||||
source_symbol = "<module>"
|
||||
|
||||
# Extract callee
|
||||
target_symbol = self._extract_js_call_target(source_bytes, node)
|
||||
if target_symbol is None:
|
||||
continue
|
||||
|
||||
# Create relationship
|
||||
line_number = node.start_point[0] + 1
|
||||
relationships.append(
|
||||
CodeRelationship(
|
||||
source_symbol=source_symbol,
|
||||
target_symbol=target_symbol,
|
||||
relationship_type="call",
|
||||
source_file=file_path,
|
||||
target_file=None,
|
||||
source_line=line_number,
|
||||
)
|
||||
)
|
||||
|
||||
return relationships
|
||||
|
||||
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all Python function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_definition", "async_function_definition"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
|
||||
"""Collect all JS/TS function/method/class definitions.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
root: Root AST node
|
||||
|
||||
Returns:
|
||||
List of symbol info dicts with name, node, and type
|
||||
"""
|
||||
symbols = []
|
||||
for node in self._iter_nodes(root):
|
||||
if node.type in {"function_declaration", "generator_function_declaration"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
elif node.type == "method_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "method"
|
||||
})
|
||||
elif node.type in {"class_declaration", "class"}:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "class"
|
||||
})
|
||||
elif node.type == "variable_declarator":
|
||||
name_node = node.child_by_field_name("name")
|
||||
value_node = node.child_by_field_name("value")
|
||||
if name_node and value_node and value_node.type == "arrow_function":
|
||||
symbols.append({
|
||||
"name": self._node_text(source_bytes, name_node),
|
||||
"node": node,
|
||||
"type": "function"
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
|
||||
"""Find the enclosing function/method/class for a node.
|
||||
|
||||
Args:
|
||||
node: AST node to find enclosure for
|
||||
symbols: List of defined symbols
|
||||
|
||||
Returns:
|
||||
Name of enclosing symbol, or None if at module level
|
||||
"""
|
||||
# Walk up the tree to find enclosing symbol
|
||||
parent = node.parent
|
||||
while parent is not None:
|
||||
for symbol in symbols:
|
||||
if symbol["node"] == parent:
|
||||
return symbol["name"]
|
||||
parent = parent.parent
|
||||
return None
|
||||
|
||||
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a Python call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers (e.g., "foo()")
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle attribute access (e.g., "obj.method()")
|
||||
if function_node.type == "attribute":
|
||||
attr_node = function_node.child_by_field_name("attribute")
|
||||
if attr_node:
|
||||
return self._node_text(source_bytes, attr_node)
|
||||
|
||||
return None
|
||||
|
||||
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
|
||||
"""Extract the target function name from a JS/TS call expression.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: Call expression node
|
||||
|
||||
Returns:
|
||||
Target function name, or None if cannot be determined
|
||||
"""
|
||||
function_node = node.child_by_field_name("function")
|
||||
if function_node is None:
|
||||
return None
|
||||
|
||||
# Handle simple identifiers
|
||||
if function_node.type == "identifier":
|
||||
return self._node_text(source_bytes, function_node)
|
||||
|
||||
# Handle member expressions (e.g., "obj.method()")
|
||||
if function_node.type == "member_expression":
|
||||
property_node = function_node.child_by_field_name("property")
|
||||
if property_node:
|
||||
return self._node_text(source_bytes, property_node)
|
||||
|
||||
return None
|
||||
|
||||
def _iter_nodes(self, root: TreeSitterNode):
|
||||
"""Iterate over all nodes in AST.
|
||||
|
||||
Args:
|
||||
root: Root node to start iteration
|
||||
|
||||
Yields:
|
||||
AST nodes in depth-first order
|
||||
"""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
for child in reversed(node.children):
|
||||
stack.append(child)
|
||||
|
||||
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
|
||||
"""Extract text for a node.
|
||||
|
||||
Args:
|
||||
source_bytes: Source code as bytes
|
||||
node: AST node
|
||||
|
||||
Returns:
|
||||
Text content of node
|
||||
"""
|
||||
return source_bytes[node.start_byte:node.end_byte].decode("utf8")
|
||||
@@ -75,6 +75,34 @@ class LLMEnhancer:
|
||||
external LLM tools (gemini, qwen) via CCW CLI subprocess.
|
||||
"""
|
||||
|
||||
CHUNK_REFINEMENT_PROMPT = '''PURPOSE: Identify optimal semantic split points in code chunk
|
||||
TASK:
|
||||
- Analyze the code structure to find natural semantic boundaries
|
||||
- Identify logical groupings (functions, classes, related statements)
|
||||
- Suggest split points that maintain semantic cohesion
|
||||
MODE: analysis
|
||||
EXPECTED: JSON format with split positions
|
||||
|
||||
=== CODE CHUNK ===
|
||||
{code_chunk}
|
||||
|
||||
=== OUTPUT FORMAT ===
|
||||
Return ONLY valid JSON (no markdown, no explanation):
|
||||
{{
|
||||
"split_points": [
|
||||
{{
|
||||
"line": <line_number>,
|
||||
"reason": "brief reason for split (e.g., 'start of new function', 'end of class definition')"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Rules:
|
||||
- Split at function/class/method boundaries
|
||||
- Keep related code together (don't split mid-function)
|
||||
- Aim for chunks between 500-2000 characters
|
||||
- Return empty split_points if no good splits found'''
|
||||
|
||||
PROMPT_TEMPLATE = '''PURPOSE: Generate semantic summaries and search keywords for code files
|
||||
TASK:
|
||||
- For each code block, generate a concise summary (1-2 sentences)
|
||||
@@ -168,42 +196,246 @@ Return ONLY valid JSON (no markdown, no explanation):
|
||||
return results
|
||||
|
||||
def enhance_file(
|
||||
|
||||
self,
|
||||
|
||||
path: str,
|
||||
|
||||
content: str,
|
||||
|
||||
language: str,
|
||||
|
||||
working_dir: Optional[Path] = None,
|
||||
|
||||
) -> SemanticMetadata:
|
||||
|
||||
"""Enhance a single file with LLM-generated semantic metadata.
|
||||
|
||||
|
||||
|
||||
Convenience method that wraps enhance_files for single file processing.
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
path: File path
|
||||
|
||||
content: File content
|
||||
|
||||
language: Programming language
|
||||
|
||||
working_dir: Optional working directory for CCW CLI
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
SemanticMetadata for the file
|
||||
|
||||
|
||||
|
||||
Raises:
|
||||
|
||||
ValueError: If enhancement fails
|
||||
|
||||
"""
|
||||
|
||||
file_data = FileData(path=path, content=content, language=language)
|
||||
|
||||
results = self.enhance_files([file_data], working_dir)
|
||||
|
||||
|
||||
|
||||
if path not in results:
|
||||
|
||||
# Return default metadata if enhancement failed
|
||||
|
||||
return SemanticMetadata(
|
||||
|
||||
summary=f"Code file written in {language}",
|
||||
|
||||
keywords=[language, "code"],
|
||||
|
||||
purpose="unknown",
|
||||
|
||||
file_path=path,
|
||||
|
||||
llm_tool=self.config.tool,
|
||||
|
||||
)
|
||||
|
||||
|
||||
|
||||
return results[path]
|
||||
|
||||
def refine_chunk_boundaries(
|
||||
self,
|
||||
chunk: SemanticChunk,
|
||||
max_chunk_size: int = 2000,
|
||||
working_dir: Optional[Path] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Refine chunk boundaries using LLM for large code chunks.
|
||||
|
||||
Uses LLM to identify semantic split points in large chunks,
|
||||
breaking them into smaller, more cohesive pieces.
|
||||
|
||||
Args:
|
||||
chunk: Original chunk to refine
|
||||
max_chunk_size: Maximum characters before triggering refinement
|
||||
working_dir: Optional working directory for CCW CLI
|
||||
|
||||
Returns:
|
||||
SemanticMetadata for the file
|
||||
|
||||
Raises:
|
||||
ValueError: If enhancement fails
|
||||
List of refined chunks (original chunk if no splits or refinement fails)
|
||||
"""
|
||||
file_data = FileData(path=path, content=content, language=language)
|
||||
results = self.enhance_files([file_data], working_dir)
|
||||
# Skip if chunk is small enough
|
||||
if len(chunk.content) <= max_chunk_size:
|
||||
return [chunk]
|
||||
|
||||
if path not in results:
|
||||
# Return default metadata if enhancement failed
|
||||
return SemanticMetadata(
|
||||
summary=f"Code file written in {language}",
|
||||
keywords=[language, "code"],
|
||||
purpose="unknown",
|
||||
file_path=path,
|
||||
llm_tool=self.config.tool,
|
||||
# Skip if LLM enhancement disabled or unavailable
|
||||
if not self.config.enabled or not self.check_available():
|
||||
return [chunk]
|
||||
|
||||
# Skip docstring chunks - only refine code chunks
|
||||
if chunk.metadata.get("chunk_type") == "docstring":
|
||||
return [chunk]
|
||||
|
||||
try:
|
||||
# Build refinement prompt
|
||||
prompt = self.CHUNK_REFINEMENT_PROMPT.format(code_chunk=chunk.content)
|
||||
|
||||
# Invoke LLM
|
||||
result = self._invoke_ccw_cli(
|
||||
prompt,
|
||||
tool=self.config.tool,
|
||||
working_dir=working_dir,
|
||||
)
|
||||
|
||||
return results[path]
|
||||
# Fallback if primary tool fails
|
||||
if not result["success"] and self.config.fallback_tool:
|
||||
result = self._invoke_ccw_cli(
|
||||
prompt,
|
||||
tool=self.config.fallback_tool,
|
||||
working_dir=working_dir,
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
logger.debug("LLM refinement failed, returning original chunk")
|
||||
return [chunk]
|
||||
|
||||
# Parse split points
|
||||
split_points = self._parse_split_points(result["stdout"])
|
||||
if not split_points:
|
||||
logger.debug("No split points identified, returning original chunk")
|
||||
return [chunk]
|
||||
|
||||
# Split chunk at identified boundaries
|
||||
refined_chunks = self._split_chunk_at_points(chunk, split_points)
|
||||
logger.debug(
|
||||
"Refined chunk into %d smaller chunks (was %d chars)",
|
||||
len(refined_chunks),
|
||||
len(chunk.content),
|
||||
)
|
||||
return refined_chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Chunk refinement error: %s, returning original chunk", e)
|
||||
return [chunk]
|
||||
|
||||
def _parse_split_points(self, stdout: str) -> List[int]:
|
||||
"""Parse split points from LLM response.
|
||||
|
||||
Args:
|
||||
stdout: Raw stdout from CCW CLI
|
||||
|
||||
Returns:
|
||||
List of line numbers where splits should occur (sorted)
|
||||
"""
|
||||
# Extract JSON from response
|
||||
json_str = self._extract_json(stdout)
|
||||
if not json_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
split_points_data = data.get("split_points", [])
|
||||
|
||||
# Extract line numbers
|
||||
lines = []
|
||||
for point in split_points_data:
|
||||
if isinstance(point, dict) and "line" in point:
|
||||
line_num = point["line"]
|
||||
if isinstance(line_num, int) and line_num > 0:
|
||||
lines.append(line_num)
|
||||
|
||||
return sorted(set(lines))
|
||||
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
logger.debug("Failed to parse split points: %s", e)
|
||||
return []
|
||||
|
||||
def _split_chunk_at_points(
|
||||
self,
|
||||
chunk: SemanticChunk,
|
||||
split_points: List[int],
|
||||
) -> List[SemanticChunk]:
|
||||
"""Split chunk at specified line numbers.
|
||||
|
||||
Args:
|
||||
chunk: Original chunk to split
|
||||
split_points: Sorted list of line numbers to split at
|
||||
|
||||
Returns:
|
||||
List of smaller chunks
|
||||
"""
|
||||
lines = chunk.content.splitlines(keepends=True)
|
||||
chunks: List[SemanticChunk] = []
|
||||
|
||||
# Get original metadata
|
||||
base_metadata = dict(chunk.metadata)
|
||||
original_start = base_metadata.get("start_line", 1)
|
||||
|
||||
# Add start and end boundaries
|
||||
boundaries = [0] + split_points + [len(lines)]
|
||||
|
||||
for i in range(len(boundaries) - 1):
|
||||
start_idx = boundaries[i]
|
||||
end_idx = boundaries[i + 1]
|
||||
|
||||
# Skip empty sections
|
||||
if start_idx >= end_idx:
|
||||
continue
|
||||
|
||||
# Extract content
|
||||
section_lines = lines[start_idx:end_idx]
|
||||
section_content = "".join(section_lines)
|
||||
|
||||
# Skip if too small
|
||||
if len(section_content.strip()) < 50:
|
||||
continue
|
||||
|
||||
# Create new chunk with updated metadata
|
||||
new_metadata = base_metadata.copy()
|
||||
new_metadata["start_line"] = original_start + start_idx
|
||||
new_metadata["end_line"] = original_start + end_idx - 1
|
||||
new_metadata["refined_by_llm"] = True
|
||||
new_metadata["original_chunk_size"] = len(chunk.content)
|
||||
|
||||
chunks.append(
|
||||
SemanticChunk(
|
||||
content=section_content,
|
||||
embedding=None, # Embeddings will be regenerated
|
||||
metadata=new_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# If no valid chunks created, return original
|
||||
if not chunks:
|
||||
return [chunk]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
|
||||
|
||||
def _process_batch(
|
||||
|
||||
Reference in New Issue
Block a user