From 40e61b30d697c4c7c4cdf74cac475ea98e9e4806 Mon Sep 17 00:00:00 2001 From: catlog22 Date: Thu, 25 Dec 2025 11:01:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=A4=9A=E7=AB=AF?= =?UTF-8?q?=E7=82=B9=E6=94=AF=E6=8C=81=E5=92=8C=E8=B4=9F=E8=BD=BD=E5=9D=87?= =?UTF-8?q?=E8=A1=A1=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=A2=9E=E5=BC=BA=20LiteLL?= =?UTF-8?q?M=20=E5=B5=8C=E5=85=A5=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ccw/src/tools/smart-search.ts | 5 +- codex-lens/src/codexlens/cli/commands.py | 56 ++- .../src/codexlens/cli/embedding_manager.py | 174 ++++++- codex-lens/src/codexlens/cli/output.py | 13 +- codex-lens/src/codexlens/config.py | 31 +- codex-lens/src/codexlens/semantic/factory.py | 43 +- .../codexlens/semantic/rotational_embedder.py | 434 ++++++++++++++++++ 7 files changed, 727 insertions(+), 29 deletions(-) create mode 100644 codex-lens/src/codexlens/semantic/rotational_embedder.py diff --git a/ccw/src/tools/smart-search.ts b/ccw/src/tools/smart-search.ts index 5e9276a1..5d54ac48 100644 --- a/ccw/src/tools/smart-search.ts +++ b/ccw/src/tools/smart-search.ts @@ -290,7 +290,7 @@ interface IndexStatus { file_count?: number; embeddings_coverage_percent?: number; total_chunks?: number; - model_info?: ModelInfo; + model_info?: ModelInfo | null; warning?: string; } @@ -359,7 +359,8 @@ async function checkIndexStatus(path: string = '.'): Promise { file_count: status.total_files, embeddings_coverage_percent: embeddingsCoverage, total_chunks: totalChunks, - model_info: modelInfo, + // Ensure model_info is null instead of undefined so it's included in JSON + model_info: modelInfo ?? null, warning, }; } catch { diff --git a/codex-lens/src/codexlens/cli/commands.py b/codex-lens/src/codexlens/cli/commands.py index fec2e552..111ed5db 100644 --- a/codex-lens/src/codexlens/cli/commands.py +++ b/codex-lens/src/codexlens/cli/commands.py @@ -759,6 +759,16 @@ def status( console.print(f" Coverage: {embeddings_info['coverage_percent']:.1f}%") console.print(f" Total Chunks: {embeddings_info['total_chunks']}") + # Display model information if available + model_info = embeddings_info.get('model_info') + if model_info: + console.print("\n[bold]Embedding Model:[/bold]") + console.print(f" Backend: [cyan]{model_info.get('backend', 'unknown')}[/cyan]") + console.print(f" Model: [cyan]{model_info.get('model_profile', 'unknown')}[/cyan] ({model_info.get('model_name', '')})") + console.print(f" Dimensions: {model_info.get('embedding_dim', 'unknown')}") + if model_info.get('updated_at'): + console.print(f" Last Updated: {model_info['updated_at']}") + except StorageError as exc: if json_mode: print_json(success=False, error=f"Storage error: {exc}") @@ -1878,7 +1888,7 @@ def embeddings_generate( """ _configure_logging(verbose, json_mode) - from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive + from codexlens.cli.embedding_manager import generate_embeddings, generate_embeddings_recursive, scan_for_model_conflicts # Validate backend valid_backends = ["fastembed", "litellm"] @@ -1946,6 +1956,50 @@ def embeddings_generate( console.print(f"Concurrency: [cyan]{max_workers} workers[/cyan]") console.print() + # Pre-check for model conflicts (only if not forcing) + if not force: + # Determine the index root for conflict scanning + scan_root = index_root if use_recursive else (index_path.parent if index_path else None) + + if scan_root: + conflict_result = scan_for_model_conflicts(scan_root, backend, model) + + if conflict_result["has_conflict"]: + existing = conflict_result["existing_config"] + conflict_count = len(conflict_result["conflicts"]) + + if json_mode: + # JSON mode: return structured error for UI handling + print_json( + success=False, + error="Model conflict detected", + code="MODEL_CONFLICT", + existing_config=existing, + target_config=conflict_result["target_config"], + conflict_count=conflict_count, + conflicts=conflict_result["conflicts"][:5], # Show first 5 conflicts + hint="Use --force to overwrite existing embeddings with the new model", + ) + raise typer.Exit(code=1) + else: + # Interactive mode: show warning and ask for confirmation + console.print("[yellow]⚠ Model Conflict Detected[/yellow]") + console.print(f" Existing: [red]{existing['backend']}/{existing['model']}[/red] ({existing.get('embedding_dim', '?')} dim)") + console.print(f" Requested: [green]{backend}/{model}[/green]") + console.print(f" Affected indexes: [yellow]{conflict_count}[/yellow]") + console.print() + console.print("[dim]Mixing different embedding models in the same index is not supported.[/dim]") + console.print("[dim]Overwriting will delete all existing embeddings and regenerate with the new model.[/dim]") + console.print() + + # Ask for confirmation + if typer.confirm("Overwrite existing embeddings with the new model?", default=False): + force = True + console.print("[green]Confirmed.[/green] Proceeding with overwrite...\n") + else: + console.print("[yellow]Cancelled.[/yellow] Use --force to skip this prompt.") + raise typer.Exit(code=0) + if use_recursive: result = generate_embeddings_recursive( index_root, diff --git a/codex-lens/src/codexlens/cli/embedding_manager.py b/codex-lens/src/codexlens/cli/embedding_manager.py index 85fc8dbb..e28439ca 100644 --- a/codex-lens/src/codexlens/cli/embedding_manager.py +++ b/codex-lens/src/codexlens/cli/embedding_manager.py @@ -235,18 +235,25 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]: } -def _get_embedding_defaults() -> tuple[str, str, bool]: +def _get_embedding_defaults() -> tuple[str, str, bool, List, str, float]: """Get default embedding settings from config. Returns: - Tuple of (backend, model, use_gpu) + Tuple of (backend, model, use_gpu, endpoints, strategy, cooldown) """ try: from codexlens.config import Config config = Config.load() - return config.embedding_backend, config.embedding_model, config.embedding_use_gpu + return ( + config.embedding_backend, + config.embedding_model, + config.embedding_use_gpu, + config.embedding_endpoints, + config.embedding_strategy, + config.embedding_cooldown, + ) except Exception: - return "fastembed", "code", True + return "fastembed", "code", True, [], "latency_aware", 60.0 def generate_embeddings( @@ -260,6 +267,9 @@ def generate_embeddings( use_gpu: Optional[bool] = None, max_tokens_per_batch: Optional[int] = None, max_workers: Optional[int] = None, + endpoints: Optional[List] = None, + strategy: Optional[str] = None, + cooldown: Optional[float] = None, ) -> Dict[str, any]: """Generate embeddings for an index using memory-efficient batch processing. @@ -284,14 +294,18 @@ def generate_embeddings( If None, attempts to get from embedder.max_tokens, then falls back to 8000. If set, overrides automatic detection. max_workers: Maximum number of concurrent API calls. - If None, uses dynamic defaults: 1 for fastembed (CPU bound), - 4 for litellm (network I/O bound). + If None, uses dynamic defaults based on backend and endpoint count. + endpoints: Optional list of endpoint configurations for multi-API load balancing. + Each dict has keys: model, api_key, api_base, weight. + strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware). + cooldown: Default cooldown seconds for rate-limited endpoints. Returns: Result dictionary with generation statistics """ # Get defaults from config if not specified - default_backend, default_model, default_gpu = _get_embedding_defaults() + (default_backend, default_model, default_gpu, + default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults() if embedding_backend is None: embedding_backend = default_backend @@ -299,13 +313,26 @@ def generate_embeddings( model_profile = default_model if use_gpu is None: use_gpu = default_gpu + if endpoints is None: + endpoints = default_endpoints + if strategy is None: + strategy = default_strategy + if cooldown is None: + cooldown = default_cooldown - # Set dynamic max_workers default based on backend type + # Calculate endpoint count for worker scaling + endpoint_count = len(endpoints) if endpoints else 1 + + # Set dynamic max_workers default based on backend type and endpoint count # - FastEmbed: CPU-bound, sequential is optimal (1 worker) - # - LiteLLM: Network I/O bound, concurrent calls improve throughput (4 workers) + # - LiteLLM single endpoint: 4 workers default + # - LiteLLM multi-endpoint: workers = endpoint_count * 2 (to saturate all APIs) if max_workers is None: if embedding_backend == "litellm": - max_workers = 4 + if endpoint_count > 1: + max_workers = min(endpoint_count * 2, 16) # Cap at 16 workers + else: + max_workers = 4 else: max_workers = 1 @@ -354,13 +381,20 @@ def generate_embeddings( from codexlens.semantic.vector_store import VectorStore from codexlens.semantic.chunker import Chunker, ChunkConfig - # Initialize embedder using factory (supports both fastembed and litellm) + # Initialize embedder using factory (supports fastembed, litellm, and rotational) # For fastembed: model_profile is a profile name (fast/code/multilingual/balanced) # For litellm: model_profile is a model name (e.g., qwen3-embedding) + # For multi-endpoint: endpoints list enables load balancing if embedding_backend == "fastembed": embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu) elif embedding_backend == "litellm": - embedder = get_embedder_factory(backend="litellm", model=model_profile) + embedder = get_embedder_factory( + backend="litellm", + model=model_profile, + endpoints=endpoints if endpoints else None, + strategy=strategy, + cooldown=cooldown, + ) else: return { "success": False, @@ -375,7 +409,10 @@ def generate_embeddings( skip_token_count=True )) + # Log embedder info with endpoint count for multi-endpoint mode if progress_callback: + if endpoint_count > 1: + progress_callback(f"Using {endpoint_count} API endpoints with {strategy} strategy") progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)") except Exception as e: @@ -684,6 +721,9 @@ def generate_embeddings_recursive( use_gpu: Optional[bool] = None, max_tokens_per_batch: Optional[int] = None, max_workers: Optional[int] = None, + endpoints: Optional[List] = None, + strategy: Optional[str] = None, + cooldown: Optional[float] = None, ) -> Dict[str, any]: """Generate embeddings for all index databases in a project recursively. @@ -704,14 +744,17 @@ def generate_embeddings_recursive( If None, attempts to get from embedder.max_tokens, then falls back to 8000. If set, overrides automatic detection. max_workers: Maximum number of concurrent API calls. - If None, uses dynamic defaults: 1 for fastembed (CPU bound), - 4 for litellm (network I/O bound). + If None, uses dynamic defaults based on backend and endpoint count. + endpoints: Optional list of endpoint configurations for multi-API load balancing. + strategy: Selection strategy for multi-endpoint mode. + cooldown: Default cooldown seconds for rate-limited endpoints. Returns: Aggregated result dictionary with generation statistics """ # Get defaults from config if not specified - default_backend, default_model, default_gpu = _get_embedding_defaults() + (default_backend, default_model, default_gpu, + default_endpoints, default_strategy, default_cooldown) = _get_embedding_defaults() if embedding_backend is None: embedding_backend = default_backend @@ -719,11 +762,23 @@ def generate_embeddings_recursive( model_profile = default_model if use_gpu is None: use_gpu = default_gpu + if endpoints is None: + endpoints = default_endpoints + if strategy is None: + strategy = default_strategy + if cooldown is None: + cooldown = default_cooldown - # Set dynamic max_workers default based on backend type + # Calculate endpoint count for worker scaling + endpoint_count = len(endpoints) if endpoints else 1 + + # Set dynamic max_workers default based on backend type and endpoint count if max_workers is None: if embedding_backend == "litellm": - max_workers = 4 + if endpoint_count > 1: + max_workers = min(endpoint_count * 2, 16) + else: + max_workers = 4 else: max_workers = 1 @@ -765,6 +820,9 @@ def generate_embeddings_recursive( use_gpu=use_gpu, max_tokens_per_batch=max_tokens_per_batch, max_workers=max_workers, + endpoints=endpoints, + strategy=strategy, + cooldown=cooldown, ) all_results.append({ @@ -958,3 +1016,85 @@ def get_embedding_stats_summary(index_root: Path) -> Dict[str, any]: "indexes": index_stats, }, } + + +def scan_for_model_conflicts( + index_root: Path, + target_backend: str, + target_model: str, +) -> Dict[str, any]: + """Scan for model conflicts across all indexes in a directory. + + Checks if any existing embeddings were generated with a different + backend or model than the target configuration. + + Args: + index_root: Root index directory to scan + target_backend: Target embedding backend (fastembed or litellm) + target_model: Target model profile/name + + Returns: + Dictionary with: + - has_conflict: True if any index has different model config + - existing_config: Config from first index with embeddings (if any) + - target_config: The requested configuration + - conflicts: List of conflicting index paths with their configs + - indexes_with_embeddings: Count of indexes that have embeddings + """ + index_files = discover_all_index_dbs(index_root) + + if not index_files: + return { + "has_conflict": False, + "existing_config": None, + "target_config": {"backend": target_backend, "model": target_model}, + "conflicts": [], + "indexes_with_embeddings": 0, + } + + conflicts = [] + existing_config = None + indexes_with_embeddings = 0 + + for index_path in index_files: + try: + from codexlens.semantic.vector_store import VectorStore + + with VectorStore(index_path) as vs: + config = vs.get_model_config() + if config and config.get("model_profile"): + indexes_with_embeddings += 1 + + # Store first existing config as reference + if existing_config is None: + existing_config = { + "backend": config.get("backend"), + "model": config.get("model_profile"), + "model_name": config.get("model_name"), + "embedding_dim": config.get("embedding_dim"), + } + + # Check for conflict: different backend OR different model + existing_backend = config.get("backend", "") + existing_model = config.get("model_profile", "") + + if existing_backend != target_backend or existing_model != target_model: + conflicts.append({ + "path": str(index_path), + "existing": { + "backend": existing_backend, + "model": existing_model, + "model_name": config.get("model_name"), + }, + }) + except Exception as e: + logger.debug(f"Failed to check model config for {index_path}: {e}") + continue + + return { + "has_conflict": len(conflicts) > 0, + "existing_config": existing_config, + "target_config": {"backend": target_backend, "model": target_model}, + "conflicts": conflicts, + "indexes_with_embeddings": indexes_with_embeddings, + } \ No newline at end of file diff --git a/codex-lens/src/codexlens/cli/output.py b/codex-lens/src/codexlens/cli/output.py index 15659441..1abfb4d2 100644 --- a/codex-lens/src/codexlens/cli/output.py +++ b/codex-lens/src/codexlens/cli/output.py @@ -35,12 +35,23 @@ def _to_jsonable(value: Any) -> Any: return value -def print_json(*, success: bool, result: Any = None, error: str | None = None) -> None: +def print_json(*, success: bool, result: Any = None, error: str | None = None, **kwargs: Any) -> None: + """Print JSON output with optional additional fields. + + Args: + success: Whether the operation succeeded + result: Result data (used when success=True) + error: Error message (used when success=False) + **kwargs: Additional fields to include in the payload (e.g., code, details) + """ payload: dict[str, Any] = {"success": success} if success: payload["result"] = _to_jsonable(result) else: payload["error"] = error or "Unknown error" + # Include additional error details if provided + for key, value in kwargs.items(): + payload[key] = _to_jsonable(value) console.print_json(json.dumps(payload, ensure_ascii=False)) diff --git a/codex-lens/src/codexlens/config.py b/codex-lens/src/codexlens/config.py index 0a919496..4a869a62 100644 --- a/codex-lens/src/codexlens/config.py +++ b/codex-lens/src/codexlens/config.py @@ -100,6 +100,12 @@ class Config: # For litellm: model name from config (e.g., "qwen3-embedding") embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration + # Multi-endpoint configuration for litellm backend + embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list) + # List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}] + embedding_strategy: str = "latency_aware" # round_robin, latency_aware, weighted_random + embedding_cooldown: float = 60.0 # Default cooldown seconds for rate-limited endpoints + def __post_init__(self) -> None: try: self.data_dir = self.data_dir.expanduser().resolve() @@ -151,12 +157,19 @@ class Config: def save_settings(self) -> None: """Save embedding and other settings to file.""" + embedding_config = { + "backend": self.embedding_backend, + "model": self.embedding_model, + "use_gpu": self.embedding_use_gpu, + } + # Include multi-endpoint config if present + if self.embedding_endpoints: + embedding_config["endpoints"] = self.embedding_endpoints + embedding_config["strategy"] = self.embedding_strategy + embedding_config["cooldown"] = self.embedding_cooldown + settings = { - "embedding": { - "backend": self.embedding_backend, - "model": self.embedding_model, - "use_gpu": self.embedding_use_gpu, - }, + "embedding": embedding_config, "llm": { "enabled": self.llm_enabled, "tool": self.llm_tool, @@ -185,6 +198,14 @@ class Config: if "use_gpu" in embedding: self.embedding_use_gpu = embedding["use_gpu"] + # Load multi-endpoint configuration + if "endpoints" in embedding: + self.embedding_endpoints = embedding["endpoints"] + if "strategy" in embedding: + self.embedding_strategy = embedding["strategy"] + if "cooldown" in embedding: + self.embedding_cooldown = embedding["cooldown"] + # Load LLM settings llm = settings.get("llm", {}) if "enabled" in llm: diff --git a/codex-lens/src/codexlens/semantic/factory.py b/codex-lens/src/codexlens/semantic/factory.py index 5baacdd4..fe360539 100644 --- a/codex-lens/src/codexlens/semantic/factory.py +++ b/codex-lens/src/codexlens/semantic/factory.py @@ -5,7 +5,7 @@ Provides a unified interface for instantiating different embedder backends. from __future__ import annotations -from typing import Any +from typing import Any, Dict, List, Optional from .base import BaseEmbedder @@ -15,6 +15,9 @@ def get_embedder( profile: str = "code", model: str = "default", use_gpu: bool = True, + endpoints: Optional[List[Dict[str, Any]]] = None, + strategy: str = "latency_aware", + cooldown: float = 60.0, **kwargs: Any, ) -> BaseEmbedder: """Factory function to create embedder based on backend. @@ -29,6 +32,13 @@ def get_embedder( Used only when backend="litellm". Default: "default" use_gpu: Whether to use GPU acceleration when available (default: True). Used only when backend="fastembed". + endpoints: Optional list of endpoint configurations for multi-endpoint load balancing. + Each endpoint is a dict with keys: model, api_key, api_base, weight. + Used only when backend="litellm" and multiple endpoints provided. + strategy: Selection strategy for multi-endpoint mode: + "round_robin", "latency_aware", "weighted_random". + Default: "latency_aware" + cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0) **kwargs: Additional backend-specific arguments Returns: @@ -47,13 +57,40 @@ def get_embedder( Create litellm embedder: >>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small") + + Create rotational embedder with multiple endpoints: + >>> endpoints = [ + ... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."}, + ... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."}, + ... ] + >>> embedder = get_embedder(backend="litellm", endpoints=endpoints) """ if backend == "fastembed": from .embedder import Embedder return Embedder(profile=profile, use_gpu=use_gpu, **kwargs) elif backend == "litellm": - from .litellm_embedder import LiteLLMEmbedderWrapper - return LiteLLMEmbedderWrapper(model=model, **kwargs) + # Check if multi-endpoint mode is requested + if endpoints and len(endpoints) > 1: + from .rotational_embedder import create_rotational_embedder + return create_rotational_embedder( + endpoints_config=endpoints, + strategy=strategy, + default_cooldown=cooldown, + ) + elif endpoints and len(endpoints) == 1: + # Single endpoint in list - use it directly + ep = endpoints[0] + ep_kwargs = {**kwargs} + if "api_key" in ep: + ep_kwargs["api_key"] = ep["api_key"] + if "api_base" in ep: + ep_kwargs["api_base"] = ep["api_base"] + from .litellm_embedder import LiteLLMEmbedderWrapper + return LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs) + else: + # No endpoints list - use model parameter + from .litellm_embedder import LiteLLMEmbedderWrapper + return LiteLLMEmbedderWrapper(model=model, **kwargs) else: raise ValueError( f"Unknown backend: {backend}. " diff --git a/codex-lens/src/codexlens/semantic/rotational_embedder.py b/codex-lens/src/codexlens/semantic/rotational_embedder.py new file mode 100644 index 00000000..ff0f41ac --- /dev/null +++ b/codex-lens/src/codexlens/semantic/rotational_embedder.py @@ -0,0 +1,434 @@ +"""Rotational embedder for multi-endpoint API load balancing. + +Provides intelligent load balancing across multiple LiteLLM embedding endpoints +to maximize throughput while respecting rate limits. +""" + +from __future__ import annotations + +import logging +import random +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional + +import numpy as np + +from .base import BaseEmbedder + +logger = logging.getLogger(__name__) + + +class EndpointStatus(Enum): + """Status of an API endpoint.""" + AVAILABLE = "available" + COOLING = "cooling" # Rate limited, temporarily unavailable + FAILED = "failed" # Permanent failure (auth error, etc.) + + +class SelectionStrategy(Enum): + """Strategy for selecting endpoints.""" + ROUND_ROBIN = "round_robin" + LATENCY_AWARE = "latency_aware" + WEIGHTED_RANDOM = "weighted_random" + + +@dataclass +class EndpointConfig: + """Configuration for a single API endpoint.""" + model: str + api_key: Optional[str] = None + api_base: Optional[str] = None + weight: float = 1.0 # Higher weight = more requests + max_concurrent: int = 4 # Max concurrent requests to this endpoint + + +@dataclass +class EndpointState: + """Runtime state for an endpoint.""" + config: EndpointConfig + embedder: Any = None # LiteLLMEmbedderWrapper instance + + # Health metrics + status: EndpointStatus = EndpointStatus.AVAILABLE + cooldown_until: float = 0.0 # Unix timestamp when cooldown ends + + # Performance metrics + total_requests: int = 0 + total_failures: int = 0 + avg_latency_ms: float = 0.0 + last_latency_ms: float = 0.0 + + # Concurrency tracking + active_requests: int = 0 + lock: threading.Lock = field(default_factory=threading.Lock) + + def is_available(self) -> bool: + """Check if endpoint is available for requests.""" + if self.status == EndpointStatus.FAILED: + return False + if self.status == EndpointStatus.COOLING: + if time.time() >= self.cooldown_until: + self.status = EndpointStatus.AVAILABLE + return True + return False + return True + + def set_cooldown(self, seconds: float) -> None: + """Put endpoint in cooldown state.""" + self.status = EndpointStatus.COOLING + self.cooldown_until = time.time() + seconds + logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s") + + def mark_failed(self) -> None: + """Mark endpoint as permanently failed.""" + self.status = EndpointStatus.FAILED + logger.error(f"Endpoint {self.config.model} marked as failed") + + def record_success(self, latency_ms: float) -> None: + """Record successful request.""" + self.total_requests += 1 + self.last_latency_ms = latency_ms + # Exponential moving average for latency + alpha = 0.3 + if self.avg_latency_ms == 0: + self.avg_latency_ms = latency_ms + else: + self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms + + def record_failure(self) -> None: + """Record failed request.""" + self.total_requests += 1 + self.total_failures += 1 + + @property + def health_score(self) -> float: + """Calculate health score (0-1) based on metrics.""" + if not self.is_available(): + return 0.0 + + # Base score from success rate + if self.total_requests > 0: + success_rate = 1 - (self.total_failures / self.total_requests) + else: + success_rate = 1.0 + + # Latency factor (faster = higher score) + # Normalize: 100ms = 1.0, 1000ms = 0.1 + if self.avg_latency_ms > 0: + latency_factor = min(1.0, 100 / self.avg_latency_ms) + else: + latency_factor = 1.0 + + # Availability factor (less concurrent = more available) + if self.config.max_concurrent > 0: + availability = 1 - (self.active_requests / self.config.max_concurrent) + else: + availability = 1.0 + + # Combined score with weights + return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight + + +class RotationalEmbedder(BaseEmbedder): + """Embedder that load balances across multiple API endpoints. + + Features: + - Intelligent endpoint selection based on latency and health + - Automatic failover on rate limits (429) and server errors + - Cooldown management to respect rate limits + - Thread-safe concurrent request handling + + Args: + endpoints: List of endpoint configurations + strategy: Selection strategy (default: latency_aware) + default_cooldown: Default cooldown seconds for rate limits (default: 60) + max_retries: Maximum retry attempts across all endpoints (default: 3) + """ + + def __init__( + self, + endpoints: List[EndpointConfig], + strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE, + default_cooldown: float = 60.0, + max_retries: int = 3, + ) -> None: + if not endpoints: + raise ValueError("At least one endpoint must be provided") + + self.strategy = strategy + self.default_cooldown = default_cooldown + self.max_retries = max_retries + + # Initialize endpoint states + self._endpoints: List[EndpointState] = [] + self._lock = threading.Lock() + self._round_robin_index = 0 + + # Create embedder instances for each endpoint + from .litellm_embedder import LiteLLMEmbedderWrapper + + for config in endpoints: + # Build kwargs for LiteLLMEmbedderWrapper + kwargs: Dict[str, Any] = {} + if config.api_key: + kwargs["api_key"] = config.api_key + if config.api_base: + kwargs["api_base"] = config.api_base + + try: + embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs) + state = EndpointState(config=config, embedder=embedder) + self._endpoints.append(state) + logger.info(f"Initialized endpoint: {config.model}") + except Exception as e: + logger.error(f"Failed to initialize endpoint {config.model}: {e}") + + if not self._endpoints: + raise ValueError("Failed to initialize any endpoints") + + # Cache embedding properties from first endpoint + self._embedding_dim = self._endpoints[0].embedder.embedding_dim + self._model_name = f"rotational({len(self._endpoints)} endpoints)" + self._max_tokens = self._endpoints[0].embedder.max_tokens + + @property + def embedding_dim(self) -> int: + """Return embedding dimensions.""" + return self._embedding_dim + + @property + def model_name(self) -> str: + """Return model name.""" + return self._model_name + + @property + def max_tokens(self) -> int: + """Return maximum token limit.""" + return self._max_tokens + + @property + def endpoint_count(self) -> int: + """Return number of configured endpoints.""" + return len(self._endpoints) + + @property + def available_endpoint_count(self) -> int: + """Return number of available endpoints.""" + return sum(1 for ep in self._endpoints if ep.is_available()) + + def get_endpoint_stats(self) -> List[Dict[str, Any]]: + """Get statistics for all endpoints.""" + stats = [] + for ep in self._endpoints: + stats.append({ + "model": ep.config.model, + "status": ep.status.value, + "total_requests": ep.total_requests, + "total_failures": ep.total_failures, + "avg_latency_ms": round(ep.avg_latency_ms, 2), + "health_score": round(ep.health_score, 3), + "active_requests": ep.active_requests, + }) + return stats + + def _select_endpoint(self) -> Optional[EndpointState]: + """Select best available endpoint based on strategy.""" + available = [ep for ep in self._endpoints if ep.is_available()] + + if not available: + return None + + if self.strategy == SelectionStrategy.ROUND_ROBIN: + with self._lock: + self._round_robin_index = (self._round_robin_index + 1) % len(available) + return available[self._round_robin_index] + + elif self.strategy == SelectionStrategy.LATENCY_AWARE: + # Sort by health score (descending) and pick top candidate + # Add small random factor to prevent thundering herd + scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available] + scored.sort(key=lambda x: x[1], reverse=True) + return scored[0][0] + + elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM: + # Weighted random selection based on health scores + scores = [ep.health_score for ep in available] + total = sum(scores) + if total == 0: + return random.choice(available) + + weights = [s / total for s in scores] + return random.choices(available, weights=weights, k=1)[0] + + return available[0] + + def _parse_retry_after(self, error: Exception) -> Optional[float]: + """Extract Retry-After value from error if available.""" + error_str = str(error) + + # Try to find Retry-After in error message + import re + match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str) + if match: + return float(match.group(1)) + + return None + + def _is_rate_limit_error(self, error: Exception) -> bool: + """Check if error is a rate limit error.""" + error_str = str(error).lower() + return any(x in error_str for x in ["429", "rate limit", "too many requests"]) + + def _is_retryable_error(self, error: Exception) -> bool: + """Check if error is retryable (not auth/config error).""" + error_str = str(error).lower() + # Retryable errors + if any(x in error_str for x in ["429", "rate limit", "502", "503", "504", + "timeout", "connection", "service unavailable"]): + return True + # Non-retryable errors (auth, config) + if any(x in error_str for x in ["401", "403", "invalid", "authentication", + "unauthorized", "api key"]): + return False + # Default to retryable for unknown errors + return True + + def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray: + """Embed texts using load-balanced endpoint selection. + + Args: + texts: Single text or iterable of texts to embed. + **kwargs: Additional arguments passed to underlying embedder. + + Returns: + numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings. + + Raises: + RuntimeError: If all endpoints fail after retries. + """ + if isinstance(texts, str): + texts = [texts] + else: + texts = list(texts) + + last_error: Optional[Exception] = None + tried_endpoints: set = set() + + for attempt in range(self.max_retries + 1): + endpoint = self._select_endpoint() + + if endpoint is None: + # All endpoints unavailable, wait for shortest cooldown + min_cooldown = min( + (ep.cooldown_until - time.time() for ep in self._endpoints + if ep.status == EndpointStatus.COOLING), + default=self.default_cooldown + ) + if min_cooldown > 0 and attempt < self.max_retries: + wait_time = min(min_cooldown, 30) # Cap wait at 30s + logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...") + time.sleep(wait_time) + continue + break + + # Track tried endpoints to avoid infinite loops + endpoint_id = id(endpoint) + if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints): + # Already tried all endpoints + break + tried_endpoints.add(endpoint_id) + + # Acquire slot + with endpoint.lock: + endpoint.active_requests += 1 + + try: + start_time = time.time() + result = endpoint.embedder.embed_to_numpy(texts, **kwargs) + latency_ms = (time.time() - start_time) * 1000 + + # Record success + endpoint.record_success(latency_ms) + + return result + + except Exception as e: + last_error = e + endpoint.record_failure() + + if self._is_rate_limit_error(e): + # Rate limited - set cooldown + retry_after = self._parse_retry_after(e) or self.default_cooldown + endpoint.set_cooldown(retry_after) + logger.warning(f"Endpoint {endpoint.config.model} rate limited, " + f"cooling for {retry_after}s") + + elif not self._is_retryable_error(e): + # Permanent failure (auth error, etc.) + endpoint.mark_failed() + logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}") + + else: + # Temporary error - short cooldown + endpoint.set_cooldown(5.0) + logger.warning(f"Endpoint {endpoint.config.model} error: {e}") + + finally: + with endpoint.lock: + endpoint.active_requests -= 1 + + # All retries exhausted + available = self.available_endpoint_count + raise RuntimeError( + f"All embedding attempts failed after {self.max_retries + 1} tries. " + f"Available endpoints: {available}/{len(self._endpoints)}. " + f"Last error: {last_error}" + ) + + +def create_rotational_embedder( + endpoints_config: List[Dict[str, Any]], + strategy: str = "latency_aware", + default_cooldown: float = 60.0, +) -> RotationalEmbedder: + """Factory function to create RotationalEmbedder from config dicts. + + Args: + endpoints_config: List of endpoint configuration dicts with keys: + - model: Model identifier (required) + - api_key: API key (optional) + - api_base: API base URL (optional) + - weight: Request weight (optional, default 1.0) + - max_concurrent: Max concurrent requests (optional, default 4) + strategy: Selection strategy name (round_robin, latency_aware, weighted_random) + default_cooldown: Default cooldown seconds for rate limits + + Returns: + Configured RotationalEmbedder instance + + Example config: + endpoints_config = [ + {"model": "openai/text-embedding-3-small", "api_key": "sk-..."}, + {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."}, + ] + """ + endpoints = [] + for cfg in endpoints_config: + endpoints.append(EndpointConfig( + model=cfg["model"], + api_key=cfg.get("api_key"), + api_base=cfg.get("api_base"), + weight=cfg.get("weight", 1.0), + max_concurrent=cfg.get("max_concurrent", 4), + )) + + strategy_enum = SelectionStrategy[strategy.upper()] + + return RotationalEmbedder( + endpoints=endpoints, + strategy=strategy_enum, + default_cooldown=default_cooldown, + )