mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-06 01:54:11 +08:00
435 lines
16 KiB
Python
435 lines
16 KiB
Python
"""Rotational embedder for multi-endpoint API load balancing.
|
|
|
|
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
|
|
to maximize throughput while respecting rate limits.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import random
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any, Dict, Iterable, List, Optional
|
|
|
|
import numpy as np
|
|
|
|
from .base import BaseEmbedder
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EndpointStatus(Enum):
|
|
"""Status of an API endpoint."""
|
|
AVAILABLE = "available"
|
|
COOLING = "cooling" # Rate limited, temporarily unavailable
|
|
FAILED = "failed" # Permanent failure (auth error, etc.)
|
|
|
|
|
|
class SelectionStrategy(Enum):
|
|
"""Strategy for selecting endpoints."""
|
|
ROUND_ROBIN = "round_robin"
|
|
LATENCY_AWARE = "latency_aware"
|
|
WEIGHTED_RANDOM = "weighted_random"
|
|
|
|
|
|
@dataclass
|
|
class EndpointConfig:
|
|
"""Configuration for a single API endpoint."""
|
|
model: str
|
|
api_key: Optional[str] = None
|
|
api_base: Optional[str] = None
|
|
weight: float = 1.0 # Higher weight = more requests
|
|
max_concurrent: int = 4 # Max concurrent requests to this endpoint
|
|
|
|
|
|
@dataclass
|
|
class EndpointState:
|
|
"""Runtime state for an endpoint."""
|
|
config: EndpointConfig
|
|
embedder: Any = None # LiteLLMEmbedderWrapper instance
|
|
|
|
# Health metrics
|
|
status: EndpointStatus = EndpointStatus.AVAILABLE
|
|
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
|
|
|
|
# Performance metrics
|
|
total_requests: int = 0
|
|
total_failures: int = 0
|
|
avg_latency_ms: float = 0.0
|
|
last_latency_ms: float = 0.0
|
|
|
|
# Concurrency tracking
|
|
active_requests: int = 0
|
|
lock: threading.Lock = field(default_factory=threading.Lock)
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if endpoint is available for requests."""
|
|
if self.status == EndpointStatus.FAILED:
|
|
return False
|
|
if self.status == EndpointStatus.COOLING:
|
|
if time.time() >= self.cooldown_until:
|
|
self.status = EndpointStatus.AVAILABLE
|
|
return True
|
|
return False
|
|
return True
|
|
|
|
def set_cooldown(self, seconds: float) -> None:
|
|
"""Put endpoint in cooldown state."""
|
|
self.status = EndpointStatus.COOLING
|
|
self.cooldown_until = time.time() + seconds
|
|
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
|
|
|
|
def mark_failed(self) -> None:
|
|
"""Mark endpoint as permanently failed."""
|
|
self.status = EndpointStatus.FAILED
|
|
logger.error(f"Endpoint {self.config.model} marked as failed")
|
|
|
|
def record_success(self, latency_ms: float) -> None:
|
|
"""Record successful request."""
|
|
self.total_requests += 1
|
|
self.last_latency_ms = latency_ms
|
|
# Exponential moving average for latency
|
|
alpha = 0.3
|
|
if self.avg_latency_ms == 0:
|
|
self.avg_latency_ms = latency_ms
|
|
else:
|
|
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
|
|
|
|
def record_failure(self) -> None:
|
|
"""Record failed request."""
|
|
self.total_requests += 1
|
|
self.total_failures += 1
|
|
|
|
@property
|
|
def health_score(self) -> float:
|
|
"""Calculate health score (0-1) based on metrics."""
|
|
if not self.is_available():
|
|
return 0.0
|
|
|
|
# Base score from success rate
|
|
if self.total_requests > 0:
|
|
success_rate = 1 - (self.total_failures / self.total_requests)
|
|
else:
|
|
success_rate = 1.0
|
|
|
|
# Latency factor (faster = higher score)
|
|
# Normalize: 100ms = 1.0, 1000ms = 0.1
|
|
if self.avg_latency_ms > 0:
|
|
latency_factor = min(1.0, 100 / self.avg_latency_ms)
|
|
else:
|
|
latency_factor = 1.0
|
|
|
|
# Availability factor (less concurrent = more available)
|
|
if self.config.max_concurrent > 0:
|
|
availability = 1 - (self.active_requests / self.config.max_concurrent)
|
|
else:
|
|
availability = 1.0
|
|
|
|
# Combined score with weights
|
|
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
|
|
|
|
|
|
class RotationalEmbedder(BaseEmbedder):
|
|
"""Embedder that load balances across multiple API endpoints.
|
|
|
|
Features:
|
|
- Intelligent endpoint selection based on latency and health
|
|
- Automatic failover on rate limits (429) and server errors
|
|
- Cooldown management to respect rate limits
|
|
- Thread-safe concurrent request handling
|
|
|
|
Args:
|
|
endpoints: List of endpoint configurations
|
|
strategy: Selection strategy (default: latency_aware)
|
|
default_cooldown: Default cooldown seconds for rate limits (default: 60)
|
|
max_retries: Maximum retry attempts across all endpoints (default: 3)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
endpoints: List[EndpointConfig],
|
|
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
|
|
default_cooldown: float = 60.0,
|
|
max_retries: int = 3,
|
|
) -> None:
|
|
if not endpoints:
|
|
raise ValueError("At least one endpoint must be provided")
|
|
|
|
self.strategy = strategy
|
|
self.default_cooldown = default_cooldown
|
|
self.max_retries = max_retries
|
|
|
|
# Initialize endpoint states
|
|
self._endpoints: List[EndpointState] = []
|
|
self._lock = threading.Lock()
|
|
self._round_robin_index = 0
|
|
|
|
# Create embedder instances for each endpoint
|
|
from .litellm_embedder import LiteLLMEmbedderWrapper
|
|
|
|
for config in endpoints:
|
|
# Build kwargs for LiteLLMEmbedderWrapper
|
|
kwargs: Dict[str, Any] = {}
|
|
if config.api_key:
|
|
kwargs["api_key"] = config.api_key
|
|
if config.api_base:
|
|
kwargs["api_base"] = config.api_base
|
|
|
|
try:
|
|
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
|
|
state = EndpointState(config=config, embedder=embedder)
|
|
self._endpoints.append(state)
|
|
logger.info(f"Initialized endpoint: {config.model}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
|
|
|
|
if not self._endpoints:
|
|
raise ValueError("Failed to initialize any endpoints")
|
|
|
|
# Cache embedding properties from first endpoint
|
|
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
|
|
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
|
|
self._max_tokens = self._endpoints[0].embedder.max_tokens
|
|
|
|
@property
|
|
def embedding_dim(self) -> int:
|
|
"""Return embedding dimensions."""
|
|
return self._embedding_dim
|
|
|
|
@property
|
|
def model_name(self) -> str:
|
|
"""Return model name."""
|
|
return self._model_name
|
|
|
|
@property
|
|
def max_tokens(self) -> int:
|
|
"""Return maximum token limit."""
|
|
return self._max_tokens
|
|
|
|
@property
|
|
def endpoint_count(self) -> int:
|
|
"""Return number of configured endpoints."""
|
|
return len(self._endpoints)
|
|
|
|
@property
|
|
def available_endpoint_count(self) -> int:
|
|
"""Return number of available endpoints."""
|
|
return sum(1 for ep in self._endpoints if ep.is_available())
|
|
|
|
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
|
|
"""Get statistics for all endpoints."""
|
|
stats = []
|
|
for ep in self._endpoints:
|
|
stats.append({
|
|
"model": ep.config.model,
|
|
"status": ep.status.value,
|
|
"total_requests": ep.total_requests,
|
|
"total_failures": ep.total_failures,
|
|
"avg_latency_ms": round(ep.avg_latency_ms, 2),
|
|
"health_score": round(ep.health_score, 3),
|
|
"active_requests": ep.active_requests,
|
|
})
|
|
return stats
|
|
|
|
def _select_endpoint(self) -> Optional[EndpointState]:
|
|
"""Select best available endpoint based on strategy."""
|
|
available = [ep for ep in self._endpoints if ep.is_available()]
|
|
|
|
if not available:
|
|
return None
|
|
|
|
if self.strategy == SelectionStrategy.ROUND_ROBIN:
|
|
with self._lock:
|
|
self._round_robin_index = (self._round_robin_index + 1) % len(available)
|
|
return available[self._round_robin_index]
|
|
|
|
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
|
|
# Sort by health score (descending) and pick top candidate
|
|
# Add small random factor to prevent thundering herd
|
|
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
|
|
scored.sort(key=lambda x: x[1], reverse=True)
|
|
return scored[0][0]
|
|
|
|
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
|
|
# Weighted random selection based on health scores
|
|
scores = [ep.health_score for ep in available]
|
|
total = sum(scores)
|
|
if total == 0:
|
|
return random.choice(available)
|
|
|
|
weights = [s / total for s in scores]
|
|
return random.choices(available, weights=weights, k=1)[0]
|
|
|
|
return available[0]
|
|
|
|
def _parse_retry_after(self, error: Exception) -> Optional[float]:
|
|
"""Extract Retry-After value from error if available."""
|
|
error_str = str(error)
|
|
|
|
# Try to find Retry-After in error message
|
|
import re
|
|
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
|
|
if match:
|
|
return float(match.group(1))
|
|
|
|
return None
|
|
|
|
def _is_rate_limit_error(self, error: Exception) -> bool:
|
|
"""Check if error is a rate limit error."""
|
|
error_str = str(error).lower()
|
|
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
|
|
|
|
def _is_retryable_error(self, error: Exception) -> bool:
|
|
"""Check if error is retryable (not auth/config error)."""
|
|
error_str = str(error).lower()
|
|
# Retryable errors
|
|
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
|
|
"timeout", "connection", "service unavailable"]):
|
|
return True
|
|
# Non-retryable errors (auth, config)
|
|
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
|
|
"unauthorized", "api key"]):
|
|
return False
|
|
# Default to retryable for unknown errors
|
|
return True
|
|
|
|
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
|
"""Embed texts using load-balanced endpoint selection.
|
|
|
|
Args:
|
|
texts: Single text or iterable of texts to embed.
|
|
**kwargs: Additional arguments passed to underlying embedder.
|
|
|
|
Returns:
|
|
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
|
|
|
Raises:
|
|
RuntimeError: If all endpoints fail after retries.
|
|
"""
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
else:
|
|
texts = list(texts)
|
|
|
|
last_error: Optional[Exception] = None
|
|
tried_endpoints: set = set()
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
endpoint = self._select_endpoint()
|
|
|
|
if endpoint is None:
|
|
# All endpoints unavailable, wait for shortest cooldown
|
|
min_cooldown = min(
|
|
(ep.cooldown_until - time.time() for ep in self._endpoints
|
|
if ep.status == EndpointStatus.COOLING),
|
|
default=self.default_cooldown
|
|
)
|
|
if min_cooldown > 0 and attempt < self.max_retries:
|
|
wait_time = min(min_cooldown, 30) # Cap wait at 30s
|
|
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
|
|
time.sleep(wait_time)
|
|
continue
|
|
break
|
|
|
|
# Track tried endpoints to avoid infinite loops
|
|
endpoint_id = id(endpoint)
|
|
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
|
|
# Already tried all endpoints
|
|
break
|
|
tried_endpoints.add(endpoint_id)
|
|
|
|
# Acquire slot
|
|
with endpoint.lock:
|
|
endpoint.active_requests += 1
|
|
|
|
try:
|
|
start_time = time.time()
|
|
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
|
|
# Record success
|
|
endpoint.record_success(latency_ms)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
endpoint.record_failure()
|
|
|
|
if self._is_rate_limit_error(e):
|
|
# Rate limited - set cooldown
|
|
retry_after = self._parse_retry_after(e) or self.default_cooldown
|
|
endpoint.set_cooldown(retry_after)
|
|
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
|
|
f"cooling for {retry_after}s")
|
|
|
|
elif not self._is_retryable_error(e):
|
|
# Permanent failure (auth error, etc.)
|
|
endpoint.mark_failed()
|
|
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
|
|
|
|
else:
|
|
# Temporary error - short cooldown
|
|
endpoint.set_cooldown(5.0)
|
|
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
|
|
|
|
finally:
|
|
with endpoint.lock:
|
|
endpoint.active_requests -= 1
|
|
|
|
# All retries exhausted
|
|
available = self.available_endpoint_count
|
|
raise RuntimeError(
|
|
f"All embedding attempts failed after {self.max_retries + 1} tries. "
|
|
f"Available endpoints: {available}/{len(self._endpoints)}. "
|
|
f"Last error: {last_error}"
|
|
)
|
|
|
|
|
|
def create_rotational_embedder(
|
|
endpoints_config: List[Dict[str, Any]],
|
|
strategy: str = "latency_aware",
|
|
default_cooldown: float = 60.0,
|
|
) -> RotationalEmbedder:
|
|
"""Factory function to create RotationalEmbedder from config dicts.
|
|
|
|
Args:
|
|
endpoints_config: List of endpoint configuration dicts with keys:
|
|
- model: Model identifier (required)
|
|
- api_key: API key (optional)
|
|
- api_base: API base URL (optional)
|
|
- weight: Request weight (optional, default 1.0)
|
|
- max_concurrent: Max concurrent requests (optional, default 4)
|
|
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
|
|
default_cooldown: Default cooldown seconds for rate limits
|
|
|
|
Returns:
|
|
Configured RotationalEmbedder instance
|
|
|
|
Example config:
|
|
endpoints_config = [
|
|
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
|
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
|
]
|
|
"""
|
|
endpoints = []
|
|
for cfg in endpoints_config:
|
|
endpoints.append(EndpointConfig(
|
|
model=cfg["model"],
|
|
api_key=cfg.get("api_key"),
|
|
api_base=cfg.get("api_base"),
|
|
weight=cfg.get("weight", 1.0),
|
|
max_concurrent=cfg.get("max_concurrent", 4),
|
|
))
|
|
|
|
strategy_enum = SelectionStrategy[strategy.upper()]
|
|
|
|
return RotationalEmbedder(
|
|
endpoints=endpoints,
|
|
strategy=strategy_enum,
|
|
default_cooldown=default_cooldown,
|
|
)
|