feat: 添加多端点支持和负载均衡功能,增强 LiteLLM 嵌入管理

This commit is contained in:
catlog22
2025-12-25 11:01:08 +08:00
parent 3c3ce55842
commit 40e61b30d6
7 changed files with 727 additions and 29 deletions

View File

@@ -759,6 +759,16 @@ def status(
console.print(f" Coverage: {embeddings_info['coverage_percent']:.1f}%")
console.print(f" Total Chunks: {embeddings_info['total_chunks']}")
# Display model information if available
model_info = embeddings_info.get('model_info')
if model_info:
console.print("\n[bold]Embedding Model:[/bold]")
console.print(f" Backend: [cyan]{model_info.get('backend', 'unknown')}[/cyan]")
console.print(f" Model: [cyan]{model_info.get('model_profile', 'unknown')}[/cyan] ({model_info.get('model_name', '')})")
console.print(f" Dimensions: {model_info.get('embedding_dim', 'unknown')}")
if model_info.get('updated_at'):
console.print(f" Last Updated: {model_info['updated_at']}")
except StorageError as exc:
if json_mode:
print_json(success=False, error=f"Storage error: {exc}")
@@ -1878,7 +1888,7 @@ def embeddings_generate(
"""
_configure_logging(verbose, json_mode)
from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive
from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive, scan_for_model_conflicts
# Validate backend
valid_backends = ["fastembed", "litellm"]
@@ -1946,6 +1956,50 @@ def embeddings_generate(
console.print(f"Concurrency: [cyan]{max_workers} workers[/cyan]")
console.print()
# Pre-check for model conflicts (only if not forcing)
if not force:
# Determine the index root for conflict scanning
scan_root = index_root if use_recursive else (index_path.parent if index_path else None)
if scan_root:
conflict_result = scan_for_model_conflicts(scan_root, backend, model)
if conflict_result["has_conflict"]:
existing = conflict_result["existing_config"]
conflict_count = len(conflict_result["conflicts"])
if json_mode:
# JSON mode: return structured error for UI handling
print_json(
success=False,
error="Model conflict detected",
code="MODEL_CONFLICT",
existing_config=existing,
target_config=conflict_result["target_config"],
conflict_count=conflict_count,
conflicts=conflict_result["conflicts"][:5], # Show first 5 conflicts
hint="Use --force to overwrite existing embeddings with the new model",
)
raise typer.Exit(code=1)
else:
# Interactive mode: show warning and ask for confirmation
console.print("[yellow]⚠ Model Conflict Detected[/yellow]")
console.print(f" Existing: [red]{existing['backend']}/{existing['model']}[/red] ({existing.get('embedding_dim', '?')} dim)")
console.print(f" Requested: [green]{backend}/{model}[/green]")
console.print(f" Affected indexes: [yellow]{conflict_count}[/yellow]")
console.print()
console.print("[dim]Mixing different embedding models in the same index is not supported.[/dim]")
console.print("[dim]Overwriting will delete all existing embeddings and regenerate with the new model.[/dim]")
console.print()
# Ask for confirmation
if typer.confirm("Overwrite existing embeddings with the new model?", default=False):
force = True
console.print("[green]Confirmed.[/green] Proceeding with overwrite...\n")
else:
console.print("[yellow]Cancelled.[/yellow] Use --force to skip this prompt.")
raise typer.Exit(code=0)
if use_recursive:
result = generate_embeddings_recursive(
index_root,

View File

@@ -235,18 +235,25 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]:
}
def _get_embedding_defaults() -> tuple[str, str, bool]:
def _get_embedding_defaults() -> tuple[str, str, bool, List, str, float]:
"""Get default embedding settings from config.
Returns:
Tuple of (backend, model, use_gpu)
Tuple of (backend, model, use_gpu, endpoints, strategy, cooldown)
"""
try:
from codexlens.config import Config
config = Config.load()
return config.embedding_backend, config.embedding_model, config.embedding_use_gpu
return (
config.embedding_backend,
config.embedding_model,
config.embedding_use_gpu,
config.embedding_endpoints,
config.embedding_strategy,
config.embedding_cooldown,
)
except Exception:
return "fastembed", "code", True
return "fastembed", "code", True, [], "latency_aware", 60.0
def generate_embeddings(
@@ -260,6 +267,9 @@ def generate_embeddings(
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: Optional[int] = None,
endpoints: Optional[List] = None,
strategy: Optional[str] = None,
cooldown: Optional[float] = None,
) -> Dict[str, any]:
"""Generate embeddings for an index using memory-efficient batch processing.
@@ -284,14 +294,18 @@ def generate_embeddings(
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls.
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
4 for litellm (network I/O bound).
If None, uses dynamic defaults based on backend and endpoint count.
endpoints: Optional list of endpoint configurations for multi-API load balancing.
Each dict has keys: model, api_key, api_base, weight.
strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware).
cooldown: Default cooldown seconds for rate-limited endpoints.
Returns:
Result dictionary with generation statistics
"""
# Get defaults from config if not specified
default_backend, default_model, default_gpu = _get_embedding_defaults()
(default_backend, default_model, default_gpu,
default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults()
if embedding_backend is None:
embedding_backend = default_backend
@@ -299,13 +313,26 @@ def generate_embeddings(
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
if endpoints is None:
endpoints = default_endpoints
if strategy is None:
strategy = default_strategy
if cooldown is None:
cooldown = default_cooldown
# Set dynamic max_workers default based on backend type
# Calculate endpoint count for worker scaling
endpoint_count = len(endpoints) if endpoints else 1
# Set dynamic max_workers default based on backend type and endpoint count
# - FastEmbed: CPU-bound, sequential is optimal (1 worker)
# - LiteLLM: Network I/O bound, concurrent calls improve throughput (4 workers)
# - LiteLLM single endpoint: 4 workers default
# - LiteLLM multi-endpoint: workers = endpoint_count * 2 (to saturate all APIs)
if max_workers is None:
if embedding_backend == "litellm":
max_workers = 4
if endpoint_count > 1:
max_workers = min(endpoint_count * 2, 16) # Cap at 16 workers
else:
max_workers = 4
else:
max_workers = 1
@@ -354,13 +381,20 @@ def generate_embeddings(
from codexlens.semantic.vector_store import VectorStore
from codexlens.semantic.chunker import Chunker, ChunkConfig
# Initialize embedder using factory (supports both fastembed and litellm)
# Initialize embedder using factory (supports fastembed, litellm, and rotational)
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
# For litellm: model_profile is a model name (e.g., qwen3-embedding)
# For multi-endpoint: endpoints list enables load balancing
if embedding_backend == "fastembed":
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu)
elif embedding_backend == "litellm":
embedder = get_embedder_factory(backend="litellm", model=model_profile)
embedder = get_embedder_factory(
backend="litellm",
model=model_profile,
endpoints=endpoints if endpoints else None,
strategy=strategy,
cooldown=cooldown,
)
else:
return {
"success": False,
@@ -375,7 +409,10 @@ def generate_embeddings(
skip_token_count=True
))
# Log embedder info with endpoint count for multi-endpoint mode
if progress_callback:
if endpoint_count > 1:
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
except Exception as e:
@@ -684,6 +721,9 @@ def generate_embeddings_recursive(
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: Optional[int] = None,
endpoints: Optional[List] = None,
strategy: Optional[str] = None,
cooldown: Optional[float] = None,
) -> Dict[str, any]:
"""Generate embeddings for all index databases in a project recursively.
@@ -704,14 +744,17 @@ def generate_embeddings_recursive(
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls.
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
4 for litellm (network I/O bound).
If None, uses dynamic defaults based on backend and endpoint count.
endpoints: Optional list of endpoint configurations for multi-API load balancing.
strategy: Selection strategy for multi-endpoint mode.
cooldown: Default cooldown seconds for rate-limited endpoints.
Returns:
Aggregated result dictionary with generation statistics
"""
# Get defaults from config if not specified
default_backend, default_model, default_gpu = _get_embedding_defaults()
(default_backend, default_model, default_gpu,
default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults()
if embedding_backend is None:
embedding_backend = default_backend
@@ -719,11 +762,23 @@ def generate_embeddings_recursive(
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
if endpoints is None:
endpoints = default_endpoints
if strategy is None:
strategy = default_strategy
if cooldown is None:
cooldown = default_cooldown
# Set dynamic max_workers default based on backend type
# Calculate endpoint count for worker scaling
endpoint_count = len(endpoints) if endpoints else 1
# Set dynamic max_workers default based on backend type and endpoint count
if max_workers is None:
if embedding_backend == "litellm":
max_workers = 4
if endpoint_count > 1:
max_workers = min(endpoint_count * 2, 16)
else:
max_workers = 4
else:
max_workers = 1
@@ -765,6 +820,9 @@ def generate_embeddings_recursive(
use_gpu=use_gpu,
max_tokens_per_batch=max_tokens_per_batch,
max_workers=max_workers,
endpoints=endpoints,
strategy=strategy,
cooldown=cooldown,
)
all_results.append({
@@ -958,3 +1016,85 @@ def get_embedding_stats_summary(index_root: Path) -> Dict[str, any]:
"indexes": index_stats,
},
}
def scan_for_model_conflicts(
index_root: Path,
target_backend: str,
target_model: str,
) -> Dict[str, any]:
"""Scan for model conflicts across all indexes in a directory.
Checks if any existing embeddings were generated with a different
backend or model than the target configuration.
Args:
index_root: Root index directory to scan
target_backend: Target embedding backend (fastembed or litellm)
target_model: Target model profile/name
Returns:
Dictionary with:
- has_conflict: True if any index has different model config
- existing_config: Config from first index with embeddings (if any)
- target_config: The requested configuration
- conflicts: List of conflicting index paths with their configs
- indexes_with_embeddings: Count of indexes that have embeddings
"""
index_files = discover_all_index_dbs(index_root)
if not index_files:
return {
"has_conflict": False,
"existing_config": None,
"target_config": {"backend": target_backend, "model": target_model},
"conflicts": [],
"indexes_with_embeddings": 0,
}
conflicts = []
existing_config = None
indexes_with_embeddings = 0
for index_path in index_files:
try:
from codexlens.semantic.vector_store import VectorStore
with VectorStore(index_path) as vs:
config = vs.get_model_config()
if config and config.get("model_profile"):
indexes_with_embeddings += 1
# Store first existing config as reference
if existing_config is None:
existing_config = {
"backend": config.get("backend"),
"model": config.get("model_profile"),
"model_name": config.get("model_name"),
"embedding_dim": config.get("embedding_dim"),
}
# Check for conflict: different backend OR different model
existing_backend = config.get("backend", "")
existing_model = config.get("model_profile", "")
if existing_backend != target_backend or existing_model != target_model:
conflicts.append({
"path": str(index_path),
"existing": {
"backend": existing_backend,
"model": existing_model,
"model_name": config.get("model_name"),
},
})
except Exception as e:
logger.debug(f"Failed to check model config for {index_path}: {e}")
continue
return {
"has_conflict": len(conflicts) > 0,
"existing_config": existing_config,
"target_config": {"backend": target_backend, "model": target_model},
"conflicts": conflicts,
"indexes_with_embeddings": indexes_with_embeddings,
}

View File

@@ -35,12 +35,23 @@ def _to_jsonable(value: Any) -> Any:
return value
def print_json(*, success: bool, result: Any = None, error: str | None = None) -> None:
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
"""Print JSON output with optional additional fields.
Args:
success: Whether the operation succeeded
result: Result data (used when success=True)
error: Error message (used when success=False)
**kwargs: Additional fields to include in the payload (e.g., code, details)
"""
payload: dict[str, Any] = {"success": success}
if success:
payload["result"] = _to_jsonable(result)
else:
payload["error"] = error or "Unknown error"
# Include additional error details if provided
for key, value in kwargs.items():
payload[key] = _to_jsonable(value)
console.print_json(json.dumps(payload, ensure_ascii=False))