mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-09 02:24:11 +08:00
feat: Enhance configuration management and embedding capabilities
- Added JSON-based settings management in Config class for embedding and LLM configurations. - Introduced methods to save and load settings from a JSON file. - Updated BaseEmbedder and its subclasses to include max_tokens property for better token management. - Enhanced chunking strategy to support recursive splitting of large symbols with improved overlap handling. - Implemented comprehensive tests for recursive splitting and chunking behavior. - Added CLI tools configuration management for better integration with external tools. - Introduced a new command for compacting session memory into structured text for recovery.
This commit is contained in:
@@ -4,8 +4,10 @@ import gc
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
@@ -79,6 +81,44 @@ def _generate_chunks_from_cursor(
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
|
||||
def _create_token_aware_batches(
|
||||
chunk_generator: Generator,
|
||||
max_tokens_per_batch: int = 8000,
|
||||
) -> Generator[List[Tuple], None, None]:
|
||||
"""Group chunks by total token count instead of fixed count.
|
||||
|
||||
Uses fast token estimation (len(content) // 4) for efficiency.
|
||||
Yields batches when approaching the token limit.
|
||||
|
||||
Args:
|
||||
chunk_generator: Generator yielding (chunk, file_path) tuples
|
||||
max_tokens_per_batch: Maximum tokens per batch (default: 8000)
|
||||
|
||||
Yields:
|
||||
List of (chunk, file_path) tuples representing a batch
|
||||
"""
|
||||
current_batch = []
|
||||
current_tokens = 0
|
||||
|
||||
for chunk, file_path in chunk_generator:
|
||||
# Fast token estimation: len(content) // 4
|
||||
chunk_tokens = len(chunk.content) // 4
|
||||
|
||||
# If adding this chunk would exceed limit and we have items, yield current batch
|
||||
if current_tokens + chunk_tokens > max_tokens_per_batch and current_batch:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
current_tokens = 0
|
||||
|
||||
# Add chunk to current batch
|
||||
current_batch.append((chunk, file_path))
|
||||
current_tokens += chunk_tokens
|
||||
|
||||
# Yield final batch if not empty
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
def _get_path_column(conn: sqlite3.Connection) -> str:
|
||||
"""Detect whether files table uses 'path' or 'full_path' column.
|
||||
|
||||
@@ -189,31 +229,69 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]:
|
||||
}
|
||||
|
||||
|
||||
def _get_embedding_defaults() -> tuple[str, str, bool]:
|
||||
"""Get default embedding settings from config.
|
||||
|
||||
Returns:
|
||||
Tuple of (backend, model, use_gpu)
|
||||
"""
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
config = Config.load()
|
||||
return config.embedding_backend, config.embedding_model, config.embedding_use_gpu
|
||||
except Exception:
|
||||
return "fastembed", "code", True
|
||||
|
||||
|
||||
def generate_embeddings(
|
||||
index_path: Path,
|
||||
embedding_backend: str = "fastembed",
|
||||
model_profile: str = "code",
|
||||
embedding_backend: Optional[str] = None,
|
||||
model_profile: Optional[str] = None,
|
||||
force: bool = False,
|
||||
chunk_size: int = 2000,
|
||||
overlap: int = 200,
|
||||
progress_callback: Optional[callable] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: int = 1,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for an index using memory-efficient batch processing.
|
||||
|
||||
This function processes files in small batches to keep memory usage under 2GB,
|
||||
regardless of the total project size.
|
||||
regardless of the total project size. Supports concurrent API calls for
|
||||
LiteLLM backend to improve throughput.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm)
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm).
|
||||
Defaults to config setting.
|
||||
model_profile: Model profile for fastembed (fast, code, multilingual, balanced)
|
||||
or model name for litellm (e.g., text-embedding-3-small)
|
||||
or model name for litellm (e.g., qwen3-embedding).
|
||||
Defaults to config setting.
|
||||
force: If True, regenerate even if embeddings exist
|
||||
chunk_size: Maximum chunk size in characters
|
||||
overlap: Overlap size in characters for sliding window chunking (default: 200)
|
||||
progress_callback: Optional callback for progress updates
|
||||
use_gpu: Whether to use GPU acceleration (fastembed only).
|
||||
Defaults to config setting.
|
||||
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
|
||||
Recommended: 2-4 for LiteLLM API backends.
|
||||
|
||||
Returns:
|
||||
Result dictionary with generation statistics
|
||||
"""
|
||||
# Get defaults from config if not specified
|
||||
default_backend, default_model, default_gpu = _get_embedding_defaults()
|
||||
|
||||
if embedding_backend is None:
|
||||
embedding_backend = default_backend
|
||||
if model_profile is None:
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -261,9 +339,9 @@ def generate_embeddings(
|
||||
|
||||
# Initialize embedder using factory (supports both fastembed and litellm)
|
||||
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
|
||||
# For litellm: model_profile is a model name (e.g., text-embedding-3-small)
|
||||
# For litellm: model_profile is a model name (e.g., qwen3-embedding)
|
||||
if embedding_backend == "fastembed":
|
||||
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=True)
|
||||
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu)
|
||||
elif embedding_backend == "litellm":
|
||||
embedder = get_embedder_factory(backend="litellm", model=model_profile)
|
||||
else:
|
||||
@@ -274,7 +352,11 @@ def generate_embeddings(
|
||||
|
||||
# skip_token_count=True: Use fast estimation (len/4) instead of expensive tiktoken
|
||||
# This significantly reduces CPU usage with minimal impact on metadata accuracy
|
||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size, skip_token_count=True))
|
||||
chunker = Chunker(config=ChunkConfig(
|
||||
max_chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
skip_token_count=True
|
||||
))
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
|
||||
@@ -336,43 +418,105 @@ def generate_embeddings(
|
||||
cursor, chunker, path_column, FILE_BATCH_SIZE, failed_files
|
||||
)
|
||||
|
||||
# Determine max tokens per batch
|
||||
# Priority: explicit parameter > embedder.max_tokens > default 8000
|
||||
if max_tokens_per_batch is None:
|
||||
max_tokens_per_batch = getattr(embedder, 'max_tokens', 8000)
|
||||
|
||||
# Create token-aware batches or fall back to fixed-size batching
|
||||
if max_tokens_per_batch:
|
||||
batch_generator = _create_token_aware_batches(
|
||||
chunk_generator, max_tokens_per_batch
|
||||
)
|
||||
else:
|
||||
# Fallback to fixed-size batching for backward compatibility
|
||||
def fixed_size_batches():
|
||||
while True:
|
||||
batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
batch_generator = fixed_size_batches()
|
||||
|
||||
batch_number = 0
|
||||
files_seen = set()
|
||||
|
||||
while True:
|
||||
# Get a small batch of chunks from the generator (EMBEDDING_BATCH_SIZE at a time)
|
||||
chunk_batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
|
||||
if not chunk_batch:
|
||||
break
|
||||
# Thread-safe counters for concurrent processing
|
||||
counter_lock = Lock()
|
||||
|
||||
batch_number += 1
|
||||
def process_batch(batch_data: Tuple[int, List[Tuple]]) -> Tuple[int, set, Optional[str]]:
|
||||
"""Process a single batch: generate embeddings and store.
|
||||
|
||||
# Track unique files for progress
|
||||
for _, file_path in chunk_batch:
|
||||
files_seen.add(file_path)
|
||||
Args:
|
||||
batch_data: Tuple of (batch_number, chunk_batch)
|
||||
|
||||
Returns:
|
||||
Tuple of (chunks_created, files_in_batch, error_message)
|
||||
"""
|
||||
batch_num, chunk_batch = batch_data
|
||||
batch_files = set()
|
||||
|
||||
# Generate embeddings directly to numpy (no tolist() conversion)
|
||||
try:
|
||||
# Track files in this batch
|
||||
for _, file_path in chunk_batch:
|
||||
batch_files.add(file_path)
|
||||
|
||||
# Generate embeddings
|
||||
batch_contents = [chunk.content for chunk, _ in chunk_batch]
|
||||
# Pass batch_size to fastembed for optimal GPU utilization
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
|
||||
|
||||
# Use add_chunks_batch_numpy to avoid numpy->list->numpy roundtrip
|
||||
# Store embeddings (thread-safe via SQLite's serialized mode)
|
||||
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
|
||||
|
||||
total_chunks_created += len(chunk_batch)
|
||||
total_files_processed = len(files_seen)
|
||||
|
||||
if progress_callback and batch_number % 10 == 0:
|
||||
progress_callback(f" Batch {batch_number}: {total_chunks_created} chunks, {total_files_processed} files")
|
||||
|
||||
# Cleanup intermediate data
|
||||
del batch_contents, embeddings_numpy, chunk_batch
|
||||
return len(chunk_batch), batch_files, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process embedding batch {batch_number}: {str(e)}")
|
||||
# Continue to next batch instead of failing entirely
|
||||
continue
|
||||
error_msg = f"Batch {batch_num}: {str(e)}"
|
||||
logger.error(f"Failed to process embedding batch {batch_num}: {str(e)}")
|
||||
return 0, batch_files, error_msg
|
||||
|
||||
# Collect batches for concurrent processing
|
||||
all_batches = []
|
||||
for chunk_batch in batch_generator:
|
||||
batch_number += 1
|
||||
all_batches.append((batch_number, chunk_batch))
|
||||
|
||||
# Process batches (sequential or concurrent based on max_workers)
|
||||
if max_workers <= 1:
|
||||
# Sequential processing (original behavior)
|
||||
for batch_num, chunk_batch in all_batches:
|
||||
chunks_created, batch_files, error = process_batch((batch_num, chunk_batch))
|
||||
files_seen.update(batch_files)
|
||||
total_chunks_created += chunks_created
|
||||
total_files_processed = len(files_seen)
|
||||
|
||||
if progress_callback and batch_num % 10 == 0:
|
||||
progress_callback(f" Batch {batch_num}: {total_chunks_created} chunks, {total_files_processed} files")
|
||||
else:
|
||||
# Concurrent processing for API backends
|
||||
if progress_callback:
|
||||
progress_callback(f"Processing {len(all_batches)} batches with {max_workers} concurrent workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(process_batch, batch): batch[0] for batch in all_batches}
|
||||
|
||||
completed = 0
|
||||
for future in as_completed(futures):
|
||||
batch_num = futures[future]
|
||||
try:
|
||||
chunks_created, batch_files, error = future.result()
|
||||
|
||||
with counter_lock:
|
||||
files_seen.update(batch_files)
|
||||
total_chunks_created += chunks_created
|
||||
total_files_processed = len(files_seen)
|
||||
completed += 1
|
||||
|
||||
if progress_callback and completed % 10 == 0:
|
||||
progress_callback(f" Completed {completed}/{len(all_batches)} batches: {total_chunks_created} chunks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {batch_num} raised exception: {str(e)}")
|
||||
|
||||
# Notify before ANN index finalization (happens when bulk_insert context exits)
|
||||
if progress_callback:
|
||||
@@ -445,26 +589,49 @@ def find_all_indexes(scan_dir: Path) -> List[Path]:
|
||||
|
||||
def generate_embeddings_recursive(
|
||||
index_root: Path,
|
||||
embedding_backend: str = "fastembed",
|
||||
model_profile: str = "code",
|
||||
embedding_backend: Optional[str] = None,
|
||||
model_profile: Optional[str] = None,
|
||||
force: bool = False,
|
||||
chunk_size: int = 2000,
|
||||
overlap: int = 200,
|
||||
progress_callback: Optional[callable] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: int = 1,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for all index databases in a project recursively.
|
||||
|
||||
Args:
|
||||
index_root: Root index directory containing _index.db files
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm)
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm).
|
||||
Defaults to config setting.
|
||||
model_profile: Model profile for fastembed (fast, code, multilingual, balanced)
|
||||
or model name for litellm (e.g., text-embedding-3-small)
|
||||
or model name for litellm (e.g., qwen3-embedding).
|
||||
Defaults to config setting.
|
||||
force: If True, regenerate even if embeddings exist
|
||||
chunk_size: Maximum chunk size in characters
|
||||
overlap: Overlap size in characters for sliding window chunking (default: 200)
|
||||
progress_callback: Optional callback for progress updates
|
||||
use_gpu: Whether to use GPU acceleration (fastembed only).
|
||||
Defaults to config setting.
|
||||
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
|
||||
Recommended: 2-4 for LiteLLM API backends.
|
||||
|
||||
Returns:
|
||||
Aggregated result dictionary with generation statistics
|
||||
"""
|
||||
# Get defaults from config if not specified
|
||||
default_backend, default_model, default_gpu = _get_embedding_defaults()
|
||||
|
||||
if embedding_backend is None:
|
||||
embedding_backend = default_backend
|
||||
if model_profile is None:
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
# Discover all _index.db files
|
||||
index_files = discover_all_index_dbs(index_root)
|
||||
|
||||
@@ -498,7 +665,11 @@ def generate_embeddings_recursive(
|
||||
model_profile=model_profile,
|
||||
force=force,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
progress_callback=None, # Don't cascade callbacks
|
||||
use_gpu=use_gpu,
|
||||
max_tokens_per_batch=max_tokens_per_batch,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
all_results.append({
|
||||
|
||||
Reference in New Issue
Block a user