diff --git a/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py b/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py index 31a86fd0..ee84ee1e 100644 --- a/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py +++ b/ccw-litellm/src/ccw_litellm/clients/litellm_embedder.py @@ -102,6 +102,15 @@ class LiteLLMEmbedder(AbstractEmbedder): """Embedding vector size.""" return self._model_config.dimensions + @property + def max_input_tokens(self) -> int: + """Maximum token limit for embeddings. + + Returns the configured max_input_tokens from model config, + enabling adaptive batch sizing based on actual model capacity. + """ + return self._model_config.max_input_tokens + def _estimate_tokens(self, text: str) -> int: """Estimate token count for a text using fast heuristic. @@ -162,7 +171,7 @@ class LiteLLMEmbedder(AbstractEmbedder): texts: str | Sequence[str], *, batch_size: int | None = None, - max_tokens_per_batch: int = 30000, + max_tokens_per_batch: int | None = None, **kwargs: Any, ) -> NDArray[np.floating]: """Embed one or more texts. @@ -170,7 +179,8 @@ class LiteLLMEmbedder(AbstractEmbedder): Args: texts: Single text or sequence of texts batch_size: Batch size for processing (deprecated, use max_tokens_per_batch) - max_tokens_per_batch: Maximum estimated tokens per API call (default: 30000) + max_tokens_per_batch: Maximum estimated tokens per API call. + If None, uses 90% of model's max_input_tokens for safety margin. **kwargs: Additional arguments for litellm.embedding() Returns: @@ -196,6 +206,15 @@ class LiteLLMEmbedder(AbstractEmbedder): if self._provider_config.api_base and "encoding_format" not in embedding_kwargs: embedding_kwargs["encoding_format"] = "float" + # Determine adaptive max_tokens_per_batch + # Use 90% of model's max_input_tokens as safety margin + if max_tokens_per_batch is None: + max_tokens_per_batch = int(self.max_input_tokens * 0.9) + logger.debug( + f"Using adaptive batch size: {max_tokens_per_batch} tokens " + f"(90% of {self.max_input_tokens})" + ) + # Split into token-aware batches batches = self._create_batches(text_list, max_tokens_per_batch) diff --git a/ccw-litellm/src/ccw_litellm/config/loader.py b/ccw-litellm/src/ccw_litellm/config/loader.py index f5c7ec21..43ba90dd 100644 --- a/ccw-litellm/src/ccw_litellm/config/loader.py +++ b/ccw-litellm/src/ccw_litellm/config/loader.py @@ -109,6 +109,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A providers: dict[str, Any] = {} llm_models: dict[str, Any] = {} embedding_models: dict[str, Any] = {} + reranker_models: dict[str, Any] = {} default_provider: str | None = None for provider in json_config.get("providers", []): @@ -186,6 +187,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A "provider": provider_id, "model": model.get("name", ""), "dimensions": model.get("capabilities", {}).get("embeddingDimension", 1536), + "max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192), } # Add model-specific endpoint settings endpoint = model.get("endpointSettings", {}) @@ -196,6 +198,29 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A embedding_models[model_id] = embedding_model_config + # Convert reranker models + for model in provider.get("rerankerModels", []): + if not model.get("enabled", True): + continue + model_id = model.get("id", "") + if not model_id: + continue + + reranker_model_config: dict[str, Any] = { + "provider": provider_id, + "model": model.get("name", ""), + "max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192), + "top_k": model.get("capabilities", {}).get("topK", 50), + } + # Add model-specific endpoint settings + endpoint = model.get("endpointSettings", {}) + if endpoint.get("baseUrl"): + reranker_model_config["api_base"] = endpoint["baseUrl"] + if endpoint.get("timeout"): + reranker_model_config["timeout"] = endpoint["timeout"] + + reranker_models[model_id] = reranker_model_config + # Ensure we have defaults if no models found if not llm_models: llm_models["default"] = { @@ -208,6 +233,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A "provider": default_provider or "openai", "model": "text-embedding-3-small", "dimensions": 1536, + "max_input_tokens": 8191, } return { @@ -216,6 +242,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A "providers": providers, "llm_models": llm_models, "embedding_models": embedding_models, + "reranker_models": reranker_models, } diff --git a/ccw-litellm/src/ccw_litellm/config/models.py b/ccw-litellm/src/ccw_litellm/config/models.py index ee76cdc8..33e63a0e 100644 --- a/ccw-litellm/src/ccw_litellm/config/models.py +++ b/ccw-litellm/src/ccw_litellm/config/models.py @@ -34,6 +34,18 @@ class EmbeddingModelConfig(BaseModel): provider: str # "openai", "fastembed", "ollama", etc. model: str dimensions: int + max_input_tokens: int = 8192 # Maximum tokens per embedding request + + model_config = {"extra": "allow"} + + +class RerankerModelConfig(BaseModel): + """Reranker model configuration.""" + + provider: str # "siliconflow", "cohere", "jina", etc. + model: str + max_input_tokens: int = 8192 # Maximum tokens per reranking request + top_k: int = 50 # Default top_k for reranking model_config = {"extra": "allow"} @@ -69,6 +81,7 @@ class LiteLLMConfig(BaseModel): providers: dict[str, ProviderConfig] = Field(default_factory=dict) llm_models: dict[str, LLMModelConfig] = Field(default_factory=dict) embedding_models: dict[str, EmbeddingModelConfig] = Field(default_factory=dict) + reranker_models: dict[str, RerankerModelConfig] = Field(default_factory=dict) model_config = {"extra": "allow"} @@ -110,6 +123,25 @@ class LiteLLMConfig(BaseModel): ) return self.embedding_models[model] + def get_reranker_model(self, model: str = "default") -> RerankerModelConfig: + """Get reranker model configuration by name. + + Args: + model: Model name or "default" + + Returns: + Reranker model configuration + + Raises: + ValueError: If model not found + """ + if model not in self.reranker_models: + raise ValueError( + f"Reranker model '{model}' not found in configuration. " + f"Available models: {list(self.reranker_models.keys())}" + ) + return self.reranker_models[model] + def get_provider(self, provider: str) -> ProviderConfig: """Get provider configuration by name. diff --git a/ccw/src/templates/dashboard-js/i18n.js b/ccw/src/templates/dashboard-js/i18n.js index 49f3bd25..c1056ce0 100644 --- a/ccw/src/templates/dashboard-js/i18n.js +++ b/ccw/src/templates/dashboard-js/i18n.js @@ -1672,6 +1672,7 @@ const i18n = { // Embedding models 'apiSettings.embeddingDimensions': 'Dimensions', 'apiSettings.embeddingMaxTokens': 'Max Tokens', + 'apiSettings.rerankerTopK': 'Top K', 'apiSettings.selectEmbeddingModel': 'Select Embedding Model', // Model modal @@ -3698,6 +3699,7 @@ const i18n = { // Embedding models 'apiSettings.embeddingDimensions': '向量维度', 'apiSettings.embeddingMaxTokens': '最大 Token', + 'apiSettings.rerankerTopK': 'Top K', 'apiSettings.selectEmbeddingModel': '选择嵌入模型', // Model modal diff --git a/ccw/src/templates/dashboard-js/views/api-settings.js b/ccw/src/templates/dashboard-js/views/api-settings.js index f9a35427..1ba3e119 100644 --- a/ccw/src/templates/dashboard-js/views/api-settings.js +++ b/ccw/src/templates/dashboard-js/views/api-settings.js @@ -1163,7 +1163,7 @@ function renderProviderDetail(providerId) { var maskedKey = provider.apiKey ? '••••••••••••••••' + provider.apiKey.slice(-4) : '••••••••'; var currentApiBase = provider.apiBase || getDefaultApiBase(provider.type); // Show full endpoint URL preview based on active model tab - var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : '/chat/completions'; + var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : activeModelTab === 'reranker' ? '/rerank' : '/chat/completions'; var apiBasePreview = currentApiBase + endpointPath; var html = '
' + @@ -1322,10 +1322,17 @@ function renderModelTree(provider) { var embeddingBadge = model.capabilities && model.capabilities.embeddingDimension ? model.capabilities.embeddingDimension + 'd' : ''; - var displayBadge = activeModelTab === 'llm' ? badge : embeddingBadge; + + // Badge for reranker models shows max tokens + var rerankerBadge = model.capabilities && model.capabilities.maxInputTokens + ? formatContextWindow(model.capabilities.maxInputTokens) + : ''; + + var displayBadge = activeModelTab === 'llm' ? badge : activeModelTab === 'reranker' ? rerankerBadge : embeddingBadge; + var iconName = activeModelTab === 'llm' ? 'sparkles' : activeModelTab === 'reranker' ? 'arrow-up-down' : 'box'; html += '
' + - '' + + '' + '' + escapeHtml(model.name) + '' + (displayBadge ? '' + displayBadge + '' : '') + '
' + @@ -1966,14 +1973,25 @@ function showModelSettingsModal(providerId, modelId, modelType) { '' + '' + '
' + ) : isReranker ? ( + // Reranker capabilities - only maxInputTokens and topK + '
' + + '' + + '' + + '
' + + '
' + + '' + + '' + + '
' ) : ( + // Embedding capabilities - embeddingDimension and maxInputTokens '
' + '' + '' + '
' + '
' + '' + - '' + + '' + '
' )) + '
' + @@ -2070,14 +2088,14 @@ function saveModelSettings(event, providerId, modelId, modelType) { vision: document.getElementById('model-settings-vision').checked }; } else if (isReranker) { - var topKEl = document.getElementById('model-settings-top-k'); models[modelIndex].capabilities = { - topK: topKEl ? parseInt(topKEl.value) || 10 : 10 + maxInputTokens: parseInt(document.getElementById('model-settings-max-tokens').value) || 8192, + topK: parseInt(document.getElementById('model-settings-top-k').value) || 50 }; } else { models[modelIndex].capabilities = { embeddingDimension: parseInt(document.getElementById('model-settings-dimensions').value) || 1536, - contextWindow: parseInt(document.getElementById('model-settings-max-tokens').value) || 8192 + maxInputTokens: parseInt(document.getElementById('model-settings-max-tokens').value) || 8192 }; } @@ -2218,7 +2236,7 @@ function updateApiBasePreview(apiBase) { if (base.endsWith('/')) { base = base.slice(0, -1); } - var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : '/chat/completions'; + var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : activeModelTab === 'reranker' ? '/rerank' : '/chat/completions'; preview.textContent = t('apiSettings.preview') + ': ' + base + endpointPath; } diff --git a/codex-lens/src/codexlens/config.py b/codex-lens/src/codexlens/config.py index 184b0843..95136e2b 100644 --- a/codex-lens/src/codexlens/config.py +++ b/codex-lens/src/codexlens/config.py @@ -140,6 +140,7 @@ class Config: reranker_backend: str = "onnx" reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" reranker_top_k: int = 50 + reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching # Cascade search configuration (two-stage retrieval) enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking) @@ -277,6 +278,7 @@ class Config: "backend": self.reranker_backend, "model": self.reranker_model, "top_k": self.reranker_top_k, + "max_input_tokens": self.reranker_max_input_tokens, "pool_enabled": self.reranker_pool_enabled, "strategy": self.reranker_strategy, "cooldown": self.reranker_cooldown, @@ -359,6 +361,8 @@ class Config: self.reranker_model = reranker["model"] if "top_k" in reranker: self.reranker_top_k = reranker["top_k"] + if "max_input_tokens" in reranker: + self.reranker_max_input_tokens = reranker["max_input_tokens"] if "pool_enabled" in reranker: self.reranker_pool_enabled = reranker["pool_enabled"] if "strategy" in reranker: diff --git a/codex-lens/src/codexlens/search/chain_search.py b/codex-lens/src/codexlens/search/chain_search.py index d23cdef4..2e417c2d 100644 --- a/codex-lens/src/codexlens/search/chain_search.py +++ b/codex-lens/src/codexlens/search/chain_search.py @@ -1798,6 +1798,11 @@ class ChainSearchEngine: kwargs = {} if backend == "onnx": kwargs["use_gpu"] = use_gpu + elif backend == "api": + # Pass max_input_tokens for adaptive batching + max_tokens = getattr(self._config, "reranker_max_input_tokens", None) + if max_tokens: + kwargs["max_input_tokens"] = max_tokens reranker = get_reranker(backend=backend, model_name=model_name, **kwargs) diff --git a/codex-lens/src/codexlens/search/hybrid_search.py b/codex-lens/src/codexlens/search/hybrid_search.py index 6be4d637..4df15c10 100644 --- a/codex-lens/src/codexlens/search/hybrid_search.py +++ b/codex-lens/src/codexlens/search/hybrid_search.py @@ -400,6 +400,11 @@ class HybridSearchEngine: elif backend == "legacy": if not bool(getattr(self._config, "embedding_use_gpu", True)): device = "cpu" + elif backend == "api": + # Pass max_input_tokens for adaptive batching + max_tokens = getattr(self._config, "reranker_max_input_tokens", None) + if max_tokens: + kwargs["max_input_tokens"] = max_tokens return get_reranker( backend=backend, diff --git a/codex-lens/src/codexlens/semantic/litellm_embedder.py b/codex-lens/src/codexlens/semantic/litellm_embedder.py index 27a6137c..ee4284dd 100644 --- a/codex-lens/src/codexlens/semantic/litellm_embedder.py +++ b/codex-lens/src/codexlens/semantic/litellm_embedder.py @@ -69,13 +69,13 @@ class LiteLLMEmbedderWrapper(BaseEmbedder): Returns: int: Maximum number of tokens that can be embedded at once. - Inferred from model config or model name patterns. + Reads from LiteLLM config's max_input_tokens property. """ - # Try to get from LiteLLM config first - if hasattr(self._embedder, 'max_input_tokens') and self._embedder.max_input_tokens: + # Get from LiteLLM embedder's max_input_tokens property (now exposed) + if hasattr(self._embedder, 'max_input_tokens'): return self._embedder.max_input_tokens - # Infer from model name + # Fallback: infer from model name model_name_lower = self.model_name.lower() # Large models (8B or "large" in name) diff --git a/codex-lens/src/codexlens/semantic/reranker/api_reranker.py b/codex-lens/src/codexlens/semantic/reranker/api_reranker.py index 88cf34a3..be0c1503 100644 --- a/codex-lens/src/codexlens/semantic/reranker/api_reranker.py +++ b/codex-lens/src/codexlens/semantic/reranker/api_reranker.py @@ -78,6 +78,7 @@ class APIReranker(BaseReranker): backoff_max_s: float = 8.0, env_api_key: str = _DEFAULT_ENV_API_KEY, workspace_root: Path | str | None = None, + max_input_tokens: int | None = None, ) -> None: ok, err = check_httpx_available() if not ok: # pragma: no cover - exercised via factory availability tests @@ -135,6 +136,22 @@ class APIReranker(BaseReranker): timeout=self.timeout_s, ) + # Store max_input_tokens with model-aware defaults + if max_input_tokens is not None: + self._max_input_tokens = max_input_tokens + else: + # Infer from model name + model_lower = self.model_name.lower() + if '8b' in model_lower or 'large' in model_lower: + self._max_input_tokens = 32768 + else: + self._max_input_tokens = 8192 + + @property + def max_input_tokens(self) -> int: + """Return maximum token limit for reranking.""" + return self._max_input_tokens + def close(self) -> None: try: self._client.close() @@ -276,15 +293,78 @@ class APIReranker(BaseReranker): } return payload + def _estimate_tokens(self, text: str) -> int: + """Estimate token count using fast heuristic (len/4).""" + return len(text) // 4 + + def _create_token_aware_batches( + self, + query: str, + documents: Sequence[str], + ) -> list[list[tuple[int, str]]]: + """Split documents into batches that fit within token limits. + + Uses 90% of max_input_tokens as safety margin. + Each batch includes the query tokens overhead. + """ + max_tokens = int(self._max_input_tokens * 0.9) + query_tokens = self._estimate_tokens(query) + + batches: list[list[tuple[int, str]]] = [] + current_batch: list[tuple[int, str]] = [] + current_tokens = query_tokens # Start with query overhead + + for idx, doc in enumerate(documents): + doc_tokens = self._estimate_tokens(doc) + + # If single doc + query exceeds limit, include it anyway (will be truncated by API) + if current_tokens + doc_tokens > max_tokens and current_batch: + batches.append(current_batch) + current_batch = [] + current_tokens = query_tokens + + current_batch.append((idx, doc)) + current_tokens += doc_tokens + + if current_batch: + batches.append(current_batch) + + return batches + def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]: if not documents: return [] - payload = self._build_payload(query=query, documents=documents) - data = self._request_json(payload) + # Create token-aware batches + batches = self._create_token_aware_batches(query, documents) - results = data.get("results") - return self._extract_scores_from_results(results, expected=len(documents)) + if len(batches) == 1: + # Single batch - original behavior + payload = self._build_payload(query=query, documents=documents) + data = self._request_json(payload) + results = data.get("results") + return self._extract_scores_from_results(results, expected=len(documents)) + + # Multiple batches - process each and merge results + logger.info( + f"Splitting {len(documents)} documents into {len(batches)} batches " + f"(max_input_tokens: {self._max_input_tokens})" + ) + + all_scores: list[float] = [0.0] * len(documents) + + for batch in batches: + batch_docs = [doc for _, doc in batch] + payload = self._build_payload(query=query, documents=batch_docs) + data = self._request_json(payload) + results = data.get("results") + batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs)) + + # Map scores back to original indices + for (orig_idx, _), score in zip(batch, batch_scores): + all_scores[orig_idx] = score + + return all_scores def score_pairs( self, diff --git a/codex-lens/src/codexlens/semantic/reranker/base.py b/codex-lens/src/codexlens/semantic/reranker/base.py index 870aca84..65c2d837 100644 --- a/codex-lens/src/codexlens/semantic/reranker/base.py +++ b/codex-lens/src/codexlens/semantic/reranker/base.py @@ -16,6 +16,16 @@ class BaseReranker(ABC): the abstract methods to ensure a consistent interface. """ + @property + def max_input_tokens(self) -> int: + """Return maximum token limit for reranking. + + Returns: + int: Maximum number of tokens that can be processed at once. + Default is 8192 if not overridden by implementation. + """ + return 8192 + @abstractmethod def score_pairs( self,