mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
feat: 添加动态批量大小计算,优化嵌入管理和配置系统
This commit is contained in:
@@ -43,6 +43,73 @@ logger = logging.getLogger(__name__)
|
||||
EMBEDDING_BATCH_SIZE = 256
|
||||
|
||||
|
||||
def calculate_dynamic_batch_size(config, embedder) -> int:
|
||||
"""Calculate batch size dynamically based on model token capacity.
|
||||
|
||||
This function computes an optimal batch size by considering:
|
||||
- Maximum chunk character size from parsing rules
|
||||
- Estimated tokens per chunk (chars / chars_per_token_estimate)
|
||||
- Model's maximum token capacity
|
||||
- Utilization factor (default 80% to leave headroom)
|
||||
|
||||
Args:
|
||||
config: Config object with api_batch_size_* settings
|
||||
embedder: Embedding model object with max_tokens property
|
||||
|
||||
Returns:
|
||||
Calculated batch size, clamped to [1, api_batch_size_max]
|
||||
"""
|
||||
# If dynamic calculation is disabled, return static value
|
||||
if not getattr(config, 'api_batch_size_dynamic', False):
|
||||
return getattr(config, 'api_batch_size', 8)
|
||||
|
||||
# Get maximum chunk character size from parsing rules
|
||||
parsing_rules = getattr(config, 'parsing_rules', {})
|
||||
default_rules = parsing_rules.get('default', {})
|
||||
max_chunk_chars = default_rules.get('max_chunk_chars', 4000)
|
||||
|
||||
# Get characters per token estimate
|
||||
chars_per_token = getattr(config, 'chars_per_token_estimate', 4)
|
||||
if chars_per_token <= 0:
|
||||
chars_per_token = 4 # Safe default
|
||||
|
||||
# Estimate tokens per chunk
|
||||
estimated_tokens_per_chunk = max_chunk_chars / chars_per_token
|
||||
|
||||
# Prevent division by zero
|
||||
if estimated_tokens_per_chunk <= 0:
|
||||
return getattr(config, 'api_batch_size', 8)
|
||||
|
||||
# Get model's maximum token capacity
|
||||
model_max_tokens = getattr(embedder, 'max_tokens', 8192)
|
||||
|
||||
# Get utilization factor (default 80%)
|
||||
utilization_factor = getattr(config, 'api_batch_size_utilization_factor', 0.8)
|
||||
if utilization_factor <= 0 or utilization_factor > 1:
|
||||
utilization_factor = 0.8
|
||||
|
||||
# Calculate safe token limit
|
||||
safe_token_limit = model_max_tokens * utilization_factor
|
||||
|
||||
# Calculate dynamic batch size
|
||||
dynamic_batch_size = int(safe_token_limit / estimated_tokens_per_chunk)
|
||||
|
||||
# Get maximum batch size limit
|
||||
batch_size_max = getattr(config, 'api_batch_size_max', 2048)
|
||||
|
||||
# Clamp to [1, batch_size_max]
|
||||
result = max(1, min(dynamic_batch_size, batch_size_max))
|
||||
|
||||
logger.debug(
|
||||
"Dynamic batch size calculated: %d (max_chunk_chars=%d, chars_per_token=%d, "
|
||||
"model_max_tokens=%d, utilization=%.1f%%, limit=%d)",
|
||||
result, max_chunk_chars, chars_per_token, model_max_tokens,
|
||||
utilization_factor * 100, batch_size_max
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _build_categories_from_batch(chunk_batch: List[Tuple[Any, str]]) -> List[str]:
|
||||
"""Build categories list from chunk batch for index-level category filtering.
|
||||
|
||||
@@ -464,6 +531,14 @@ def generate_embeddings(
|
||||
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
|
||||
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
|
||||
|
||||
# Calculate dynamic batch size based on model capacity
|
||||
from codexlens.config import Config
|
||||
batch_config = Config.load()
|
||||
effective_batch_size = calculate_dynamic_batch_size(batch_config, embedder)
|
||||
|
||||
if progress_callback and batch_config.api_batch_size_dynamic:
|
||||
progress_callback(f"Dynamic batch size: {effective_batch_size} (model max_tokens={getattr(embedder, 'max_tokens', 8192)})")
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -477,7 +552,7 @@ def generate_embeddings(
|
||||
total_chunks_created = 0
|
||||
total_files_processed = 0
|
||||
FILE_BATCH_SIZE = 100 # Process 100 files at a time
|
||||
# EMBEDDING_BATCH_SIZE is defined at module level (default: 256)
|
||||
# effective_batch_size is calculated above (dynamic or EMBEDDING_BATCH_SIZE fallback)
|
||||
|
||||
try:
|
||||
with VectorStore(index_path) as vector_store:
|
||||
@@ -535,7 +610,7 @@ def generate_embeddings(
|
||||
# Fallback to fixed-size batching for backward compatibility
|
||||
def fixed_size_batches():
|
||||
while True:
|
||||
batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
|
||||
batch = list(islice(chunk_generator, effective_batch_size))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
@@ -566,7 +641,7 @@ def generate_embeddings(
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
batch_contents = [chunk.content for chunk, _ in chunk_batch]
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=effective_batch_size)
|
||||
return batch_num, chunk_batch, embeddings_numpy, batch_files, None
|
||||
|
||||
except Exception as e:
|
||||
@@ -614,7 +689,7 @@ def generate_embeddings(
|
||||
try:
|
||||
# Generate embeddings
|
||||
batch_contents = [chunk.content for chunk, _ in chunk_batch]
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=effective_batch_size)
|
||||
|
||||
# Store embeddings with category
|
||||
categories = _build_categories_from_batch(chunk_batch)
|
||||
@@ -1227,6 +1302,14 @@ def generate_dense_embeddings_centralized(
|
||||
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
|
||||
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
|
||||
|
||||
# Calculate dynamic batch size based on model capacity
|
||||
from codexlens.config import Config
|
||||
batch_config = Config.load()
|
||||
effective_batch_size = calculate_dynamic_batch_size(batch_config, embedder)
|
||||
|
||||
if progress_callback and batch_config.api_batch_size_dynamic:
|
||||
progress_callback(f"Dynamic batch size: {effective_batch_size} (model max_tokens={getattr(embedder, 'max_tokens', 8192)})")
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
|
||||
@@ -170,6 +170,10 @@ class Config:
|
||||
# API concurrency settings
|
||||
api_max_workers: int = 4 # Max concurrent API calls for embedding/reranking
|
||||
api_batch_size: int = 8 # Batch size for API requests
|
||||
api_batch_size_dynamic: bool = False # Enable dynamic batch size calculation
|
||||
api_batch_size_utilization_factor: float = 0.8 # Use 80% of model token capacity
|
||||
api_batch_size_max: int = 2048 # Absolute upper limit for batch size
|
||||
chars_per_token_estimate: int = 4 # Characters per token estimation ratio
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
try:
|
||||
@@ -291,6 +295,10 @@ class Config:
|
||||
"api": {
|
||||
"max_workers": self.api_max_workers,
|
||||
"batch_size": self.api_batch_size,
|
||||
"batch_size_dynamic": self.api_batch_size_dynamic,
|
||||
"batch_size_utilization_factor": self.api_batch_size_utilization_factor,
|
||||
"batch_size_max": self.api_batch_size_max,
|
||||
"chars_per_token_estimate": self.chars_per_token_estimate,
|
||||
},
|
||||
}
|
||||
with open(self.settings_path, "w", encoding="utf-8") as f:
|
||||
@@ -309,13 +317,16 @@ class Config:
|
||||
embedding = settings.get("embedding", {})
|
||||
if "backend" in embedding:
|
||||
backend = embedding["backend"]
|
||||
# Support 'api' as alias for 'litellm'
|
||||
if backend == "api":
|
||||
backend = "litellm"
|
||||
if backend in {"fastembed", "litellm"}:
|
||||
self.embedding_backend = backend
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid embedding backend in %s: %r (expected 'fastembed' or 'litellm')",
|
||||
self.settings_path,
|
||||
backend,
|
||||
embedding["backend"],
|
||||
)
|
||||
if "model" in embedding:
|
||||
self.embedding_model = embedding["model"]
|
||||
@@ -393,6 +404,14 @@ class Config:
|
||||
self.api_max_workers = api["max_workers"]
|
||||
if "batch_size" in api:
|
||||
self.api_batch_size = api["batch_size"]
|
||||
if "batch_size_dynamic" in api:
|
||||
self.api_batch_size_dynamic = api["batch_size_dynamic"]
|
||||
if "batch_size_utilization_factor" in api:
|
||||
self.api_batch_size_utilization_factor = api["batch_size_utilization_factor"]
|
||||
if "batch_size_max" in api:
|
||||
self.api_batch_size_max = api["batch_size_max"]
|
||||
if "chars_per_token_estimate" in api:
|
||||
self.chars_per_token_estimate = api["chars_per_token_estimate"]
|
||||
except Exception as exc:
|
||||
log.warning(
|
||||
"Failed to load settings from %s (%s): %s",
|
||||
@@ -409,7 +428,7 @@ class Config:
|
||||
|
||||
Priority: default → settings.json → .env (highest)
|
||||
|
||||
Supported variables:
|
||||
Supported variables (with or without CODEXLENS_ prefix):
|
||||
EMBEDDING_MODEL: Override embedding model/profile
|
||||
EMBEDDING_BACKEND: Override embedding backend (fastembed/litellm)
|
||||
EMBEDDING_POOL_ENABLED: Enable embedding high availability pool
|
||||
@@ -428,83 +447,103 @@ class Config:
|
||||
if not env_vars:
|
||||
return
|
||||
|
||||
def get_env(key: str) -> str | None:
|
||||
"""Get env var with or without CODEXLENS_ prefix."""
|
||||
# Check prefixed version first (Dashboard format), then unprefixed
|
||||
return env_vars.get(f"CODEXLENS_{key}") or env_vars.get(key)
|
||||
|
||||
# Embedding overrides
|
||||
if "EMBEDDING_MODEL" in env_vars:
|
||||
self.embedding_model = env_vars["EMBEDDING_MODEL"]
|
||||
embedding_model = get_env("EMBEDDING_MODEL")
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
log.debug("Overriding embedding_model from .env: %s", self.embedding_model)
|
||||
|
||||
if "EMBEDDING_BACKEND" in env_vars:
|
||||
backend = env_vars["EMBEDDING_BACKEND"].lower()
|
||||
embedding_backend = get_env("EMBEDDING_BACKEND")
|
||||
if embedding_backend:
|
||||
backend = embedding_backend.lower()
|
||||
# Support 'api' as alias for 'litellm'
|
||||
if backend == "api":
|
||||
backend = "litellm"
|
||||
if backend in {"fastembed", "litellm"}:
|
||||
self.embedding_backend = backend
|
||||
log.debug("Overriding embedding_backend from .env: %s", backend)
|
||||
else:
|
||||
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", backend)
|
||||
log.warning("Invalid EMBEDDING_BACKEND in .env: %r", embedding_backend)
|
||||
|
||||
if "EMBEDDING_POOL_ENABLED" in env_vars:
|
||||
value = env_vars["EMBEDDING_POOL_ENABLED"].lower()
|
||||
embedding_pool = get_env("EMBEDDING_POOL_ENABLED")
|
||||
if embedding_pool:
|
||||
value = embedding_pool.lower()
|
||||
self.embedding_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding embedding_pool_enabled from .env: %s", self.embedding_pool_enabled)
|
||||
|
||||
if "EMBEDDING_STRATEGY" in env_vars:
|
||||
strategy = env_vars["EMBEDDING_STRATEGY"].lower()
|
||||
embedding_strategy = get_env("EMBEDDING_STRATEGY")
|
||||
if embedding_strategy:
|
||||
strategy = embedding_strategy.lower()
|
||||
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||
self.embedding_strategy = strategy
|
||||
log.debug("Overriding embedding_strategy from .env: %s", strategy)
|
||||
else:
|
||||
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", strategy)
|
||||
log.warning("Invalid EMBEDDING_STRATEGY in .env: %r", embedding_strategy)
|
||||
|
||||
if "EMBEDDING_COOLDOWN" in env_vars:
|
||||
embedding_cooldown = get_env("EMBEDDING_COOLDOWN")
|
||||
if embedding_cooldown:
|
||||
try:
|
||||
self.embedding_cooldown = float(env_vars["EMBEDDING_COOLDOWN"])
|
||||
self.embedding_cooldown = float(embedding_cooldown)
|
||||
log.debug("Overriding embedding_cooldown from .env: %s", self.embedding_cooldown)
|
||||
except ValueError:
|
||||
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", env_vars["EMBEDDING_COOLDOWN"])
|
||||
log.warning("Invalid EMBEDDING_COOLDOWN in .env: %r", embedding_cooldown)
|
||||
|
||||
# Reranker overrides
|
||||
if "RERANKER_MODEL" in env_vars:
|
||||
self.reranker_model = env_vars["RERANKER_MODEL"]
|
||||
reranker_model = get_env("RERANKER_MODEL")
|
||||
if reranker_model:
|
||||
self.reranker_model = reranker_model
|
||||
log.debug("Overriding reranker_model from .env: %s", self.reranker_model)
|
||||
|
||||
if "RERANKER_BACKEND" in env_vars:
|
||||
backend = env_vars["RERANKER_BACKEND"].lower()
|
||||
reranker_backend = get_env("RERANKER_BACKEND")
|
||||
if reranker_backend:
|
||||
backend = reranker_backend.lower()
|
||||
if backend in {"fastembed", "onnx", "api", "litellm", "legacy"}:
|
||||
self.reranker_backend = backend
|
||||
log.debug("Overriding reranker_backend from .env: %s", backend)
|
||||
else:
|
||||
log.warning("Invalid RERANKER_BACKEND in .env: %r", backend)
|
||||
log.warning("Invalid RERANKER_BACKEND in .env: %r", reranker_backend)
|
||||
|
||||
if "RERANKER_ENABLED" in env_vars:
|
||||
value = env_vars["RERANKER_ENABLED"].lower()
|
||||
reranker_enabled = get_env("RERANKER_ENABLED")
|
||||
if reranker_enabled:
|
||||
value = reranker_enabled.lower()
|
||||
self.enable_cross_encoder_rerank = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding reranker_enabled from .env: %s", self.enable_cross_encoder_rerank)
|
||||
|
||||
if "RERANKER_POOL_ENABLED" in env_vars:
|
||||
value = env_vars["RERANKER_POOL_ENABLED"].lower()
|
||||
reranker_pool = get_env("RERANKER_POOL_ENABLED")
|
||||
if reranker_pool:
|
||||
value = reranker_pool.lower()
|
||||
self.reranker_pool_enabled = value in {"true", "1", "yes", "on"}
|
||||
log.debug("Overriding reranker_pool_enabled from .env: %s", self.reranker_pool_enabled)
|
||||
|
||||
if "RERANKER_STRATEGY" in env_vars:
|
||||
strategy = env_vars["RERANKER_STRATEGY"].lower()
|
||||
reranker_strategy = get_env("RERANKER_STRATEGY")
|
||||
if reranker_strategy:
|
||||
strategy = reranker_strategy.lower()
|
||||
if strategy in {"round_robin", "latency_aware", "weighted_random"}:
|
||||
self.reranker_strategy = strategy
|
||||
log.debug("Overriding reranker_strategy from .env: %s", strategy)
|
||||
else:
|
||||
log.warning("Invalid RERANKER_STRATEGY in .env: %r", strategy)
|
||||
log.warning("Invalid RERANKER_STRATEGY in .env: %r", reranker_strategy)
|
||||
|
||||
if "RERANKER_COOLDOWN" in env_vars:
|
||||
reranker_cooldown = get_env("RERANKER_COOLDOWN")
|
||||
if reranker_cooldown:
|
||||
try:
|
||||
self.reranker_cooldown = float(env_vars["RERANKER_COOLDOWN"])
|
||||
self.reranker_cooldown = float(reranker_cooldown)
|
||||
log.debug("Overriding reranker_cooldown from .env: %s", self.reranker_cooldown)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", env_vars["RERANKER_COOLDOWN"])
|
||||
log.warning("Invalid RERANKER_COOLDOWN in .env: %r", reranker_cooldown)
|
||||
|
||||
if "RERANKER_MAX_INPUT_TOKENS" in env_vars:
|
||||
reranker_max_tokens = get_env("RERANKER_MAX_INPUT_TOKENS")
|
||||
if reranker_max_tokens:
|
||||
try:
|
||||
self.reranker_max_input_tokens = int(env_vars["RERANKER_MAX_INPUT_TOKENS"])
|
||||
self.reranker_max_input_tokens = int(reranker_max_tokens)
|
||||
log.debug("Overriding reranker_max_input_tokens from .env: %s", self.reranker_max_input_tokens)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", env_vars["RERANKER_MAX_INPUT_TOKENS"])
|
||||
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "Config":
|
||||
|
||||
Reference in New Issue
Block a user