Add comprehensive tests for tokenizer, performance benchmarks, and TreeSitter parser functionality

- Implemented unit tests for the Tokenizer class, covering various text inputs, edge cases, and fallback mechanisms.
- Created performance benchmarks comparing tiktoken and pure Python implementations for token counting.
- Developed extensive tests for TreeSitterSymbolParser across Python, JavaScript, and TypeScript, ensuring accurate symbol extraction and parsing.
- Added configuration documentation for MCP integration and custom prompts, enhancing usability and flexibility.
- Introduced a refactor script for GraphAnalyzer to streamline future improvements.
This commit is contained in:
catlog22
2025-12-15 14:36:09 +08:00
parent 82dcafff00
commit 0fe16963cd
49 changed files with 9307 additions and 438 deletions

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from codexlens.entities import SemanticChunk, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@dataclass
@@ -14,6 +15,7 @@ class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 100 # Overlap for sliding window
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
@@ -22,6 +24,7 @@ class Chunker:
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
def chunk_by_symbol(
self,
@@ -29,10 +32,18 @@ class Chunker:
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
@@ -47,6 +58,13 @@ class Chunker:
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._tokenizer.count_tokens(chunk_content)
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
@@ -58,6 +76,7 @@ class Chunker:
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"token_count": token_count,
}
))
@@ -68,10 +87,19 @@ class Chunker:
content: str,
file_path: str | Path,
language: str,
line_mapping: Optional[List[int]] = None,
) -> List[SemanticChunk]:
"""Chunk code using sliding window approach.
Used for files without clear symbol boundaries or very long functions.
Args:
content: Source code content
file_path: Path to source file
language: Programming language
line_mapping: Optional list mapping content line indices to original line numbers
(1-indexed). If provided, line_mapping[i] is the original line number
for the i-th line in content.
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
@@ -92,6 +120,18 @@ class Chunker:
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
token_count = self._tokenizer.count_tokens(chunk_content)
# Calculate correct line numbers
if line_mapping:
# Use line mapping to get original line numbers
start_line = line_mapping[start]
end_line = line_mapping[end - 1]
else:
# Default behavior: treat content as starting at line 1
start_line = start + 1
end_line = end
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
@@ -99,9 +139,10 @@ class Chunker:
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start + 1,
"end_line": end,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"token_count": token_count,
}
))
chunk_idx += 1
@@ -119,12 +160,239 @@ class Chunker:
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk a file using the best strategy.
Uses symbol-based chunking if symbols available,
falls back to sliding window for files without symbols.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
if symbols:
return self.chunk_by_symbol(content, symbols, file_path, language)
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
return self.chunk_sliding_window(content, file_path, language)
class DocstringExtractor:
"""Extract docstrings from source code."""
@staticmethod
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
"""Extract Python docstrings with their line ranges.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
docstrings: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
start_line = i + 1
if stripped.count(quote_type) >= 2:
docstring_content = line
end_line = i + 1
docstrings.append((docstring_content, start_line, end_line))
i += 1
continue
docstring_lines = [line]
i += 1
while i < len(lines):
docstring_lines.append(lines[i])
if quote_type in lines[i]:
break
i += 1
end_line = i + 1
docstring_content = "".join(docstring_lines)
docstrings.append((docstring_content, start_line, end_line))
i += 1
return docstrings
@staticmethod
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
"""Extract JSDoc comments with their line ranges.
Returns: List of (comment_content, start_line, end_line) tuples
"""
comments: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('/**'):
start_line = i + 1
comment_lines = [line]
i += 1
while i < len(lines):
comment_lines.append(lines[i])
if '*/' in lines[i]:
break
i += 1
end_line = i + 1
comment_content = "".join(comment_lines)
comments.append((comment_content, start_line, end_line))
i += 1
return comments
@classmethod
def extract_docstrings(
cls,
content: str,
language: str
) -> List[Tuple[str, int, int]]:
"""Extract docstrings based on language.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
if language == "python":
return cls.extract_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.extract_jsdoc_comments(content)
return []
class HybridChunker:
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
Composition-based strategy that:
1. Extracts docstrings as dedicated chunks
2. For remaining code, uses base chunker (symbol or sliding window)
"""
def __init__(
self,
base_chunker: Chunker | None = None,
config: ChunkConfig | None = None
) -> None:
"""Initialize hybrid chunker.
Args:
base_chunker: Chunker to use for non-docstring content
config: Configuration for chunking
"""
self.config = config or ChunkConfig()
self.base_chunker = base_chunker or Chunker(self.config)
self.docstring_extractor = DocstringExtractor()
def _get_excluded_line_ranges(
self,
docstrings: List[Tuple[str, int, int]]
) -> set[int]:
"""Get set of line numbers that are part of docstrings."""
excluded_lines: set[int] = set()
for _, start_line, end_line in docstrings:
for line_num in range(start_line, end_line + 1):
excluded_lines.add(line_num)
return excluded_lines
def _filter_symbols_outside_docstrings(
self,
symbols: List[Symbol],
excluded_lines: set[int]
) -> List[Symbol]:
"""Filter symbols to exclude those completely within docstrings."""
filtered: List[Symbol] = []
for symbol in symbols:
start_line, end_line = symbol.range
symbol_lines = set(range(start_line, end_line + 1))
if not symbol_lines.issubset(excluded_lines):
filtered.append(symbol)
return filtered
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk file using hybrid strategy.
Extracts docstrings first, then chunks remaining code.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
tokenizer = get_default_tokenizer()
# Step 1: Extract docstrings as dedicated chunks
docstrings = self.docstring_extractor.extract_docstrings(content, language)
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
token_count = tokenizer.count_tokens(docstring_content)
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
))
# Step 2: Get line ranges occupied by docstrings
excluded_lines = self._get_excluded_line_ranges(docstrings)
# Step 3: Filter symbols to exclude docstring-only ranges
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
# Step 4: Chunk remaining content using base chunker
if filtered_symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, filtered_symbols, file_path, language, symbol_token_counts
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
else:
lines = content.splitlines(keepends=True)
remaining_lines: List[str] = []
for i, line in enumerate(lines, start=1):
if i not in excluded_lines:
remaining_lines.append(line)
if remaining_lines:
remaining_content = "".join(remaining_lines)
if len(remaining_content.strip()) >= self.config.min_chunk_size:
base_chunks = self.base_chunker.chunk_sliding_window(
remaining_content, file_path, language
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,531 @@
"""Graph analyzer for extracting code relationships using tree-sitter.
Provides AST-based analysis to identify function calls, method invocations,
and class inheritance relationships within source files.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
try:
from tree_sitter import Node as TreeSitterNode
TREE_SITTER_AVAILABLE = True
except ImportError:
TreeSitterNode = None # type: ignore[assignment]
TREE_SITTER_AVAILABLE = False
from codexlens.entities import CodeRelationship, Symbol
from codexlens.parsers.treesitter_parser import TreeSitterSymbolParser
class GraphAnalyzer:
"""Analyzer for extracting semantic relationships from code using AST traversal."""
def __init__(self, language_id: str, parser: Optional[TreeSitterSymbolParser] = None) -> None:
"""Initialize graph analyzer for a language.
Args:
language_id: Language identifier (python, javascript, typescript, etc.)
parser: Optional TreeSitterSymbolParser instance for dependency injection.
If None, creates a new parser instance (backward compatibility).
"""
self.language_id = language_id
self._parser = parser if parser is not None else TreeSitterSymbolParser(language_id)
def is_available(self) -> bool:
"""Check if graph analyzer is available.
Returns:
True if tree-sitter parser is initialized and ready
"""
return self._parser.is_available()
def analyze_file(self, text: str, file_path: Path) -> List[CodeRelationship]:
"""Analyze source code and extract relationships.
Args:
text: Source code text
file_path: File path for relationship context
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
relationships = self._extract_relationships(source_bytes, root, str(file_path.resolve()))
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def analyze_with_symbols(
self, text: str, file_path: Path, symbols: List[Symbol]
) -> List[CodeRelationship]:
"""Analyze source code using pre-parsed symbols to avoid duplicate parsing.
Args:
text: Source code text
file_path: File path for relationship context
symbols: Pre-parsed Symbol objects from TreeSitterSymbolParser
Returns:
List of CodeRelationship objects representing intra-file relationships
"""
if not self.is_available() or self._parser._parser is None:
return []
try:
source_bytes = text.encode("utf8")
tree = self._parser._parser.parse(source_bytes) # type: ignore[attr-defined]
root = tree.root_node
# Convert Symbol objects to internal symbol format
defined_symbols = self._convert_symbols_to_dict(source_bytes, root, symbols)
# Extract relationships using provided symbols
relationships = self._extract_relationships_with_symbols(
source_bytes, root, str(file_path.resolve()), defined_symbols
)
return relationships
except Exception:
# Gracefully handle parsing errors
return []
def _convert_symbols_to_dict(
self, source_bytes: bytes, root: TreeSitterNode, symbols: List[Symbol]
) -> List[dict]:
"""Convert Symbol objects to internal dict format for relationship extraction.
Args:
source_bytes: Source code as bytes
root: Root AST node
symbols: Pre-parsed Symbol objects
Returns:
List of symbol info dicts with name, node, and type
"""
symbol_dicts = []
symbol_names = {s.name for s in symbols}
# Find AST nodes corresponding to symbols
for node in self._iter_nodes(root):
node_type = node.type
# Check if this node matches any of our symbols
if node_type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
elif node_type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "method"
})
elif node_type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "class"
})
elif node_type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
name = self._node_text(source_bytes, name_node)
if name in symbol_names:
symbol_dicts.append({
"name": name,
"node": node,
"type": "function"
})
return symbol_dicts
def _extract_relationships_with_symbols(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str, defined_symbols: List[dict]
) -> List[CodeRelationship]:
"""Extract relationships from AST using pre-parsed symbols.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
defined_symbols: Pre-parsed symbol dicts
Returns:
List of extracted relationships
"""
relationships: List[CodeRelationship] = []
# Determine call node type based on language
if self.language_id == "python":
call_node_type = "call"
extract_target = self._extract_call_target
elif self.language_id in {"javascript", "typescript"}:
call_node_type = "call_expression"
extract_target = self._extract_js_call_target
else:
return []
# Find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == call_node_type:
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = extract_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of extracted relationships
"""
if self.language_id == "python":
return self._extract_python_relationships(source_bytes, root, file_path)
elif self.language_id in {"javascript", "typescript"}:
return self._extract_js_ts_relationships(source_bytes, root, file_path)
else:
return []
def _extract_python_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract Python relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of Python relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols with their scopes
defined_symbols = self._collect_python_symbols(source_bytes, root)
# Second pass: find call expressions and match to defined symbols
for node in self._iter_nodes(root):
if node.type == "call":
# Extract caller context (enclosing function/method/class)
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
# Call at module level, use "<module>" as source
source_symbol = "<module>"
# Extract callee (function/method being called)
target_symbol = self._extract_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None, # Intra-file only
source_line=line_number,
)
)
return relationships
def _extract_js_ts_relationships(
self, source_bytes: bytes, root: TreeSitterNode, file_path: str
) -> List[CodeRelationship]:
"""Extract JavaScript/TypeScript relationships from AST.
Args:
source_bytes: Source code as bytes
root: Root AST node
file_path: Absolute file path
Returns:
List of JS/TS relationships (function/method calls)
"""
relationships: List[CodeRelationship] = []
# First pass: collect all defined symbols
defined_symbols = self._collect_js_ts_symbols(source_bytes, root)
# Second pass: find call expressions
for node in self._iter_nodes(root):
if node.type == "call_expression":
# Extract caller context
source_symbol = self._find_enclosing_symbol(node, defined_symbols)
if source_symbol is None:
source_symbol = "<module>"
# Extract callee
target_symbol = self._extract_js_call_target(source_bytes, node)
if target_symbol is None:
continue
# Create relationship
line_number = node.start_point[0] + 1
relationships.append(
CodeRelationship(
source_symbol=source_symbol,
target_symbol=target_symbol,
relationship_type="call",
source_file=file_path,
target_file=None,
source_line=line_number,
)
)
return relationships
def _collect_python_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all Python function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_definition", "async_function_definition"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
return symbols
def _collect_js_ts_symbols(self, source_bytes: bytes, root: TreeSitterNode) -> List[dict]:
"""Collect all JS/TS function/method/class definitions.
Args:
source_bytes: Source code as bytes
root: Root AST node
Returns:
List of symbol info dicts with name, node, and type
"""
symbols = []
for node in self._iter_nodes(root):
if node.type in {"function_declaration", "generator_function_declaration"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
elif node.type == "method_definition":
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "method"
})
elif node.type in {"class_declaration", "class"}:
name_node = node.child_by_field_name("name")
if name_node:
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "class"
})
elif node.type == "variable_declarator":
name_node = node.child_by_field_name("name")
value_node = node.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
symbols.append({
"name": self._node_text(source_bytes, name_node),
"node": node,
"type": "function"
})
return symbols
def _find_enclosing_symbol(self, node: TreeSitterNode, symbols: List[dict]) -> Optional[str]:
"""Find the enclosing function/method/class for a node.
Args:
node: AST node to find enclosure for
symbols: List of defined symbols
Returns:
Name of enclosing symbol, or None if at module level
"""
# Walk up the tree to find enclosing symbol
parent = node.parent
while parent is not None:
for symbol in symbols:
if symbol["node"] == parent:
return symbol["name"]
parent = parent.parent
return None
def _extract_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a Python call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers (e.g., "foo()")
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle attribute access (e.g., "obj.method()")
if function_node.type == "attribute":
attr_node = function_node.child_by_field_name("attribute")
if attr_node:
return self._node_text(source_bytes, attr_node)
return None
def _extract_js_call_target(self, source_bytes: bytes, node: TreeSitterNode) -> Optional[str]:
"""Extract the target function name from a JS/TS call expression.
Args:
source_bytes: Source code as bytes
node: Call expression node
Returns:
Target function name, or None if cannot be determined
"""
function_node = node.child_by_field_name("function")
if function_node is None:
return None
# Handle simple identifiers
if function_node.type == "identifier":
return self._node_text(source_bytes, function_node)
# Handle member expressions (e.g., "obj.method()")
if function_node.type == "member_expression":
property_node = function_node.child_by_field_name("property")
if property_node:
return self._node_text(source_bytes, property_node)
return None
def _iter_nodes(self, root: TreeSitterNode):
"""Iterate over all nodes in AST.
Args:
root: Root node to start iteration
Yields:
AST nodes in depth-first order
"""
stack = [root]
while stack:
node = stack.pop()
yield node
for child in reversed(node.children):
stack.append(child)
def _node_text(self, source_bytes: bytes, node: TreeSitterNode) -> str:
"""Extract text for a node.
Args:
source_bytes: Source code as bytes
node: AST node
Returns:
Text content of node
"""
return source_bytes[node.start_byte:node.end_byte].decode("utf8")

View File

@@ -75,6 +75,34 @@ class LLMEnhancer:
external LLM tools (gemini, qwen) via CCW CLI subprocess.
"""
CHUNK_REFINEMENT_PROMPT = '''PURPOSE: Identify optimal semantic split points in code chunk
TASK:
- Analyze the code structure to find natural semantic boundaries
- Identify logical groupings (functions, classes, related statements)
- Suggest split points that maintain semantic cohesion
MODE: analysis
EXPECTED: JSON format with split positions
=== CODE CHUNK ===
{code_chunk}
=== OUTPUT FORMAT ===
Return ONLY valid JSON (no markdown, no explanation):
{{
"split_points": [
{{
"line": <line_number>,
"reason": "brief reason for split (e.g., 'start of new function', 'end of class definition')"
}}
]
}}
Rules:
- Split at function/class/method boundaries
- Keep related code together (don't split mid-function)
- Aim for chunks between 500-2000 characters
- Return empty split_points if no good splits found'''
PROMPT_TEMPLATE = '''PURPOSE: Generate semantic summaries and search keywords for code files
TASK:
- For each code block, generate a concise summary (1-2 sentences)
@@ -168,42 +196,246 @@ Return ONLY valid JSON (no markdown, no explanation):
return results
def enhance_file(
self,
path: str,
content: str,
language: str,
working_dir: Optional[Path] = None,
) -> SemanticMetadata:
"""Enhance a single file with LLM-generated semantic metadata.
Convenience method that wraps enhance_files for single file processing.
Args:
path: File path
content: File content
language: Programming language
working_dir: Optional working directory for CCW CLI
Returns:
SemanticMetadata for the file
Raises:
ValueError: If enhancement fails
"""
file_data = FileData(path=path, content=content, language=language)
results = self.enhance_files([file_data], working_dir)
if path not in results:
# Return default metadata if enhancement failed
return SemanticMetadata(
summary=f"Code file written in {language}",
keywords=[language, "code"],
purpose="unknown",
file_path=path,
llm_tool=self.config.tool,
)
return results[path]
def refine_chunk_boundaries(
self,
chunk: SemanticChunk,
max_chunk_size: int = 2000,
working_dir: Optional[Path] = None,
) -> List[SemanticChunk]:
"""Refine chunk boundaries using LLM for large code chunks.
Uses LLM to identify semantic split points in large chunks,
breaking them into smaller, more cohesive pieces.
Args:
chunk: Original chunk to refine
max_chunk_size: Maximum characters before triggering refinement
working_dir: Optional working directory for CCW CLI
Returns:
SemanticMetadata for the file
Raises:
ValueError: If enhancement fails
List of refined chunks (original chunk if no splits or refinement fails)
"""
file_data = FileData(path=path, content=content, language=language)
results = self.enhance_files([file_data], working_dir)
# Skip if chunk is small enough
if len(chunk.content) <= max_chunk_size:
return [chunk]
if path not in results:
# Return default metadata if enhancement failed
return SemanticMetadata(
summary=f"Code file written in {language}",
keywords=[language, "code"],
purpose="unknown",
file_path=path,
llm_tool=self.config.tool,
# Skip if LLM enhancement disabled or unavailable
if not self.config.enabled or not self.check_available():
return [chunk]
# Skip docstring chunks - only refine code chunks
if chunk.metadata.get("chunk_type") == "docstring":
return [chunk]
try:
# Build refinement prompt
prompt = self.CHUNK_REFINEMENT_PROMPT.format(code_chunk=chunk.content)
# Invoke LLM
result = self._invoke_ccw_cli(
prompt,
tool=self.config.tool,
working_dir=working_dir,
)
return results[path]
# Fallback if primary tool fails
if not result["success"] and self.config.fallback_tool:
result = self._invoke_ccw_cli(
prompt,
tool=self.config.fallback_tool,
working_dir=working_dir,
)
if not result["success"]:
logger.debug("LLM refinement failed, returning original chunk")
return [chunk]
# Parse split points
split_points = self._parse_split_points(result["stdout"])
if not split_points:
logger.debug("No split points identified, returning original chunk")
return [chunk]
# Split chunk at identified boundaries
refined_chunks = self._split_chunk_at_points(chunk, split_points)
logger.debug(
"Refined chunk into %d smaller chunks (was %d chars)",
len(refined_chunks),
len(chunk.content),
)
return refined_chunks
except Exception as e:
logger.warning("Chunk refinement error: %s, returning original chunk", e)
return [chunk]
def _parse_split_points(self, stdout: str) -> List[int]:
"""Parse split points from LLM response.
Args:
stdout: Raw stdout from CCW CLI
Returns:
List of line numbers where splits should occur (sorted)
"""
# Extract JSON from response
json_str = self._extract_json(stdout)
if not json_str:
return []
try:
data = json.loads(json_str)
split_points_data = data.get("split_points", [])
# Extract line numbers
lines = []
for point in split_points_data:
if isinstance(point, dict) and "line" in point:
line_num = point["line"]
if isinstance(line_num, int) and line_num > 0:
lines.append(line_num)
return sorted(set(lines))
except (json.JSONDecodeError, ValueError, TypeError) as e:
logger.debug("Failed to parse split points: %s", e)
return []
def _split_chunk_at_points(
self,
chunk: SemanticChunk,
split_points: List[int],
) -> List[SemanticChunk]:
"""Split chunk at specified line numbers.
Args:
chunk: Original chunk to split
split_points: Sorted list of line numbers to split at
Returns:
List of smaller chunks
"""
lines = chunk.content.splitlines(keepends=True)
chunks: List[SemanticChunk] = []
# Get original metadata
base_metadata = dict(chunk.metadata)
original_start = base_metadata.get("start_line", 1)
# Add start and end boundaries
boundaries = [0] + split_points + [len(lines)]
for i in range(len(boundaries) - 1):
start_idx = boundaries[i]
end_idx = boundaries[i + 1]
# Skip empty sections
if start_idx >= end_idx:
continue
# Extract content
section_lines = lines[start_idx:end_idx]
section_content = "".join(section_lines)
# Skip if too small
if len(section_content.strip()) < 50:
continue
# Create new chunk with updated metadata
new_metadata = base_metadata.copy()
new_metadata["start_line"] = original_start + start_idx
new_metadata["end_line"] = original_start + end_idx - 1
new_metadata["refined_by_llm"] = True
new_metadata["original_chunk_size"] = len(chunk.content)
chunks.append(
SemanticChunk(
content=section_content,
embedding=None, # Embeddings will be regenerated
metadata=new_metadata,
)
)
# If no valid chunks created, return original
if not chunks:
return [chunk]
return chunks
def _process_batch(