feat: Enhance configuration management and embedding capabilities

- Added JSON-based settings management in Config class for embedding and LLM configurations.
- Introduced methods to save and load settings from a JSON file.
- Updated BaseEmbedder and its subclasses to include max_tokens property for better token management.
- Enhanced chunking strategy to support recursive splitting of large symbols with improved overlap handling.
- Implemented comprehensive tests for recursive splitting and chunking behavior.
- Added CLI tools configuration management for better integration with external tools.
- Introduced a new command for compacting session memory into structured text for recovery.
This commit is contained in:
catlog22
2025-12-24 16:32:27 +08:00
parent b00113d212
commit e671b45948
25 changed files with 2889 additions and 153 deletions

View File

@@ -4,8 +4,10 @@ import gc
import logging
import sqlite3
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import islice
from pathlib import Path
from threading import Lock
from typing import Dict, Generator, List, Optional, Tuple
try:
@@ -79,6 +81,44 @@ def _generate_chunks_from_cursor(
failed_files.append((file_path, str(e)))
def _create_token_aware_batches(
chunk_generator: Generator,
max_tokens_per_batch: int = 8000,
) -> Generator[List[Tuple], None, None]:
"""Group chunks by total token count instead of fixed count.
Uses fast token estimation (len(content) // 4) for efficiency.
Yields batches when approaching the token limit.
Args:
chunk_generator: Generator yielding (chunk, file_path) tuples
max_tokens_per_batch: Maximum tokens per batch (default: 8000)
Yields:
List of (chunk, file_path) tuples representing a batch
"""
current_batch = []
current_tokens = 0
for chunk, file_path in chunk_generator:
# Fast token estimation: len(content) // 4
chunk_tokens = len(chunk.content) // 4
# If adding this chunk would exceed limit and we have items, yield current batch
if current_tokens + chunk_tokens > max_tokens_per_batch and current_batch:
yield current_batch
current_batch = []
current_tokens = 0
# Add chunk to current batch
current_batch.append((chunk, file_path))
current_tokens += chunk_tokens
# Yield final batch if not empty
if current_batch:
yield current_batch
def _get_path_column(conn: sqlite3.Connection) -> str:
"""Detect whether files table uses 'path' or 'full_path' column.
@@ -189,31 +229,69 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]:
}
def _get_embedding_defaults() -> tuple[str, str, bool]:
"""Get default embedding settings from config.
Returns:
Tuple of (backend, model, use_gpu)
"""
try:
from codexlens.config import Config
config = Config.load()
return config.embedding_backend, config.embedding_model, config.embedding_use_gpu
except Exception:
return "fastembed", "code", True
def generate_embeddings(
index_path: Path,
embedding_backend: str = "fastembed",
model_profile: str = "code",
embedding_backend: Optional[str] = None,
model_profile: Optional[str] = None,
force: bool = False,
chunk_size: int = 2000,
overlap: int = 200,
progress_callback: Optional[callable] = None,
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: int = 1,
) -> Dict[str, any]:
"""Generate embeddings for an index using memory-efficient batch processing.
This function processes files in small batches to keep memory usage under 2GB,
regardless of the total project size.
regardless of the total project size. Supports concurrent API calls for
LiteLLM backend to improve throughput.
Args:
index_path: Path to _index.db file
embedding_backend: Embedding backend to use (fastembed or litellm)
embedding_backend: Embedding backend to use (fastembed or litellm).
Defaults to config setting.
model_profile: Model profile for fastembed (fast, code, multilingual, balanced)
or model name for litellm (e.g., text-embedding-3-small)
or model name for litellm (e.g., qwen3-embedding).
Defaults to config setting.
force: If True, regenerate even if embeddings exist
chunk_size: Maximum chunk size in characters
overlap: Overlap size in characters for sliding window chunking (default: 200)
progress_callback: Optional callback for progress updates
use_gpu: Whether to use GPU acceleration (fastembed only).
Defaults to config setting.
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
Recommended: 2-4 for LiteLLM API backends.
Returns:
Result dictionary with generation statistics
"""
# Get defaults from config if not specified
default_backend, default_model, default_gpu = _get_embedding_defaults()
if embedding_backend is None:
embedding_backend = default_backend
if model_profile is None:
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
if not SEMANTIC_AVAILABLE:
return {
"success": False,
@@ -261,9 +339,9 @@ def generate_embeddings(
# Initialize embedder using factory (supports both fastembed and litellm)
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
# For litellm: model_profile is a model name (e.g., text-embedding-3-small)
# For litellm: model_profile is a model name (e.g., qwen3-embedding)
if embedding_backend == "fastembed":
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=True)
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu)
elif embedding_backend == "litellm":
embedder = get_embedder_factory(backend="litellm", model=model_profile)
else:
@@ -274,7 +352,11 @@ def generate_embeddings(
# skip_token_count=True: Use fast estimation (len/4) instead of expensive tiktoken
# This significantly reduces CPU usage with minimal impact on metadata accuracy
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size, skip_token_count=True))
chunker = Chunker(config=ChunkConfig(
max_chunk_size=chunk_size,
overlap=overlap,
skip_token_count=True
))
if progress_callback:
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
@@ -336,43 +418,105 @@ def generate_embeddings(
cursor, chunker, path_column, FILE_BATCH_SIZE, failed_files
)
# Determine max tokens per batch
# Priority: explicit parameter > embedder.max_tokens > default 8000
if max_tokens_per_batch is None:
max_tokens_per_batch = getattr(embedder, 'max_tokens', 8000)
# Create token-aware batches or fall back to fixed-size batching
if max_tokens_per_batch:
batch_generator = _create_token_aware_batches(
chunk_generator, max_tokens_per_batch
)
else:
# Fallback to fixed-size batching for backward compatibility
def fixed_size_batches():
while True:
batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
if not batch:
break
yield batch
batch_generator = fixed_size_batches()
batch_number = 0
files_seen = set()
while True:
# Get a small batch of chunks from the generator (EMBEDDING_BATCH_SIZE at a time)
chunk_batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
if not chunk_batch:
break
# Thread-safe counters for concurrent processing
counter_lock = Lock()
batch_number += 1
def process_batch(batch_data: Tuple[int, List[Tuple]]) -> Tuple[int, set, Optional[str]]:
"""Process a single batch: generate embeddings and store.
# Track unique files for progress
for _, file_path in chunk_batch:
files_seen.add(file_path)
Args:
batch_data: Tuple of (batch_number, chunk_batch)
Returns:
Tuple of (chunks_created, files_in_batch, error_message)
"""
batch_num, chunk_batch = batch_data
batch_files = set()
# Generate embeddings directly to numpy (no tolist() conversion)
try:
# Track files in this batch
for _, file_path in chunk_batch:
batch_files.add(file_path)
# Generate embeddings
batch_contents = [chunk.content for chunk, _ in chunk_batch]
# Pass batch_size to fastembed for optimal GPU utilization
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
# Use add_chunks_batch_numpy to avoid numpy->list->numpy roundtrip
# Store embeddings (thread-safe via SQLite's serialized mode)
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
total_chunks_created += len(chunk_batch)
total_files_processed = len(files_seen)
if progress_callback and batch_number % 10 == 0:
progress_callback(f" Batch {batch_number}: {total_chunks_created} chunks, {total_files_processed} files")
# Cleanup intermediate data
del batch_contents, embeddings_numpy, chunk_batch
return len(chunk_batch), batch_files, None
except Exception as e:
logger.error(f"Failed to process embedding batch {batch_number}: {str(e)}")
# Continue to next batch instead of failing entirely
continue
error_msg = f"Batch {batch_num}: {str(e)}"
logger.error(f"Failed to process embedding batch {batch_num}: {str(e)}")
return 0, batch_files, error_msg
# Collect batches for concurrent processing
all_batches = []
for chunk_batch in batch_generator:
batch_number += 1
all_batches.append((batch_number, chunk_batch))
# Process batches (sequential or concurrent based on max_workers)
if max_workers <= 1:
# Sequential processing (original behavior)
for batch_num, chunk_batch in all_batches:
chunks_created, batch_files, error = process_batch((batch_num, chunk_batch))
files_seen.update(batch_files)
total_chunks_created += chunks_created
total_files_processed = len(files_seen)
if progress_callback and batch_num % 10 == 0:
progress_callback(f" Batch {batch_num}: {total_chunks_created} chunks, {total_files_processed} files")
else:
# Concurrent processing for API backends
if progress_callback:
progress_callback(f"Processing {len(all_batches)} batches with {max_workers} concurrent workers...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(process_batch, batch): batch[0] for batch in all_batches}
completed = 0
for future in as_completed(futures):
batch_num = futures[future]
try:
chunks_created, batch_files, error = future.result()
with counter_lock:
files_seen.update(batch_files)
total_chunks_created += chunks_created
total_files_processed = len(files_seen)
completed += 1
if progress_callback and completed % 10 == 0:
progress_callback(f" Completed {completed}/{len(all_batches)} batches: {total_chunks_created} chunks")
except Exception as e:
logger.error(f"Batch {batch_num} raised exception: {str(e)}")
# Notify before ANN index finalization (happens when bulk_insert context exits)
if progress_callback:
@@ -445,26 +589,49 @@ def find_all_indexes(scan_dir: Path) -> List[Path]:
def generate_embeddings_recursive(
index_root: Path,
embedding_backend: str = "fastembed",
model_profile: str = "code",
embedding_backend: Optional[str] = None,
model_profile: Optional[str] = None,
force: bool = False,
chunk_size: int = 2000,
overlap: int = 200,
progress_callback: Optional[callable] = None,
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: int = 1,
) -> Dict[str, any]:
"""Generate embeddings for all index databases in a project recursively.
Args:
index_root: Root index directory containing _index.db files
embedding_backend: Embedding backend to use (fastembed or litellm)
embedding_backend: Embedding backend to use (fastembed or litellm).
Defaults to config setting.
model_profile: Model profile for fastembed (fast, code, multilingual, balanced)
or model name for litellm (e.g., text-embedding-3-small)
or model name for litellm (e.g., qwen3-embedding).
Defaults to config setting.
force: If True, regenerate even if embeddings exist
chunk_size: Maximum chunk size in characters
overlap: Overlap size in characters for sliding window chunking (default: 200)
progress_callback: Optional callback for progress updates
use_gpu: Whether to use GPU acceleration (fastembed only).
Defaults to config setting.
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
Recommended: 2-4 for LiteLLM API backends.
Returns:
Aggregated result dictionary with generation statistics
"""
# Get defaults from config if not specified
default_backend, default_model, default_gpu = _get_embedding_defaults()
if embedding_backend is None:
embedding_backend = default_backend
if model_profile is None:
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
# Discover all _index.db files
index_files = discover_all_index_dbs(index_root)
@@ -498,7 +665,11 @@ def generate_embeddings_recursive(
model_profile=model_profile,
force=force,
chunk_size=chunk_size,
overlap=overlap,
progress_callback=None, # Don't cascade callbacks
use_gpu=use_gpu,
max_tokens_per_batch=max_tokens_per_batch,
max_workers=max_workers,
)
all_results.append({

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass, field
from functools import cached_property
@@ -14,6 +15,9 @@ from .errors import ConfigError
# Workspace-local directory name
WORKSPACE_DIR_NAME = ".codexlens"
# Settings file name
SETTINGS_FILE_NAME = "settings.json"
def _default_global_dir() -> Path:
"""Get global CodexLens data directory."""
@@ -89,6 +93,13 @@ class Config:
# Hybrid chunker configuration
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
# Embedding configuration
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
# For litellm: model name from config (e.g., "qwen3-embedding")
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()
@@ -133,6 +144,67 @@ class Config:
"""Get parsing rules for a specific language, falling back to defaults."""
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
@cached_property
def settings_path(self) -> Path:
"""Path to the settings file."""
return self.data_dir / SETTINGS_FILE_NAME
def save_settings(self) -> None:
"""Save embedding and other settings to file."""
settings = {
"embedding": {
"backend": self.embedding_backend,
"model": self.embedding_model,
"use_gpu": self.embedding_use_gpu,
},
"llm": {
"enabled": self.llm_enabled,
"tool": self.llm_tool,
"timeout_ms": self.llm_timeout_ms,
"batch_size": self.llm_batch_size,
},
}
with open(self.settings_path, "w", encoding="utf-8") as f:
json.dump(settings, f, indent=2)
def load_settings(self) -> None:
"""Load settings from file if exists."""
if not self.settings_path.exists():
return
try:
with open(self.settings_path, "r", encoding="utf-8") as f:
settings = json.load(f)
# Load embedding settings
embedding = settings.get("embedding", {})
if "backend" in embedding:
self.embedding_backend = embedding["backend"]
if "model" in embedding:
self.embedding_model = embedding["model"]
if "use_gpu" in embedding:
self.embedding_use_gpu = embedding["use_gpu"]
# Load LLM settings
llm = settings.get("llm", {})
if "enabled" in llm:
self.llm_enabled = llm["enabled"]
if "tool" in llm:
self.llm_tool = llm["tool"]
if "timeout_ms" in llm:
self.llm_timeout_ms = llm["timeout_ms"]
if "batch_size" in llm:
self.llm_batch_size = llm["batch_size"]
except Exception:
pass # Silently ignore errors
@classmethod
def load(cls) -> "Config":
"""Load config with settings from file."""
config = cls()
config.load_settings()
return config
@dataclass
class WorkspaceConfig:

View File

@@ -38,6 +38,16 @@ class BaseEmbedder(ABC):
"""
...
@property
def max_tokens(self) -> int:
"""Return maximum token limit for embeddings.
Returns:
int: Maximum number of tokens that can be embedded at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Embed texts to numpy array.

View File

@@ -39,7 +39,7 @@ from codexlens.parsers.tokenizer import get_default_tokenizer
class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 100 # Overlap for sliding window
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
@@ -80,6 +80,7 @@ class Chunker:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Large symbols exceeding max_chunk_size are recursively split using sliding window.
Args:
content: Source code content
@@ -101,27 +102,49 @@ class Chunker:
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._estimate_token_count(chunk_content)
# Check if symbol content exceeds max_chunk_size
if len(chunk_content) > self.config.max_chunk_size:
# Create line mapping for correct line number tracking
line_mapping = list(range(start_line, end_line + 1))
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"token_count": token_count,
}
))
# Use sliding window to split large symbol
sub_chunks = self.chunk_sliding_window(
chunk_content,
file_path=file_path,
language=language,
line_mapping=line_mapping
)
# Update sub_chunks with parent symbol metadata
for sub_chunk in sub_chunks:
sub_chunk.metadata["symbol_name"] = symbol.name
sub_chunk.metadata["symbol_kind"] = symbol.kind
sub_chunk.metadata["strategy"] = "symbol_split"
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
chunks.extend(sub_chunks)
else:
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._estimate_token_count(chunk_content)
chunks.append(SemanticChunk(
content=chunk_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"token_count": token_count,
}
))
return chunks

View File

@@ -165,6 +165,33 @@ class Embedder(BaseEmbedder):
"""Get embedding dimension for current model."""
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
@property
def max_tokens(self) -> int:
"""Get maximum token limit for current model.
Returns:
int: Maximum number of tokens based on model profile.
- fast: 512 (lightweight, optimized for speed)
- code: 8192 (code-optimized, larger context)
- multilingual: 512 (standard multilingual model)
- balanced: 512 (general purpose)
"""
# Determine profile from model name
profile = None
for prof, model in self.MODELS.items():
if model == self._model_name:
profile = prof
break
# Return token limit based on profile
if profile == "code":
return 8192
elif profile in ("fast", "multilingual", "balanced"):
return 512
else:
# Default for unknown models
return 512
@property
def providers(self) -> List[str]:
"""Get configured ONNX execution providers."""

View File

@@ -63,11 +63,39 @@ class LiteLLMEmbedderWrapper(BaseEmbedder):
"""
return self._embedder.model_name
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
@property
def max_tokens(self) -> int:
"""Return maximum token limit for the embedding model.
Returns:
int: Maximum number of tokens that can be embedded at once.
Inferred from model config or model name patterns.
"""
# Try to get from LiteLLM config first
if hasattr(self._embedder, 'max_input_tokens') and self._embedder.max_input_tokens:
return self._embedder.max_input_tokens
# Infer from model name
model_name_lower = self.model_name.lower()
# Large models (8B or "large" in name)
if '8b' in model_name_lower or 'large' in model_name_lower:
return 32768
# OpenAI text-embedding-3-* models
if 'text-embedding-3' in model_name_lower:
return 8191
# Default fallback
return 8192
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts to numpy array using LiteLLMEmbedder.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments (ignored for LiteLLM backend).
Accepts batch_size for API compatibility with fastembed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
@@ -76,4 +104,5 @@ class LiteLLMEmbedderWrapper(BaseEmbedder):
texts = [texts]
else:
texts = list(texts)
# LiteLLM handles batching internally, ignore batch_size parameter
return self._embedder.embed(texts)

View File

@@ -0,0 +1,291 @@
"""Tests for recursive splitting of large symbols in chunker."""
import pytest
from codexlens.entities import Symbol
from codexlens.semantic.chunker import Chunker, ChunkConfig
class TestRecursiveSplitting:
"""Test cases for recursive splitting of large symbols."""
def test_small_symbol_no_split(self):
"""Test that small symbols are not split."""
config = ChunkConfig(max_chunk_size=1000, overlap=100)
chunker = Chunker(config)
content = '''def small_function():
# This is a small function
x = 1
y = 2
return x + y
'''
symbols = [Symbol(name='small_function', kind='function', range=(1, 5))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
assert len(chunks) == 1
assert chunks[0].metadata['strategy'] == 'symbol'
assert chunks[0].metadata['symbol_name'] == 'small_function'
assert chunks[0].metadata['symbol_kind'] == 'function'
assert 'parent_symbol_range' not in chunks[0].metadata
def test_large_symbol_splits(self):
"""Test that large symbols are recursively split."""
config = ChunkConfig(max_chunk_size=100, overlap=20)
chunker = Chunker(config)
content = '''def large_function():
# Line 1
# Line 2
# Line 3
# Line 4
# Line 5
# Line 6
# Line 7
# Line 8
# Line 9
# Line 10
# Line 11
# Line 12
# Line 13
# Line 14
# Line 15
pass
'''
symbols = [Symbol(name='large_function', kind='function', range=(1, 18))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Should be split into multiple chunks
assert len(chunks) > 1
# All chunks should have symbol metadata
for chunk in chunks:
assert chunk.metadata['strategy'] == 'symbol_split'
assert chunk.metadata['symbol_name'] == 'large_function'
assert chunk.metadata['symbol_kind'] == 'function'
assert chunk.metadata['parent_symbol_range'] == (1, 18)
def test_boundary_condition(self):
"""Test symbol exactly at max_chunk_size boundary."""
config = ChunkConfig(max_chunk_size=90, overlap=20)
chunker = Chunker(config)
content = '''def boundary_function():
# This function is exactly at boundary
x = 1
y = 2
return x + y
'''
symbols = [Symbol(name='boundary_function', kind='function', range=(1, 5))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Content is slightly over 90 chars, should be split
assert len(chunks) >= 1
assert chunks[0].metadata['strategy'] == 'symbol_split'
def test_multiple_symbols_mixed_sizes(self):
"""Test chunking with multiple symbols of different sizes."""
config = ChunkConfig(max_chunk_size=150, overlap=30)
chunker = Chunker(config)
content = '''def small():
return 1
def medium():
# Medium function
x = 1
y = 2
z = 3
return x + y + z
def very_large():
# Line 1
# Line 2
# Line 3
# Line 4
# Line 5
# Line 6
# Line 7
# Line 8
# Line 9
# Line 10
# Line 11
# Line 12
# Line 13
# Line 14
# Line 15
pass
'''
symbols = [
Symbol(name='small', kind='function', range=(1, 2)),
Symbol(name='medium', kind='function', range=(4, 9)),
Symbol(name='very_large', kind='function', range=(11, 28)),
]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Find chunks for each symbol
small_chunks = [c for c in chunks if c.metadata['symbol_name'] == 'small']
medium_chunks = [c for c in chunks if c.metadata['symbol_name'] == 'medium']
large_chunks = [c for c in chunks if c.metadata['symbol_name'] == 'very_large']
# Small should be filtered (< min_chunk_size)
assert len(small_chunks) == 0
# Medium should not be split
assert len(medium_chunks) == 1
assert medium_chunks[0].metadata['strategy'] == 'symbol'
# Large should be split
assert len(large_chunks) > 1
for chunk in large_chunks:
assert chunk.metadata['strategy'] == 'symbol_split'
def test_line_numbers_preserved(self):
"""Test that line numbers are correctly preserved in sub-chunks."""
config = ChunkConfig(max_chunk_size=100, overlap=20)
chunker = Chunker(config)
content = '''def large_function():
# Line 1 with some extra content to make it longer
# Line 2 with some extra content to make it longer
# Line 3 with some extra content to make it longer
# Line 4 with some extra content to make it longer
# Line 5 with some extra content to make it longer
# Line 6 with some extra content to make it longer
# Line 7 with some extra content to make it longer
# Line 8 with some extra content to make it longer
# Line 9 with some extra content to make it longer
# Line 10 with some extra content to make it longer
pass
'''
symbols = [Symbol(name='large_function', kind='function', range=(1, 13))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Verify line numbers are correct and sequential
assert len(chunks) > 1
assert chunks[0].metadata['start_line'] == 1
# Each chunk should have valid line numbers
for chunk in chunks:
assert chunk.metadata['start_line'] >= 1
assert chunk.metadata['end_line'] <= 13
assert chunk.metadata['start_line'] <= chunk.metadata['end_line']
def test_overlap_in_split_chunks(self):
"""Test that overlap is applied when splitting large symbols."""
config = ChunkConfig(max_chunk_size=100, overlap=30)
chunker = Chunker(config)
content = '''def large_function():
# Line 1
# Line 2
# Line 3
# Line 4
# Line 5
# Line 6
# Line 7
# Line 8
# Line 9
# Line 10
# Line 11
# Line 12
pass
'''
symbols = [Symbol(name='large_function', kind='function', range=(1, 14))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# With overlap, consecutive chunks should overlap
if len(chunks) > 1:
for i in range(len(chunks) - 1):
# Next chunk should start before current chunk ends (overlap)
current_end = chunks[i].metadata['end_line']
next_start = chunks[i + 1].metadata['start_line']
# Overlap should exist
assert next_start <= current_end
def test_empty_symbol_filtered(self):
"""Test that symbols smaller than min_chunk_size are filtered."""
config = ChunkConfig(max_chunk_size=1000, min_chunk_size=50)
chunker = Chunker(config)
content = '''def tiny():
pass
'''
symbols = [Symbol(name='tiny', kind='function', range=(1, 2))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Should be filtered due to min_chunk_size
assert len(chunks) == 0
def test_class_symbol_splits(self):
"""Test that large class symbols are also split correctly."""
config = ChunkConfig(max_chunk_size=120, overlap=25)
chunker = Chunker(config)
content = '''class LargeClass:
"""A large class with many methods."""
def method1(self):
return 1
def method2(self):
return 2
def method3(self):
return 3
def method4(self):
return 4
'''
symbols = [Symbol(name='LargeClass', kind='class', range=(1, 14))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Should be split
assert len(chunks) > 1
# All chunks should preserve class metadata
for chunk in chunks:
assert chunk.metadata['symbol_name'] == 'LargeClass'
assert chunk.metadata['symbol_kind'] == 'class'
assert chunk.metadata['strategy'] == 'symbol_split'
class TestLightweightMode:
"""Test recursive splitting with lightweight token counting."""
def test_large_symbol_splits_lightweight_mode(self):
"""Test that large symbols split correctly in lightweight mode."""
config = ChunkConfig(max_chunk_size=100, overlap=20, skip_token_count=True)
chunker = Chunker(config)
content = '''def large_function():
# Line 1 with some extra content to make it longer
# Line 2 with some extra content to make it longer
# Line 3 with some extra content to make it longer
# Line 4 with some extra content to make it longer
# Line 5 with some extra content to make it longer
# Line 6 with some extra content to make it longer
# Line 7 with some extra content to make it longer
# Line 8 with some extra content to make it longer
# Line 9 with some extra content to make it longer
# Line 10 with some extra content to make it longer
pass
'''
symbols = [Symbol(name='large_function', kind='function', range=(1, 13))]
chunks = chunker.chunk_by_symbol(content, symbols, 'test.py', 'python')
# Should split even in lightweight mode
assert len(chunks) > 1
# All chunks should have token_count (estimated)
for chunk in chunks:
assert 'token_count' in chunk.metadata
assert chunk.metadata['token_count'] > 0