feat: 添加多端点支持和负载均衡功能,增强 LiteLLM 嵌入管理

This commit is contained in:
catlog22
2025-12-25 11:01:08 +08:00
parent 3c3ce55842
commit 40e61b30d6
7 changed files with 727 additions and 29 deletions

View File

@@ -759,6 +759,16 @@ def status(
console.print(f" Coverage: {embeddings_info['coverage_percent']:.1f}%")
console.print(f" Total Chunks: {embeddings_info['total_chunks']}")
# Display model information if available
model_info = embeddings_info.get('model_info')
if model_info:
console.print("\n[bold]Embedding Model:[/bold]")
console.print(f" Backend: [cyan]{model_info.get('backend', 'unknown')}[/cyan]")
console.print(f" Model: [cyan]{model_info.get('model_profile', 'unknown')}[/cyan] ({model_info.get('model_name', '')})")
console.print(f" Dimensions: {model_info.get('embedding_dim', 'unknown')}")
if model_info.get('updated_at'):
console.print(f" Last Updated: {model_info['updated_at']}")
except StorageError as exc:
if json_mode:
print_json(success=False, error=f"Storage error: {exc}")
@@ -1878,7 +1888,7 @@ def embeddings_generate(
"""
_configure_logging(verbose, json_mode)
from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive
from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive, scan_for_model_conflicts
# Validate backend
valid_backends = ["fastembed", "litellm"]
@@ -1946,6 +1956,50 @@ def embeddings_generate(
console.print(f"Concurrency: [cyan]{max_workers} workers[/cyan]")
console.print()
# Pre-check for model conflicts (only if not forcing)
if not force:
# Determine the index root for conflict scanning
scan_root = index_root if use_recursive else (index_path.parent if index_path else None)
if scan_root:
conflict_result = scan_for_model_conflicts(scan_root, backend, model)
if conflict_result["has_conflict"]:
existing = conflict_result["existing_config"]
conflict_count = len(conflict_result["conflicts"])
if json_mode:
# JSON mode: return structured error for UI handling
print_json(
success=False,
error="Model conflict detected",
code="MODEL_CONFLICT",
existing_config=existing,
target_config=conflict_result["target_config"],
conflict_count=conflict_count,
conflicts=conflict_result["conflicts"][:5], # Show first 5 conflicts
hint="Use --force to overwrite existing embeddings with the new model",
)
raise typer.Exit(code=1)
else:
# Interactive mode: show warning and ask for confirmation
console.print("[yellow]⚠ Model Conflict Detected[/yellow]")
console.print(f" Existing: [red]{existing['backend']}/{existing['model']}[/red] ({existing.get('embedding_dim', '?')} dim)")
console.print(f" Requested: [green]{backend}/{model}[/green]")
console.print(f" Affected indexes: [yellow]{conflict_count}[/yellow]")
console.print()
console.print("[dim]Mixing different embedding models in the same index is not supported.[/dim]")
console.print("[dim]Overwriting will delete all existing embeddings and regenerate with the new model.[/dim]")
console.print()
# Ask for confirmation
if typer.confirm("Overwrite existing embeddings with the new model?", default=False):
force = True
console.print("[green]Confirmed.[/green] Proceeding with overwrite...\n")
else:
console.print("[yellow]Cancelled.[/yellow] Use --force to skip this prompt.")
raise typer.Exit(code=0)
if use_recursive:
result = generate_embeddings_recursive(
index_root,

View File

@@ -235,18 +235,25 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]:
}
def _get_embedding_defaults() -> tuple[str, str, bool]:
def _get_embedding_defaults() -> tuple[str, str, bool, List, str, float]:
"""Get default embedding settings from config.
Returns:
Tuple of (backend, model, use_gpu)
Tuple of (backend, model, use_gpu, endpoints, strategy, cooldown)
"""
try:
from codexlens.config import Config
config = Config.load()
return config.embedding_backend, config.embedding_model, config.embedding_use_gpu
return (
config.embedding_backend,
config.embedding_model,
config.embedding_use_gpu,
config.embedding_endpoints,
config.embedding_strategy,
config.embedding_cooldown,
)
except Exception:
return "fastembed", "code", True
return "fastembed", "code", True, [], "latency_aware", 60.0
def generate_embeddings(
@@ -260,6 +267,9 @@ def generate_embeddings(
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: Optional[int] = None,
endpoints: Optional[List] = None,
strategy: Optional[str] = None,
cooldown: Optional[float] = None,
) -> Dict[str, any]:
"""Generate embeddings for an index using memory-efficient batch processing.
@@ -284,14 +294,18 @@ def generate_embeddings(
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls.
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
4 for litellm (network I/O bound).
If None, uses dynamic defaults based on backend and endpoint count.
endpoints: Optional list of endpoint configurations for multi-API load balancing.
Each dict has keys: model, api_key, api_base, weight.
strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware).
cooldown: Default cooldown seconds for rate-limited endpoints.
Returns:
Result dictionary with generation statistics
"""
# Get defaults from config if not specified
default_backend, default_model, default_gpu = _get_embedding_defaults()
(default_backend, default_model, default_gpu,
default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults()
if embedding_backend is None:
embedding_backend = default_backend
@@ -299,13 +313,26 @@ def generate_embeddings(
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
if endpoints is None:
endpoints = default_endpoints
if strategy is None:
strategy = default_strategy
if cooldown is None:
cooldown = default_cooldown
# Set dynamic max_workers default based on backend type
# Calculate endpoint count for worker scaling
endpoint_count = len(endpoints) if endpoints else 1
# Set dynamic max_workers default based on backend type and endpoint count
# - FastEmbed: CPU-bound, sequential is optimal (1 worker)
# - LiteLLM: Network I/O bound, concurrent calls improve throughput (4 workers)
# - LiteLLM single endpoint: 4 workers default
# - LiteLLM multi-endpoint: workers = endpoint_count * 2 (to saturate all APIs)
if max_workers is None:
if embedding_backend == "litellm":
max_workers = 4
if endpoint_count > 1:
max_workers = min(endpoint_count * 2, 16) # Cap at 16 workers
else:
max_workers = 4
else:
max_workers = 1
@@ -354,13 +381,20 @@ def generate_embeddings(
from codexlens.semantic.vector_store import VectorStore
from codexlens.semantic.chunker import Chunker, ChunkConfig
# Initialize embedder using factory (supports both fastembed and litellm)
# Initialize embedder using factory (supports fastembed, litellm, and rotational)
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
# For litellm: model_profile is a model name (e.g., qwen3-embedding)
# For multi-endpoint: endpoints list enables load balancing
if embedding_backend == "fastembed":
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu)
elif embedding_backend == "litellm":
embedder = get_embedder_factory(backend="litellm", model=model_profile)
embedder = get_embedder_factory(
backend="litellm",
model=model_profile,
endpoints=endpoints if endpoints else None,
strategy=strategy,
cooldown=cooldown,
)
else:
return {
"success": False,
@@ -375,7 +409,10 @@ def generate_embeddings(
skip_token_count=True
))
# Log embedder info with endpoint count for multi-endpoint mode
if progress_callback:
if endpoint_count > 1:
progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy")
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
except Exception as e:
@@ -684,6 +721,9 @@ def generate_embeddings_recursive(
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: Optional[int] = None,
endpoints: Optional[List] = None,
strategy: Optional[str] = None,
cooldown: Optional[float] = None,
) -> Dict[str, any]:
"""Generate embeddings for all index databases in a project recursively.
@@ -704,14 +744,17 @@ def generate_embeddings_recursive(
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls.
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
4 for litellm (network I/O bound).
If None, uses dynamic defaults based on backend and endpoint count.
endpoints: Optional list of endpoint configurations for multi-API load balancing.
strategy: Selection strategy for multi-endpoint mode.
cooldown: Default cooldown seconds for rate-limited endpoints.
Returns:
Aggregated result dictionary with generation statistics
"""
# Get defaults from config if not specified
default_backend, default_model, default_gpu = _get_embedding_defaults()
(default_backend, default_model, default_gpu,
default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults()
if embedding_backend is None:
embedding_backend = default_backend
@@ -719,11 +762,23 @@ def generate_embeddings_recursive(
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
if endpoints is None:
endpoints = default_endpoints
if strategy is None:
strategy = default_strategy
if cooldown is None:
cooldown = default_cooldown
# Set dynamic max_workers default based on backend type
# Calculate endpoint count for worker scaling
endpoint_count = len(endpoints) if endpoints else 1
# Set dynamic max_workers default based on backend type and endpoint count
if max_workers is None:
if embedding_backend == "litellm":
max_workers = 4
if endpoint_count > 1:
max_workers = min(endpoint_count * 2, 16)
else:
max_workers = 4
else:
max_workers = 1
@@ -765,6 +820,9 @@ def generate_embeddings_recursive(
use_gpu=use_gpu,
max_tokens_per_batch=max_tokens_per_batch,
max_workers=max_workers,
endpoints=endpoints,
strategy=strategy,
cooldown=cooldown,
)
all_results.append({
@@ -958,3 +1016,85 @@ def get_embedding_stats_summary(index_root: Path) -> Dict[str, any]:
"indexes": index_stats,
},
}
def scan_for_model_conflicts(
index_root: Path,
target_backend: str,
target_model: str,
) -> Dict[str, any]:
"""Scan for model conflicts across all indexes in a directory.
Checks if any existing embeddings were generated with a different
backend or model than the target configuration.
Args:
index_root: Root index directory to scan
target_backend: Target embedding backend (fastembed or litellm)
target_model: Target model profile/name
Returns:
Dictionary with:
- has_conflict: True if any index has different model config
- existing_config: Config from first index with embeddings (if any)
- target_config: The requested configuration
- conflicts: List of conflicting index paths with their configs
- indexes_with_embeddings: Count of indexes that have embeddings
"""
index_files = discover_all_index_dbs(index_root)
if not index_files:
return {
"has_conflict": False,
"existing_config": None,
"target_config": {"backend": target_backend, "model": target_model},
"conflicts": [],
"indexes_with_embeddings": 0,
}
conflicts = []
existing_config = None
indexes_with_embeddings = 0
for index_path in index_files:
try:
from codexlens.semantic.vector_store import VectorStore
with VectorStore(index_path) as vs:
config = vs.get_model_config()
if config and config.get("model_profile"):
indexes_with_embeddings += 1
# Store first existing config as reference
if existing_config is None:
existing_config = {
"backend": config.get("backend"),
"model": config.get("model_profile"),
"model_name": config.get("model_name"),
"embedding_dim": config.get("embedding_dim"),
}
# Check for conflict: different backend OR different model
existing_backend = config.get("backend", "")
existing_model = config.get("model_profile", "")
if existing_backend != target_backend or existing_model != target_model:
conflicts.append({
"path": str(index_path),
"existing": {
"backend": existing_backend,
"model": existing_model,
"model_name": config.get("model_name"),
},
})
except Exception as e:
logger.debug(f"Failed to check model config for {index_path}: {e}")
continue
return {
"has_conflict": len(conflicts) > 0,
"existing_config": existing_config,
"target_config": {"backend": target_backend, "model": target_model},
"conflicts": conflicts,
"indexes_with_embeddings": indexes_with_embeddings,
}

View File

@@ -35,12 +35,23 @@ def _to_jsonable(value: Any) -> Any:
return value
def print_json(*, success: bool, result: Any = None, error: str | None = None) -> None:
def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None:
"""Print JSON output with optional additional fields.
Args:
success: Whether the operation succeeded
result: Result data (used when success=True)
error: Error message (used when success=False)
**kwargs: Additional fields to include in the payload (e.g., code, details)
"""
payload: dict[str, Any] = {"success": success}
if success:
payload["result"] = _to_jsonable(result)
else:
payload["error"] = error or "Unknown error"
# Include additional error details if provided
for key, value in kwargs.items():
payload[key] = _to_jsonable(value)
console.print_json(json.dumps(payload, ensure_ascii=False))

View File

@@ -100,6 +100,12 @@ class Config:
# For litellm: model name from config (e.g., "qwen3-embedding")
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
# Multi-endpoint configuration for litellm backend
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random
embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints
def __post_init__(self) -> None:
try:
self.data_dir = self.data_dir.expanduser().resolve()
@@ -151,12 +157,19 @@ class Config:
def save_settings(self) -> None:
"""Save embedding and other settings to file."""
embedding_config = {
"backend": self.embedding_backend,
"model": self.embedding_model,
"use_gpu": self.embedding_use_gpu,
}
# Include multi-endpoint config if present
if self.embedding_endpoints:
embedding_config["endpoints"] = self.embedding_endpoints
embedding_config["strategy"] = self.embedding_strategy
embedding_config["cooldown"] = self.embedding_cooldown
settings = {
"embedding": {
"backend": self.embedding_backend,
"model": self.embedding_model,
"use_gpu": self.embedding_use_gpu,
},
"embedding": embedding_config,
"llm": {
"enabled": self.llm_enabled,
"tool": self.llm_tool,
@@ -185,6 +198,14 @@ class Config:
if "use_gpu" in embedding:
self.embedding_use_gpu = embedding["use_gpu"]
# Load multi-endpoint configuration
if "endpoints" in embedding:
self.embedding_endpoints = embedding["endpoints"]
if "strategy" in embedding:
self.embedding_strategy = embedding["strategy"]
if "cooldown" in embedding:
self.embedding_cooldown = embedding["cooldown"]
# Load LLM settings
llm = settings.get("llm", {})
if "enabled" in llm:

View File

@@ -5,7 +5,7 @@ Provides a unified interface for instantiating different embedder backends.
from __future__ import annotations
from typing import Any
from typing import Any, Dict, List, Optional
from .base import BaseEmbedder
@@ -15,6 +15,9 @@ def get_embedder(
profile: str = "code",
model: str = "default",
use_gpu: bool = True,
endpoints: Optional[List[Dict[str, Any]]] = None,
strategy: str = "latency_aware",
cooldown: float = 60.0,
**kwargs: Any,
) -> BaseEmbedder:
"""Factory function to create embedder based on backend.
@@ -29,6 +32,13 @@ def get_embedder(
Used only when backend="litellm". Default: "default"
use_gpu: Whether to use GPU acceleration when available (default: True).
Used only when backend="fastembed".
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
Each endpoint is a dict with keys: model, api_key, api_base, weight.
Used only when backend="litellm" and multiple endpoints provided.
strategy: Selection strategy for multi-endpoint mode:
"round_robin", "latency_aware", "weighted_random".
Default: "latency_aware"
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
**kwargs: Additional backend-specific arguments
Returns:
@@ -47,13 +57,40 @@ def get_embedder(
Create litellm embedder:
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
Create rotational embedder with multiple endpoints:
>>> endpoints = [
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
... ]
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
"""
if backend == "fastembed":
from .embedder import Embedder
return Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
elif backend == "litellm":
from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=model, **kwargs)
# Check if multi-endpoint mode is requested
if endpoints and len(endpoints) > 1:
from .rotational_embedder import create_rotational_embedder
return create_rotational_embedder(
endpoints_config=endpoints,
strategy=strategy,
default_cooldown=cooldown,
)
elif endpoints and len(endpoints) == 1:
# Single endpoint in list - use it directly
ep = endpoints[0]
ep_kwargs = {**kwargs}
if "api_key" in ep:
ep_kwargs["api_key"] = ep["api_key"]
if "api_base" in ep:
ep_kwargs["api_base"] = ep["api_base"]
from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
else:
# No endpoints list - use model parameter
from .litellm_embedder import LiteLLMEmbedderWrapper
return LiteLLMEmbedderWrapper(model=model, **kwargs)
else:
raise ValueError(
f"Unknown backend: {backend}. "

View File

@@ -0,0 +1,434 @@
"""Rotational embedder for multi-endpoint API load balancing.
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
to maximize throughput while respecting rate limits.
"""
from __future__ import annotations
import logging
import random
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional
import numpy as np
from .base import BaseEmbedder
logger = logging.getLogger(__name__)
class EndpointStatus(Enum):
"""Status of an API endpoint."""
AVAILABLE = "available"
COOLING = "cooling" # Rate limited, temporarily unavailable
FAILED = "failed" # Permanent failure (auth error, etc.)
class SelectionStrategy(Enum):
"""Strategy for selecting endpoints."""
ROUND_ROBIN = "round_robin"
LATENCY_AWARE = "latency_aware"
WEIGHTED_RANDOM = "weighted_random"
@dataclass
class EndpointConfig:
"""Configuration for a single API endpoint."""
model: str
api_key: Optional[str] = None
api_base: Optional[str] = None
weight: float = 1.0 # Higher weight = more requests
max_concurrent: int = 4 # Max concurrent requests to this endpoint
@dataclass
class EndpointState:
"""Runtime state for an endpoint."""
config: EndpointConfig
embedder: Any = None # LiteLLMEmbedderWrapper instance
# Health metrics
status: EndpointStatus = EndpointStatus.AVAILABLE
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
# Performance metrics
total_requests: int = 0
total_failures: int = 0
avg_latency_ms: float = 0.0
last_latency_ms: float = 0.0
# Concurrency tracking
active_requests: int = 0
lock: threading.Lock = field(default_factory=threading.Lock)
def is_available(self) -> bool:
"""Check if endpoint is available for requests."""
if self.status == EndpointStatus.FAILED:
return False
if self.status == EndpointStatus.COOLING:
if time.time() >= self.cooldown_until:
self.status = EndpointStatus.AVAILABLE
return True
return False
return True
def set_cooldown(self, seconds: float) -> None:
"""Put endpoint in cooldown state."""
self.status = EndpointStatus.COOLING
self.cooldown_until = time.time() + seconds
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
def mark_failed(self) -> None:
"""Mark endpoint as permanently failed."""
self.status = EndpointStatus.FAILED
logger.error(f"Endpoint {self.config.model} marked as failed")
def record_success(self, latency_ms: float) -> None:
"""Record successful request."""
self.total_requests += 1
self.last_latency_ms = latency_ms
# Exponential moving average for latency
alpha = 0.3
if self.avg_latency_ms == 0:
self.avg_latency_ms = latency_ms
else:
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
def record_failure(self) -> None:
"""Record failed request."""
self.total_requests += 1
self.total_failures += 1
@property
def health_score(self) -> float:
"""Calculate health score (0-1) based on metrics."""
if not self.is_available():
return 0.0
# Base score from success rate
if self.total_requests > 0:
success_rate = 1 - (self.total_failures / self.total_requests)
else:
success_rate = 1.0
# Latency factor (faster = higher score)
# Normalize: 100ms = 1.0, 1000ms = 0.1
if self.avg_latency_ms > 0:
latency_factor = min(1.0, 100 / self.avg_latency_ms)
else:
latency_factor = 1.0
# Availability factor (less concurrent = more available)
if self.config.max_concurrent > 0:
availability = 1 - (self.active_requests / self.config.max_concurrent)
else:
availability = 1.0
# Combined score with weights
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
class RotationalEmbedder(BaseEmbedder):
"""Embedder that load balances across multiple API endpoints.
Features:
- Intelligent endpoint selection based on latency and health
- Automatic failover on rate limits (429) and server errors
- Cooldown management to respect rate limits
- Thread-safe concurrent request handling
Args:
endpoints: List of endpoint configurations
strategy: Selection strategy (default: latency_aware)
default_cooldown: Default cooldown seconds for rate limits (default: 60)
max_retries: Maximum retry attempts across all endpoints (default: 3)
"""
def __init__(
self,
endpoints: List[EndpointConfig],
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
default_cooldown: float = 60.0,
max_retries: int = 3,
) -> None:
if not endpoints:
raise ValueError("At least one endpoint must be provided")
self.strategy = strategy
self.default_cooldown = default_cooldown
self.max_retries = max_retries
# Initialize endpoint states
self._endpoints: List[EndpointState] = []
self._lock = threading.Lock()
self._round_robin_index = 0
# Create embedder instances for each endpoint
from .litellm_embedder import LiteLLMEmbedderWrapper
for config in endpoints:
# Build kwargs for LiteLLMEmbedderWrapper
kwargs: Dict[str, Any] = {}
if config.api_key:
kwargs["api_key"] = config.api_key
if config.api_base:
kwargs["api_base"] = config.api_base
try:
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
state = EndpointState(config=config, embedder=embedder)
self._endpoints.append(state)
logger.info(f"Initialized endpoint: {config.model}")
except Exception as e:
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
if not self._endpoints:
raise ValueError("Failed to initialize any endpoints")
# Cache embedding properties from first endpoint
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
self._max_tokens = self._endpoints[0].embedder.max_tokens
@property
def embedding_dim(self) -> int:
"""Return embedding dimensions."""
return self._embedding_dim
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def max_tokens(self) -> int:
"""Return maximum token limit."""
return self._max_tokens
@property
def endpoint_count(self) -> int:
"""Return number of configured endpoints."""
return len(self._endpoints)
@property
def available_endpoint_count(self) -> int:
"""Return number of available endpoints."""
return sum(1 for ep in self._endpoints if ep.is_available())
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
"""Get statistics for all endpoints."""
stats = []
for ep in self._endpoints:
stats.append({
"model": ep.config.model,
"status": ep.status.value,
"total_requests": ep.total_requests,
"total_failures": ep.total_failures,
"avg_latency_ms": round(ep.avg_latency_ms, 2),
"health_score": round(ep.health_score, 3),
"active_requests": ep.active_requests,
})
return stats
def _select_endpoint(self) -> Optional[EndpointState]:
"""Select best available endpoint based on strategy."""
available = [ep for ep in self._endpoints if ep.is_available()]
if not available:
return None
if self.strategy == SelectionStrategy.ROUND_ROBIN:
with self._lock:
self._round_robin_index = (self._round_robin_index + 1) % len(available)
return available[self._round_robin_index]
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
# Sort by health score (descending) and pick top candidate
# Add small random factor to prevent thundering herd
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
scored.sort(key=lambda x: x[1], reverse=True)
return scored[0][0]
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
# Weighted random selection based on health scores
scores = [ep.health_score for ep in available]
total = sum(scores)
if total == 0:
return random.choice(available)
weights = [s / total for s in scores]
return random.choices(available, weights=weights, k=1)[0]
return available[0]
def _parse_retry_after(self, error: Exception) -> Optional[float]:
"""Extract Retry-After value from error if available."""
error_str = str(error)
# Try to find Retry-After in error message
import re
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
if match:
return float(match.group(1))
return None
def _is_rate_limit_error(self, error: Exception) -> bool:
"""Check if error is a rate limit error."""
error_str = str(error).lower()
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
def _is_retryable_error(self, error: Exception) -> bool:
"""Check if error is retryable (not auth/config error)."""
error_str = str(error).lower()
# Retryable errors
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
"timeout", "connection", "service unavailable"]):
return True
# Non-retryable errors (auth, config)
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
"unauthorized", "api key"]):
return False
# Default to retryable for unknown errors
return True
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts using load-balanced endpoint selection.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments passed to underlying embedder.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
Raises:
RuntimeError: If all endpoints fail after retries.
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
last_error: Optional[Exception] = None
tried_endpoints: set = set()
for attempt in range(self.max_retries + 1):
endpoint = self._select_endpoint()
if endpoint is None:
# All endpoints unavailable, wait for shortest cooldown
min_cooldown = min(
(ep.cooldown_until - time.time() for ep in self._endpoints
if ep.status == EndpointStatus.COOLING),
default=self.default_cooldown
)
if min_cooldown > 0 and attempt < self.max_retries:
wait_time = min(min_cooldown, 30) # Cap wait at 30s
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
time.sleep(wait_time)
continue
break
# Track tried endpoints to avoid infinite loops
endpoint_id = id(endpoint)
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
# Already tried all endpoints
break
tried_endpoints.add(endpoint_id)
# Acquire slot
with endpoint.lock:
endpoint.active_requests += 1
try:
start_time = time.time()
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
latency_ms = (time.time() - start_time) * 1000
# Record success
endpoint.record_success(latency_ms)
return result
except Exception as e:
last_error = e
endpoint.record_failure()
if self._is_rate_limit_error(e):
# Rate limited - set cooldown
retry_after = self._parse_retry_after(e) or self.default_cooldown
endpoint.set_cooldown(retry_after)
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
f"cooling for {retry_after}s")
elif not self._is_retryable_error(e):
# Permanent failure (auth error, etc.)
endpoint.mark_failed()
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
else:
# Temporary error - short cooldown
endpoint.set_cooldown(5.0)
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
finally:
with endpoint.lock:
endpoint.active_requests -= 1
# All retries exhausted
available = self.available_endpoint_count
raise RuntimeError(
f"All embedding attempts failed after {self.max_retries + 1} tries. "
f"Available endpoints: {available}/{len(self._endpoints)}. "
f"Last error: {last_error}"
)
def create_rotational_embedder(
endpoints_config: List[Dict[str, Any]],
strategy: str = "latency_aware",
default_cooldown: float = 60.0,
) -> RotationalEmbedder:
"""Factory function to create RotationalEmbedder from config dicts.
Args:
endpoints_config: List of endpoint configuration dicts with keys:
- model: Model identifier (required)
- api_key: API key (optional)
- api_base: API base URL (optional)
- weight: Request weight (optional, default 1.0)
- max_concurrent: Max concurrent requests (optional, default 4)
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
default_cooldown: Default cooldown seconds for rate limits
Returns:
Configured RotationalEmbedder instance
Example config:
endpoints_config = [
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
]
"""
endpoints = []
for cfg in endpoints_config:
endpoints.append(EndpointConfig(
model=cfg["model"],
api_key=cfg.get("api_key"),
api_base=cfg.get("api_base"),
weight=cfg.get("weight", 1.0),
max_concurrent=cfg.get("max_concurrent", 4),
))
strategy_enum = SelectionStrategy[strategy.upper()]
return RotationalEmbedder(
endpoints=endpoints,
strategy=strategy_enum,
default_cooldown=default_cooldown,
)