diff --git a/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py b/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py index 31a86fd0..ee84ee1e 100644 --- a/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py +++ b/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py @@ -102,6 +102,15 @@ class LiteLLMEmbedder(AbstractEmbedder): """Embedding vector size.""" return self._model_config.dimensions + @property + def max_input_tokens(self) -> int: + """Maximum token limit for embeddings. + + Returns the configured max_input_tokens from model config, + enabling adaptive batch sizing based on actual model capacity. + """ + return self._model_config.max_input_tokens + def _estimate_tokens(self, text: str) -> int: """Estimate token count for a text using fast heuristic. @@ -162,7 +171,7 @@ class LiteLLMEmbedder(AbstractEmbedder): texts: str | Sequence[str], *, batch_size: int | None = None, - max_tokens_per_batch: int = 30000, + max_tokens_per_batch: int | None = None, **kwargs: Any, ) -> NDArray[np.floating]: """Embed one or more texts. @@ -170,7 +179,8 @@ class LiteLLMEmbedder(AbstractEmbedder): Args: texts: Single text or sequence of texts batch_size: Batch size for processing (deprecated, use max_tokens_per_batch) - max_tokens_per_batch: Maximum estimated tokens per API call (default: 30000) + max_tokens_per_batch: Maximum estimated tokens per API call. + If None, uses 90% of model's max_input_tokens for safety margin. **kwargs: Additional arguments for litellm.embedding() Returns: @@ -196,6 +206,15 @@ class LiteLLMEmbedder(AbstractEmbedder): if self._provider_config.api_base and "encoding_format" not in embedding_kwargs: embedding_kwargs["encoding_format"] = "float" + # Determine adaptive max_tokens_per_batch + # Use 90% of model's max_input_tokens as safety margin + if max_tokens_per_batch is None: + max_tokens_per_batch = int(self.max_input_tokens * 0.9) + logger.debug( + f"Using adaptive batch size: {max_tokens_per_batch} tokens " + f"(90% of {self.max_input_tokens})" + ) + # Split into token-aware batches batches = self._create_batches(text_list, max_tokens_per_batch) diff --git a/ccw-litellm/src/ccw_litellm/config/loader.py b/ccw-litellm/src/ccw_litellm/config/loader.py index f5c7ec21..43ba90dd 100644 --- a/ccw-litellm/src/ccw_litellm/config/loader.py +++ b/ccw-litellm/src/ccw_litellm/config/loader.py @@ -109,6 +109,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A providers: dict[str, Any] = {} llm_models: dict[str, Any] = {} embedding_models: dict[str, Any] = {} + reranker_models: dict[str, Any] = {} default_provider: str | None = None for provider in json_config.get("providers", []): @@ -186,6 +187,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A "provider": provider_id, "model": model.get("name", ""), "dimensions": model.get("capabilities", {}).get("embeddingDimension", 1536), + "max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192), } # Add model-specific endpoint settings endpoint = model.get("endpointSettings", {}) @@ -196,6 +198,29 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A embedding_models[model_id] = embedding_model_config + # Convert reranker models + for model in provider.get("rerankerModels", []): + if not model.get("enabled", True): + continue + model_id = model.get("id", "") + if not model_id: + continue + + reranker_model_config: dict[str, Any] = { + "provider": provider_id, + "model": model.get("name", ""), + "max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192), + "top_k": model.get("capabilities", {}).get("topK", 50), + } + # Add model-specific endpoint settings + endpoint = model.get("endpointSettings", {}) + if endpoint.get("baseUrl"): + reranker_model_config["api_base"] = endpoint["baseUrl"] + if endpoint.get("timeout"): + reranker_model_config["timeout"] = endpoint["timeout"] + + reranker_models[model_id] = reranker_model_config + # Ensure we have defaults if no models found if not llm_models: llm_models["default"] = { @@ -208,6 +233,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A "provider": default_provider or "openai", "model": "text-embedding-3-small", "dimensions": 1536, + "max_input_tokens": 8191, } return { @@ -216,6 +242,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A "providers": providers, "llm_models": llm_models, "embedding_models": embedding_models, + "reranker_models": reranker_models, } diff --git a/ccw-litellm/src/ccw_litellm/config/models.py b/ccw-litellm/src/ccw_litellm/config/models.py index ee76cdc8..33e63a0e 100644 --- a/ccw-litellm/src/ccw_litellm/config/models.py +++ b/ccw-litellm/src/ccw_litellm/config/models.py @@ -34,6 +34,18 @@ class EmbeddingModelConfig(BaseModel): provider: str # "openai", "fastembed", "ollama", etc. model: str dimensions: int + max_input_tokens: int = 8192 # Maximum tokens per embedding request + + model_config = {"extra": "allow"} + + +class RerankerModelConfig(BaseModel): + """Reranker model configuration.""" + + provider: str # "siliconflow", "cohere", "jina", etc. + model: str + max_input_tokens: int = 8192 # Maximum tokens per reranking request + top_k: int = 50 # Default top_k for reranking model_config = {"extra": "allow"} @@ -69,6 +81,7 @@ class LiteLLMConfig(BaseModel): providers: dict[str, ProviderConfig] = Field(default_factory=dict) llm_models: dict[str, LLMModelConfig] = Field(default_factory=dict) embedding_models: dict[str, EmbeddingModelConfig] = Field(default_factory=dict) + reranker_models: dict[str, RerankerModelConfig] = Field(default_factory=dict) model_config = {"extra": "allow"} @@ -110,6 +123,25 @@ class LiteLLMConfig(BaseModel): ) return self.embedding_models[model] + def get_reranker_model(self, model: str = "default") -> RerankerModelConfig: + """Get reranker model configuration by name. + + Args: + model: Model name or "default" + + Returns: + Reranker model configuration + + Raises: + ValueError: If model not found + """ + if model not in self.reranker_models: + raise ValueError( + f"Reranker model '{model}' not found in configuration. " + f"Available models: {list(self.reranker_models.keys())}" + ) + return self.reranker_models[model] + def get_provider(self, provider: str) -> ProviderConfig: """Get provider configuration by name. diff --git a/ccw/src/templates/dashboard-js/i18n.js b/ccw/src/templates/dashboard-js/i18n.js index 49f3bd25..c1056ce0 100644 --- a/ccw/src/templates/dashboard-js/i18n.js +++ b/ccw/src/templates/dashboard-js/i18n.js @@ -1672,6 +1672,7 @@ const i18n = { // Embedding models 'apiSettings.embeddingDimensions': 'Dimensions', 'apiSettings.embeddingMaxTokens': 'Max Tokens', + 'apiSettings.rerankerTopK': 'Top K', 'apiSettings.selectEmbeddingModel': 'Select Embedding Model', // Model modal @@ -3698,6 +3699,7 @@ const i18n = { // Embedding models 'apiSettings.embeddingDimensions': '向量维度', 'apiSettings.embeddingMaxTokens': '最大 Token', + 'apiSettings.rerankerTopK': 'Top K', 'apiSettings.selectEmbeddingModel': '选择嵌入模型', // Model modal diff --git a/ccw/src/templates/dashboard-js/views/api-settings.js b/ccw/src/templates/dashboard-js/views/api-settings.js index f9a35427..1ba3e119 100644 --- a/ccw/src/templates/dashboard-js/views/api-settings.js +++ b/ccw/src/templates/dashboard-js/views/api-settings.js @@ -1163,7 +1163,7 @@ function renderProviderDetail(providerId) { var maskedKey = provider.apiKey ? '••••••••••••••••' + provider.apiKey.slice(-4) : '••••••••'; var currentApiBase = provider.apiBase || getDefaultApiBase(provider.type); // Show full endpoint URL preview based on active model tab - var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : '/chat/completions'; + var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : activeModelTab === 'reranker' ? '/rerank' : '/chat/completions'; var apiBasePreview = currentApiBase + endpointPath; var html = '