feat: 添加动态批量大小计算,优化嵌入管理和配置系统

This commit is contained in:
catlog22
2026-01-12 17:34:37 +08:00
parent b360e0edc7
commit 90a1321aac
6 changed files with 425 additions and 72 deletions

View File

@@ -43,6 +43,73 @@ logger = logging.getLogger(__name__)
EMBEDDING_BATCH_SIZE = 256
def calculate_dynamic_batch_size(config, embedder) -> int:
"""Calculate batch size dynamically based on model token capacity.
This function computes an optimal batch size by considering:
- Maximum chunk character size from parsing rules
- Estimated tokens per chunk (chars / chars_per_token_estimate)
- Model's maximum token capacity
- Utilization factor (default 80% to leave headroom)
Args:
config: Config object with api_batch_size_* settings
embedder: Embedding model object with max_tokens property
Returns:
Calculated batch size, clamped to [1, api_batch_size_max]
"""
# If dynamic calculation is disabled, return static value
if not getattr(config, 'api_batch_size_dynamic', False):
return getattr(config, 'api_batch_size', 8)
# Get maximum chunk character size from parsing rules
parsing_rules = getattr(config, 'parsing_rules', {})
default_rules = parsing_rules.get('default', {})
max_chunk_chars = default_rules.get('max_chunk_chars', 4000)
# Get characters per token estimate
chars_per_token = getattr(config, 'chars_per_token_estimate', 4)
if chars_per_token <= 0:
chars_per_token = 4 # Safe default
# Estimate tokens per chunk
estimated_tokens_per_chunk = max_chunk_chars / chars_per_token
# Prevent division by zero
if estimated_tokens_per_chunk <= 0:
return getattr(config, 'api_batch_size', 8)
# Get model's maximum token capacity
model_max_tokens = getattr(embedder, 'max_tokens', 8192)
# Get utilization factor (default 80%)
utilization_factor = getattr(config, 'api_batch_size_utilization_factor', 0.8)
if utilization_factor <= 0 or utilization_factor > 1:
utilization_factor = 0.8
# Calculate safe token limit
safe_token_limit = model_max_tokens * utilization_factor
# Calculate dynamic batch size
dynamic_batch_size = int(safe_token_limit / estimated_tokens_per_chunk)
# Get maximum batch size limit
batch_size_max = getattr(config, 'api_batch_size_max', 2048)
# Clamp to [1, batch_size_max]
result = max(1, min(dynamic_batch_size, batch_size_max))
logger.debug(
"Dynamic batch size calculated: %d (max_chunk_chars=%d, chars_per_token=%d, "
"model_max_tokens=%d, utilization=%.1f%%, limit=%d)",
result, max_chunk_chars, chars_per_token, model_max_tokens,
utilization_factor * 100, batch_size_max
)
return result
def _build_categories_from_batch(chunk_batch: List[Tuple[Any, str]]) -> List[str]:
"""Build categories list from chunk batch for index-level category filtering.
@@ -464,6 +531,14 @@ def generate_embeddings(
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
# Calculate dynamic batch size based on model capacity
from codexlens.config import Config
batch_config = Config.load()
effective_batch_size = calculate_dynamic_batch_size(batch_config, embedder)
if progress_callback and batch_config.api_batch_size_dynamic:
progress_callback(f"Dynamic batch size: {effective_batch_size} (model max_tokens={getattr(embedder, 'max_tokens', 8192)})")
except Exception as e:
return {
"success": False,
@@ -477,7 +552,7 @@ def generate_embeddings(
total_chunks_created = 0
total_files_processed = 0
FILE_BATCH_SIZE = 100 # Process 100 files at a time
# EMBEDDING_BATCH_SIZE is defined at module level (default: 256)
# effective_batch_size is calculated above (dynamic or EMBEDDING_BATCH_SIZE fallback)
try:
with VectorStore(index_path) as vector_store:
@@ -535,7 +610,7 @@ def generate_embeddings(
# Fallback to fixed-size batching for backward compatibility
def fixed_size_batches():
while True:
batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
batch = list(islice(chunk_generator, effective_batch_size))
if not batch:
break
yield batch
@@ -566,7 +641,7 @@ def generate_embeddings(
for attempt in range(max_retries + 1):
try:
batch_contents = [chunk.content for chunk, _ in chunk_batch]
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=effective_batch_size)
return batch_num, chunk_batch, embeddings_numpy, batch_files, None
except Exception as e:
@@ -614,7 +689,7 @@ def generate_embeddings(
try:
# Generate embeddings
batch_contents = [chunk.content for chunk, _ in chunk_batch]
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=effective_batch_size)
# Store embeddings with category
categories = _build_categories_from_batch(chunk_batch)
@@ -1227,6 +1302,14 @@ def generate_dense_embeddings_centralized(
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
# Calculate dynamic batch size based on model capacity
from codexlens.config import Config
batch_config = Config.load()
effective_batch_size = calculate_dynamic_batch_size(batch_config, embedder)
if progress_callback and batch_config.api_batch_size_dynamic:
progress_callback(f"Dynamic batch size: {effective_batch_size} (model max_tokens={getattr(embedder, 'max_tokens', 8192)})")
except Exception as e:
return {
"success": False,