mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: 添加多端点支持和负载均衡功能,增强 LiteLLM 嵌入管理
This commit is contained in:
@@ -290,7 +290,7 @@ interface IndexStatus {
|
||||
file_count?: number;
|
||||
embeddings_coverage_percent?: number;
|
||||
total_chunks?: number;
|
||||
model_info?: ModelInfo;
|
||||
model_info?: ModelInfo | null;
|
||||
warning?: string;
|
||||
}
|
||||
|
||||
@@ -359,7 +359,8 @@ async function checkIndexStatus(path: string = '.'): Promise<IndexStatus> {
|
||||
file_count: status.total_files,
|
||||
embeddings_coverage_percent: embeddingsCoverage,
|
||||
total_chunks: totalChunks,
|
||||
model_info: modelInfo,
|
||||
// Ensure model_info is null instead of undefined so it's included in JSON
|
||||
model_info: modelInfo ?? null,
|
||||
warning,
|
||||
};
|
||||
} catch {
|
||||
|
||||
@@ -759,6 +759,16 @@ def status(
|
||||
console.print(f" Coverage: {embeddings_info['coverage_percent']:.1f}%")
|
||||
console.print(f" Total Chunks: {embeddings_info['total_chunks']}")
|
||||
|
||||
# Display model information if available
|
||||
model_info = embeddings_info.get('model_info')
|
||||
if model_info:
|
||||
console.print("\n[bold]Embedding Model:[/bold]")
|
||||
console.print(f" Backend: [cyan]{model_info.get('backend', 'unknown')}[/cyan]")
|
||||
console.print(f" Model: [cyan]{model_info.get('model_profile', 'unknown')}[/cyan] ({model_info.get('model_name', '')})")
|
||||
console.print(f" Dimensions: {model_info.get('embedding_dim', 'unknown')}")
|
||||
if model_info.get('updated_at'):
|
||||
console.print(f" Last Updated: {model_info['updated_at']}")
|
||||
|
||||
except StorageError as exc:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Storage error: {exc}")
|
||||
@@ -1878,7 +1888,7 @@ def embeddings_generate(
|
||||
"""
|
||||
_configure_logging(verbose, json_mode)
|
||||
|
||||
from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive
|
||||
from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive, scan_for_model_conflicts
|
||||
|
||||
# Validate backend
|
||||
valid_backends = ["fastembed", "litellm"]
|
||||
@@ -1946,6 +1956,50 @@ def embeddings_generate(
|
||||
console.print(f"Concurrency: [cyan]{max_workers} workers[/cyan]")
|
||||
console.print()
|
||||
|
||||
# Pre-check for model conflicts (only if not forcing)
|
||||
if not force:
|
||||
# Determine the index root for conflict scanning
|
||||
scan_root = index_root if use_recursive else (index_path.parent if index_path else None)
|
||||
|
||||
if scan_root:
|
||||
conflict_result = scan_for_model_conflicts(scan_root, backend, model)
|
||||
|
||||
if conflict_result["has_conflict"]:
|
||||
existing = conflict_result["existing_config"]
|
||||
conflict_count = len(conflict_result["conflicts"])
|
||||
|
||||
if json_mode:
|
||||
# JSON mode: return structured error for UI handling
|
||||
print_json(
|
||||
success=False,
|
||||
error="Model conflict detected",
|
||||
code="MODEL_CONFLICT",
|
||||
existing_config=existing,
|
||||
target_config=conflict_result["target_config"],
|
||||
conflict_count=conflict_count,
|
||||
conflicts=conflict_result["conflicts"][:5], # Show first 5 conflicts
|
||||
hint="Use --force to overwrite existing embeddings with the new model",
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
else:
|
||||
# Interactive mode: show warning and ask for confirmation
|
||||
console.print("[yellow]⚠ Model Conflict Detected[/yellow]")
|
||||
console.print(f" Existing: [red]{existing['backend']}/{existing['model']}[/red] ({existing.get('embedding_dim', '?')} dim)")
|
||||
console.print(f" Requested: [green]{backend}/{model}[/green]")
|
||||
console.print(f" Affected indexes: [yellow]{conflict_count}[/yellow]")
|
||||
console.print()
|
||||
console.print("[dim]Mixing different embedding models in the same index is not supported.[/dim]")
|
||||
console.print("[dim]Overwriting will delete all existing embeddings and regenerate with the new model.[/dim]")
|
||||
console.print()
|
||||
|
||||
# Ask for confirmation
|
||||
if typer.confirm("Overwrite existing embeddings with the new model?", default=False):
|
||||
force = True
|
||||
console.print("[green]Confirmed.[/green] Proceeding with overwrite...\n")
|
||||
else:
|
||||
console.print("[yellow]Cancelled.[/yellow] Use --force to skip this prompt.")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
if use_recursive:
|
||||
result = generate_embeddings_recursive(
|
||||
index_root,
|
||||
|
||||
@@ -235,18 +235,25 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]:
|
||||
}
|
||||
|
||||
|
||||
def _get_embedding_defaults() -> tuple[str, str, bool]:
|
||||
def _get_embedding_defaults() -> tuple[str, str, bool, List, str, float]:
|
||||
"""Get default embedding settings from config.
|
||||
|
||||
Returns:
|
||||
Tuple of (backend, model, use_gpu)
|
||||
Tuple of (backend, model, use_gpu, endpoints, strategy, cooldown)
|
||||
"""
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
config = Config.load()
|
||||
return config.embedding_backend, config.embedding_model, config.embedding_use_gpu
|
||||
return (
|
||||
config.embedding_backend,
|
||||
config.embedding_model,
|
||||
config.embedding_use_gpu,
|
||||
config.embedding_endpoints,
|
||||
config.embedding_strategy,
|
||||
config.embedding_cooldown,
|
||||
)
|
||||
except Exception:
|
||||
return "fastembed", "code", True
|
||||
return "fastembed", "code", True, [], "latency_aware", 60.0
|
||||
|
||||
|
||||
def generate_embeddings(
|
||||
@@ -260,6 +267,9 @@ def generate_embeddings(
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
endpoints: Optional[List] = None,
|
||||
strategy: Optional[str] = None,
|
||||
cooldown: Optional[float] = None,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for an index using memory-efficient batch processing.
|
||||
|
||||
@@ -284,14 +294,18 @@ def generate_embeddings(
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls.
|
||||
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
|
||||
4 for litellm (network I/O bound).
|
||||
If None, uses dynamic defaults based on backend and endpoint count.
|
||||
endpoints: Optional list of endpoint configurations for multi-API load balancing.
|
||||
Each dict has keys: model, api_key, api_base, weight.
|
||||
strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware).
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints.
|
||||
|
||||
Returns:
|
||||
Result dictionary with generation statistics
|
||||
"""
|
||||
# Get defaults from config if not specified
|
||||
default_backend, default_model, default_gpu = _get_embedding_defaults()
|
||||
(default_backend, default_model, default_gpu,
|
||||
default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults()
|
||||
|
||||
if embedding_backend is None:
|
||||
embedding_backend = default_backend
|
||||
@@ -299,13 +313,26 @@ def generate_embeddings(
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
if endpoints is None:
|
||||
endpoints = default_endpoints
|
||||
if strategy is None:
|
||||
strategy = default_strategy
|
||||
if cooldown is None:
|
||||
cooldown = default_cooldown
|
||||
|
||||
# Set dynamic max_workers default based on backend type
|
||||
# Calculate endpoint count for worker scaling
|
||||
endpoint_count = len(endpoints) if endpoints else 1
|
||||
|
||||
# Set dynamic max_workers default based on backend type and endpoint count
|
||||
# - FastEmbed: CPU-bound, sequential is optimal (1 worker)
|
||||
# - LiteLLM: Network I/O bound, concurrent calls improve throughput (4 workers)
|
||||
# - LiteLLM single endpoint: 4 workers default
|
||||
# - LiteLLM multi-endpoint: workers = endpoint_count * 2 (to saturate all APIs)
|
||||
if max_workers is None:
|
||||
if embedding_backend == "litellm":
|
||||
max_workers = 4
|
||||
if endpoint_count > 1:
|
||||
max_workers = min(endpoint_count * 2, 16) # Cap at 16 workers
|
||||
else:
|
||||
max_workers = 4
|
||||
else:
|
||||
max_workers = 1
|
||||
|
||||
@@ -354,13 +381,20 @@ def generate_embeddings(
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
from codexlens.semantic.chunker import Chunker, ChunkConfig
|
||||
|
||||
# Initialize embedder using factory (supports both fastembed and litellm)
|
||||
# Initialize embedder using factory (supports fastembed, litellm, and rotational)
|
||||
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
|
||||
# For litellm: model_profile is a model name (e.g., qwen3-embedding)
|
||||
# For multi-endpoint: endpoints list enables load balancing
|
||||
if embedding_backend == "fastembed":
|
||||
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu)
|
||||
elif embedding_backend == "litellm":
|
||||
embedder = get_embedder_factory(backend="litellm", model=model_profile)
|
||||
embedder = get_embedder_factory(
|
||||
backend="litellm",
|
||||
model=model_profile,
|
||||
endpoints=endpoints if endpoints else None,
|
||||
strategy=strategy,
|
||||
cooldown=cooldown,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -375,7 +409,10 @@ def generate_embeddings(
|
||||
skip_token_count=True
|
||||
))
|
||||
|
||||
# Log embedder info with endpoint count for multi-endpoint mode
|
||||
if progress_callback:
|
||||
if endpoint_count > 1:
|
||||
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
|
||||
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
|
||||
|
||||
except Exception as e:
|
||||
@@ -684,6 +721,9 @@ def generate_embeddings_recursive(
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
endpoints: Optional[List] = None,
|
||||
strategy: Optional[str] = None,
|
||||
cooldown: Optional[float] = None,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for all index databases in a project recursively.
|
||||
|
||||
@@ -704,14 +744,17 @@ def generate_embeddings_recursive(
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls.
|
||||
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
|
||||
4 for litellm (network I/O bound).
|
||||
If None, uses dynamic defaults based on backend and endpoint count.
|
||||
endpoints: Optional list of endpoint configurations for multi-API load balancing.
|
||||
strategy: Selection strategy for multi-endpoint mode.
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints.
|
||||
|
||||
Returns:
|
||||
Aggregated result dictionary with generation statistics
|
||||
"""
|
||||
# Get defaults from config if not specified
|
||||
default_backend, default_model, default_gpu = _get_embedding_defaults()
|
||||
(default_backend, default_model, default_gpu,
|
||||
default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults()
|
||||
|
||||
if embedding_backend is None:
|
||||
embedding_backend = default_backend
|
||||
@@ -719,11 +762,23 @@ def generate_embeddings_recursive(
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
if endpoints is None:
|
||||
endpoints = default_endpoints
|
||||
if strategy is None:
|
||||
strategy = default_strategy
|
||||
if cooldown is None:
|
||||
cooldown = default_cooldown
|
||||
|
||||
# Set dynamic max_workers default based on backend type
|
||||
# Calculate endpoint count for worker scaling
|
||||
endpoint_count = len(endpoints) if endpoints else 1
|
||||
|
||||
# Set dynamic max_workers default based on backend type and endpoint count
|
||||
if max_workers is None:
|
||||
if embedding_backend == "litellm":
|
||||
max_workers = 4
|
||||
if endpoint_count > 1:
|
||||
max_workers = min(endpoint_count * 2, 16)
|
||||
else:
|
||||
max_workers = 4
|
||||
else:
|
||||
max_workers = 1
|
||||
|
||||
@@ -765,6 +820,9 @@ def generate_embeddings_recursive(
|
||||
use_gpu=use_gpu,
|
||||
max_tokens_per_batch=max_tokens_per_batch,
|
||||
max_workers=max_workers,
|
||||
endpoints=endpoints,
|
||||
strategy=strategy,
|
||||
cooldown=cooldown,
|
||||
)
|
||||
|
||||
all_results.append({
|
||||
@@ -958,3 +1016,85 @@ def get_embedding_stats_summary(index_root: Path) -> Dict[str, any]:
|
||||
"indexes": index_stats,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def scan_for_model_conflicts(
|
||||
index_root: Path,
|
||||
target_backend: str,
|
||||
target_model: str,
|
||||
) -> Dict[str, any]:
|
||||
"""Scan for model conflicts across all indexes in a directory.
|
||||
|
||||
Checks if any existing embeddings were generated with a different
|
||||
backend or model than the target configuration.
|
||||
|
||||
Args:
|
||||
index_root: Root index directory to scan
|
||||
target_backend: Target embedding backend (fastembed or litellm)
|
||||
target_model: Target model profile/name
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- has_conflict: True if any index has different model config
|
||||
- existing_config: Config from first index with embeddings (if any)
|
||||
- target_config: The requested configuration
|
||||
- conflicts: List of conflicting index paths with their configs
|
||||
- indexes_with_embeddings: Count of indexes that have embeddings
|
||||
"""
|
||||
index_files = discover_all_index_dbs(index_root)
|
||||
|
||||
if not index_files:
|
||||
return {
|
||||
"has_conflict": False,
|
||||
"existing_config": None,
|
||||
"target_config": {"backend": target_backend, "model": target_model},
|
||||
"conflicts": [],
|
||||
"indexes_with_embeddings": 0,
|
||||
}
|
||||
|
||||
conflicts = []
|
||||
existing_config = None
|
||||
indexes_with_embeddings = 0
|
||||
|
||||
for index_path in index_files:
|
||||
try:
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
|
||||
with VectorStore(index_path) as vs:
|
||||
config = vs.get_model_config()
|
||||
if config and config.get("model_profile"):
|
||||
indexes_with_embeddings += 1
|
||||
|
||||
# Store first existing config as reference
|
||||
if existing_config is None:
|
||||
existing_config = {
|
||||
"backend": config.get("backend"),
|
||||
"model": config.get("model_profile"),
|
||||
"model_name": config.get("model_name"),
|
||||
"embedding_dim": config.get("embedding_dim"),
|
||||
}
|
||||
|
||||
# Check for conflict: different backend OR different model
|
||||
existing_backend = config.get("backend", "")
|
||||
existing_model = config.get("model_profile", "")
|
||||
|
||||
if existing_backend != target_backend or existing_model != target_model:
|
||||
conflicts.append({
|
||||
"path": str(index_path),
|
||||
"existing": {
|
||||
"backend": existing_backend,
|
||||
"model": existing_model,
|
||||
"model_name": config.get("model_name"),
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to check model config for {index_path}: {e}")
|
||||
continue
|
||||
|
||||
return {
|
||||
"has_conflict": len(conflicts) > 0,
|
||||
"existing_config": existing_config,
|
||||
"target_config": {"backend": target_backend, "model": target_model},
|
||||
"conflicts": conflicts,
|
||||
"indexes_with_embeddings": indexes_with_embeddings,
|
||||
}
|
||||
@@ -35,12 +35,23 @@ def _to_jsonable(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def print_json(*, success: bool, result: Any = None, error: str | None = None) -> None:
|
||||
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
|
||||
"""Print JSON output with optional additional fields.
|
||||
|
||||
Args:
|
||||
success: Whether the operation succeeded
|
||||
result: Result data (used when success=True)
|
||||
error: Error message (used when success=False)
|
||||
**kwargs: Additional fields to include in the payload (e.g., code, details)
|
||||
"""
|
||||
payload: dict[str, Any] = {"success": success}
|
||||
if success:
|
||||
payload["result"] = _to_jsonable(result)
|
||||
else:
|
||||
payload["error"] = error or "Unknown error"
|
||||
# Include additional error details if provided
|
||||
for key, value in kwargs.items():
|
||||
payload[key] = _to_jsonable(value)
|
||||
console.print_json(json.dumps(payload, ensure_ascii=False))
|
||||
|
||||
|
||||
|
||||
@@ -100,6 +100,12 @@ class Config:
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
# Multi-endpoint configuration for litellm backend
|
||||
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
|
||||
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
|
||||
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
try:
|
||||
self.data_dir = self.data_dir.expanduser().resolve()
|
||||
@@ -151,12 +157,19 @@ class Config:
|
||||
|
||||
def save_settings(self) -> None:
|
||||
"""Save embedding and other settings to file."""
|
||||
embedding_config = {
|
||||
"backend": self.embedding_backend,
|
||||
"model": self.embedding_model,
|
||||
"use_gpu": self.embedding_use_gpu,
|
||||
}
|
||||
# Include multi-endpoint config if present
|
||||
if self.embedding_endpoints:
|
||||
embedding_config["endpoints"] = self.embedding_endpoints
|
||||
embedding_config["strategy"] = self.embedding_strategy
|
||||
embedding_config["cooldown"] = self.embedding_cooldown
|
||||
|
||||
settings = {
|
||||
"embedding": {
|
||||
"backend": self.embedding_backend,
|
||||
"model": self.embedding_model,
|
||||
"use_gpu": self.embedding_use_gpu,
|
||||
},
|
||||
"embedding": embedding_config,
|
||||
"llm": {
|
||||
"enabled": self.llm_enabled,
|
||||
"tool": self.llm_tool,
|
||||
@@ -185,6 +198,14 @@ class Config:
|
||||
if "use_gpu" in embedding:
|
||||
self.embedding_use_gpu = embedding["use_gpu"]
|
||||
|
||||
# Load multi-endpoint configuration
|
||||
if "endpoints" in embedding:
|
||||
self.embedding_endpoints = embedding["endpoints"]
|
||||
if "strategy" in embedding:
|
||||
self.embedding_strategy = embedding["strategy"]
|
||||
if "cooldown" in embedding:
|
||||
self.embedding_cooldown = embedding["cooldown"]
|
||||
|
||||
# Load LLM settings
|
||||
llm = settings.get("llm", {})
|
||||
if "enabled" in llm:
|
||||
|
||||
@@ -5,7 +5,7 @@ Provides a unified interface for instantiating different embedder backends.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
@@ -15,6 +15,9 @@ def get_embedder(
|
||||
profile: str = "code",
|
||||
model: str = "default",
|
||||
use_gpu: bool = True,
|
||||
endpoints: Optional[List[Dict[str, Any]]] = None,
|
||||
strategy: str = "latency_aware",
|
||||
cooldown: float = 60.0,
|
||||
**kwargs: Any,
|
||||
) -> BaseEmbedder:
|
||||
"""Factory function to create embedder based on backend.
|
||||
@@ -29,6 +32,13 @@ def get_embedder(
|
||||
Used only when backend="litellm". Default: "default"
|
||||
use_gpu: Whether to use GPU acceleration when available (default: True).
|
||||
Used only when backend="fastembed".
|
||||
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
|
||||
Each endpoint is a dict with keys: model, api_key, api_base, weight.
|
||||
Used only when backend="litellm" and multiple endpoints provided.
|
||||
strategy: Selection strategy for multi-endpoint mode:
|
||||
"round_robin", "latency_aware", "weighted_random".
|
||||
Default: "latency_aware"
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
|
||||
**kwargs: Additional backend-specific arguments
|
||||
|
||||
Returns:
|
||||
@@ -47,13 +57,40 @@ def get_embedder(
|
||||
|
||||
Create litellm embedder:
|
||||
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
|
||||
|
||||
Create rotational embedder with multiple endpoints:
|
||||
>>> endpoints = [
|
||||
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
... ]
|
||||
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||
"""
|
||||
if backend == "fastembed":
|
||||
from .embedder import Embedder
|
||||
return Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||
elif backend == "litellm":
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
return LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
# Check if multi-endpoint mode is requested
|
||||
if endpoints and len(endpoints) > 1:
|
||||
from .rotational_embedder import create_rotational_embedder
|
||||
return create_rotational_embedder(
|
||||
endpoints_config=endpoints,
|
||||
strategy=strategy,
|
||||
default_cooldown=cooldown,
|
||||
)
|
||||
elif endpoints and len(endpoints) == 1:
|
||||
# Single endpoint in list - use it directly
|
||||
ep = endpoints[0]
|
||||
ep_kwargs = {**kwargs}
|
||||
if "api_key" in ep:
|
||||
ep_kwargs["api_key"] = ep["api_key"]
|
||||
if "api_base" in ep:
|
||||
ep_kwargs["api_base"] = ep["api_base"]
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||
else:
|
||||
# No endpoints list - use model parameter
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
return LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
|
||||
434
codex-lens/src/codexlens/semantic/rotational_embedder.py
Normal file
434
codex-lens/src/codexlens/semantic/rotational_embedder.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""Rotational embedder for multi-endpoint API load balancing.
|
||||
|
||||
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
|
||||
to maximize throughput while respecting rate limits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndpointStatus(Enum):
|
||||
"""Status of an API endpoint."""
|
||||
AVAILABLE = "available"
|
||||
COOLING = "cooling" # Rate limited, temporarily unavailable
|
||||
FAILED = "failed" # Permanent failure (auth error, etc.)
|
||||
|
||||
|
||||
class SelectionStrategy(Enum):
|
||||
"""Strategy for selecting endpoints."""
|
||||
ROUND_ROBIN = "round_robin"
|
||||
LATENCY_AWARE = "latency_aware"
|
||||
WEIGHTED_RANDOM = "weighted_random"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointConfig:
|
||||
"""Configuration for a single API endpoint."""
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
weight: float = 1.0 # Higher weight = more requests
|
||||
max_concurrent: int = 4 # Max concurrent requests to this endpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointState:
|
||||
"""Runtime state for an endpoint."""
|
||||
config: EndpointConfig
|
||||
embedder: Any = None # LiteLLMEmbedderWrapper instance
|
||||
|
||||
# Health metrics
|
||||
status: EndpointStatus = EndpointStatus.AVAILABLE
|
||||
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
|
||||
|
||||
# Performance metrics
|
||||
total_requests: int = 0
|
||||
total_failures: int = 0
|
||||
avg_latency_ms: float = 0.0
|
||||
last_latency_ms: float = 0.0
|
||||
|
||||
# Concurrency tracking
|
||||
active_requests: int = 0
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if endpoint is available for requests."""
|
||||
if self.status == EndpointStatus.FAILED:
|
||||
return False
|
||||
if self.status == EndpointStatus.COOLING:
|
||||
if time.time() >= self.cooldown_until:
|
||||
self.status = EndpointStatus.AVAILABLE
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_cooldown(self, seconds: float) -> None:
|
||||
"""Put endpoint in cooldown state."""
|
||||
self.status = EndpointStatus.COOLING
|
||||
self.cooldown_until = time.time() + seconds
|
||||
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
|
||||
|
||||
def mark_failed(self) -> None:
|
||||
"""Mark endpoint as permanently failed."""
|
||||
self.status = EndpointStatus.FAILED
|
||||
logger.error(f"Endpoint {self.config.model} marked as failed")
|
||||
|
||||
def record_success(self, latency_ms: float) -> None:
|
||||
"""Record successful request."""
|
||||
self.total_requests += 1
|
||||
self.last_latency_ms = latency_ms
|
||||
# Exponential moving average for latency
|
||||
alpha = 0.3
|
||||
if self.avg_latency_ms == 0:
|
||||
self.avg_latency_ms = latency_ms
|
||||
else:
|
||||
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""Record failed request."""
|
||||
self.total_requests += 1
|
||||
self.total_failures += 1
|
||||
|
||||
@property
|
||||
def health_score(self) -> float:
|
||||
"""Calculate health score (0-1) based on metrics."""
|
||||
if not self.is_available():
|
||||
return 0.0
|
||||
|
||||
# Base score from success rate
|
||||
if self.total_requests > 0:
|
||||
success_rate = 1 - (self.total_failures / self.total_requests)
|
||||
else:
|
||||
success_rate = 1.0
|
||||
|
||||
# Latency factor (faster = higher score)
|
||||
# Normalize: 100ms = 1.0, 1000ms = 0.1
|
||||
if self.avg_latency_ms > 0:
|
||||
latency_factor = min(1.0, 100 / self.avg_latency_ms)
|
||||
else:
|
||||
latency_factor = 1.0
|
||||
|
||||
# Availability factor (less concurrent = more available)
|
||||
if self.config.max_concurrent > 0:
|
||||
availability = 1 - (self.active_requests / self.config.max_concurrent)
|
||||
else:
|
||||
availability = 1.0
|
||||
|
||||
# Combined score with weights
|
||||
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
|
||||
|
||||
|
||||
class RotationalEmbedder(BaseEmbedder):
|
||||
"""Embedder that load balances across multiple API endpoints.
|
||||
|
||||
Features:
|
||||
- Intelligent endpoint selection based on latency and health
|
||||
- Automatic failover on rate limits (429) and server errors
|
||||
- Cooldown management to respect rate limits
|
||||
- Thread-safe concurrent request handling
|
||||
|
||||
Args:
|
||||
endpoints: List of endpoint configurations
|
||||
strategy: Selection strategy (default: latency_aware)
|
||||
default_cooldown: Default cooldown seconds for rate limits (default: 60)
|
||||
max_retries: Maximum retry attempts across all endpoints (default: 3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoints: List[EndpointConfig],
|
||||
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
|
||||
default_cooldown: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
if not endpoints:
|
||||
raise ValueError("At least one endpoint must be provided")
|
||||
|
||||
self.strategy = strategy
|
||||
self.default_cooldown = default_cooldown
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize endpoint states
|
||||
self._endpoints: List[EndpointState] = []
|
||||
self._lock = threading.Lock()
|
||||
self._round_robin_index = 0
|
||||
|
||||
# Create embedder instances for each endpoint
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
|
||||
for config in endpoints:
|
||||
# Build kwargs for LiteLLMEmbedderWrapper
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if config.api_key:
|
||||
kwargs["api_key"] = config.api_key
|
||||
if config.api_base:
|
||||
kwargs["api_base"] = config.api_base
|
||||
|
||||
try:
|
||||
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
|
||||
state = EndpointState(config=config, embedder=embedder)
|
||||
self._endpoints.append(state)
|
||||
logger.info(f"Initialized endpoint: {config.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
|
||||
|
||||
if not self._endpoints:
|
||||
raise ValueError("Failed to initialize any endpoints")
|
||||
|
||||
# Cache embedding properties from first endpoint
|
||||
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
|
||||
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
|
||||
self._max_tokens = self._endpoints[0].embedder.max_tokens
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions."""
|
||||
return self._embedding_dim
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit."""
|
||||
return self._max_tokens
|
||||
|
||||
@property
|
||||
def endpoint_count(self) -> int:
|
||||
"""Return number of configured endpoints."""
|
||||
return len(self._endpoints)
|
||||
|
||||
@property
|
||||
def available_endpoint_count(self) -> int:
|
||||
"""Return number of available endpoints."""
|
||||
return sum(1 for ep in self._endpoints if ep.is_available())
|
||||
|
||||
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
|
||||
"""Get statistics for all endpoints."""
|
||||
stats = []
|
||||
for ep in self._endpoints:
|
||||
stats.append({
|
||||
"model": ep.config.model,
|
||||
"status": ep.status.value,
|
||||
"total_requests": ep.total_requests,
|
||||
"total_failures": ep.total_failures,
|
||||
"avg_latency_ms": round(ep.avg_latency_ms, 2),
|
||||
"health_score": round(ep.health_score, 3),
|
||||
"active_requests": ep.active_requests,
|
||||
})
|
||||
return stats
|
||||
|
||||
def _select_endpoint(self) -> Optional[EndpointState]:
|
||||
"""Select best available endpoint based on strategy."""
|
||||
available = [ep for ep in self._endpoints if ep.is_available()]
|
||||
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if self.strategy == SelectionStrategy.ROUND_ROBIN:
|
||||
with self._lock:
|
||||
self._round_robin_index = (self._round_robin_index + 1) % len(available)
|
||||
return available[self._round_robin_index]
|
||||
|
||||
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
|
||||
# Sort by health score (descending) and pick top candidate
|
||||
# Add small random factor to prevent thundering herd
|
||||
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored[0][0]
|
||||
|
||||
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
|
||||
# Weighted random selection based on health scores
|
||||
scores = [ep.health_score for ep in available]
|
||||
total = sum(scores)
|
||||
if total == 0:
|
||||
return random.choice(available)
|
||||
|
||||
weights = [s / total for s in scores]
|
||||
return random.choices(available, weights=weights, k=1)[0]
|
||||
|
||||
return available[0]
|
||||
|
||||
def _parse_retry_after(self, error: Exception) -> Optional[float]:
|
||||
"""Extract Retry-After value from error if available."""
|
||||
error_str = str(error)
|
||||
|
||||
# Try to find Retry-After in error message
|
||||
import re
|
||||
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _is_rate_limit_error(self, error: Exception) -> bool:
|
||||
"""Check if error is a rate limit error."""
|
||||
error_str = str(error).lower()
|
||||
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""Check if error is retryable (not auth/config error)."""
|
||||
error_str = str(error).lower()
|
||||
# Retryable errors
|
||||
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
|
||||
"timeout", "connection", "service unavailable"]):
|
||||
return True
|
||||
# Non-retryable errors (auth, config)
|
||||
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
|
||||
"unauthorized", "api key"]):
|
||||
return False
|
||||
# Default to retryable for unknown errors
|
||||
return True
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts using load-balanced endpoint selection.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments passed to underlying embedder.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all endpoints fail after retries.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
tried_endpoints: set = set()
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
endpoint = self._select_endpoint()
|
||||
|
||||
if endpoint is None:
|
||||
# All endpoints unavailable, wait for shortest cooldown
|
||||
min_cooldown = min(
|
||||
(ep.cooldown_until - time.time() for ep in self._endpoints
|
||||
if ep.status == EndpointStatus.COOLING),
|
||||
default=self.default_cooldown
|
||||
)
|
||||
if min_cooldown > 0 and attempt < self.max_retries:
|
||||
wait_time = min(min_cooldown, 30) # Cap wait at 30s
|
||||
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
break
|
||||
|
||||
# Track tried endpoints to avoid infinite loops
|
||||
endpoint_id = id(endpoint)
|
||||
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
|
||||
# Already tried all endpoints
|
||||
break
|
||||
tried_endpoints.add(endpoint_id)
|
||||
|
||||
# Acquire slot
|
||||
with endpoint.lock:
|
||||
endpoint.active_requests += 1
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Record success
|
||||
endpoint.record_success(latency_ms)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
endpoint.record_failure()
|
||||
|
||||
if self._is_rate_limit_error(e):
|
||||
# Rate limited - set cooldown
|
||||
retry_after = self._parse_retry_after(e) or self.default_cooldown
|
||||
endpoint.set_cooldown(retry_after)
|
||||
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
|
||||
f"cooling for {retry_after}s")
|
||||
|
||||
elif not self._is_retryable_error(e):
|
||||
# Permanent failure (auth error, etc.)
|
||||
endpoint.mark_failed()
|
||||
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
|
||||
|
||||
else:
|
||||
# Temporary error - short cooldown
|
||||
endpoint.set_cooldown(5.0)
|
||||
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
|
||||
|
||||
finally:
|
||||
with endpoint.lock:
|
||||
endpoint.active_requests -= 1
|
||||
|
||||
# All retries exhausted
|
||||
available = self.available_endpoint_count
|
||||
raise RuntimeError(
|
||||
f"All embedding attempts failed after {self.max_retries + 1} tries. "
|
||||
f"Available endpoints: {available}/{len(self._endpoints)}. "
|
||||
f"Last error: {last_error}"
|
||||
)
|
||||
|
||||
|
||||
def create_rotational_embedder(
|
||||
endpoints_config: List[Dict[str, Any]],
|
||||
strategy: str = "latency_aware",
|
||||
default_cooldown: float = 60.0,
|
||||
) -> RotationalEmbedder:
|
||||
"""Factory function to create RotationalEmbedder from config dicts.
|
||||
|
||||
Args:
|
||||
endpoints_config: List of endpoint configuration dicts with keys:
|
||||
- model: Model identifier (required)
|
||||
- api_key: API key (optional)
|
||||
- api_base: API base URL (optional)
|
||||
- weight: Request weight (optional, default 1.0)
|
||||
- max_concurrent: Max concurrent requests (optional, default 4)
|
||||
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
|
||||
default_cooldown: Default cooldown seconds for rate limits
|
||||
|
||||
Returns:
|
||||
Configured RotationalEmbedder instance
|
||||
|
||||
Example config:
|
||||
endpoints_config = [
|
||||
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
]
|
||||
"""
|
||||
endpoints = []
|
||||
for cfg in endpoints_config:
|
||||
endpoints.append(EndpointConfig(
|
||||
model=cfg["model"],
|
||||
api_key=cfg.get("api_key"),
|
||||
api_base=cfg.get("api_base"),
|
||||
weight=cfg.get("weight", 1.0),
|
||||
max_concurrent=cfg.get("max_concurrent", 4),
|
||||
))
|
||||
|
||||
strategy_enum = SelectionStrategy[strategy.upper()]
|
||||
|
||||
return RotationalEmbedder(
|
||||
endpoints=endpoints,
|
||||
strategy=strategy_enum,
|
||||
default_cooldown=default_cooldown,
|
||||
)
|
||||
Reference in New Issue
Block a user