mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: 添加重排序模型配置,支持最大输入令牌数,优化 API 批处理能力
This commit is contained in:
@@ -102,6 +102,15 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
|||||||
"""Embedding vector size."""
|
"""Embedding vector size."""
|
||||||
return self._model_config.dimensions
|
return self._model_config.dimensions
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_input_tokens(self) -> int:
|
||||||
|
"""Maximum token limit for embeddings.
|
||||||
|
|
||||||
|
Returns the configured max_input_tokens from model config,
|
||||||
|
enabling adaptive batch sizing based on actual model capacity.
|
||||||
|
"""
|
||||||
|
return self._model_config.max_input_tokens
|
||||||
|
|
||||||
def _estimate_tokens(self, text: str) -> int:
|
def _estimate_tokens(self, text: str) -> int:
|
||||||
"""Estimate token count for a text using fast heuristic.
|
"""Estimate token count for a text using fast heuristic.
|
||||||
|
|
||||||
@@ -162,7 +171,7 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
|||||||
texts: str | Sequence[str],
|
texts: str | Sequence[str],
|
||||||
*,
|
*,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
max_tokens_per_batch: int = 30000,
|
max_tokens_per_batch: int | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> NDArray[np.floating]:
|
) -> NDArray[np.floating]:
|
||||||
"""Embed one or more texts.
|
"""Embed one or more texts.
|
||||||
@@ -170,7 +179,8 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
|||||||
Args:
|
Args:
|
||||||
texts: Single text or sequence of texts
|
texts: Single text or sequence of texts
|
||||||
batch_size: Batch size for processing (deprecated, use max_tokens_per_batch)
|
batch_size: Batch size for processing (deprecated, use max_tokens_per_batch)
|
||||||
max_tokens_per_batch: Maximum estimated tokens per API call (default: 30000)
|
max_tokens_per_batch: Maximum estimated tokens per API call.
|
||||||
|
If None, uses 90% of model's max_input_tokens for safety margin.
|
||||||
**kwargs: Additional arguments for litellm.embedding()
|
**kwargs: Additional arguments for litellm.embedding()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -196,6 +206,15 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
|||||||
if self._provider_config.api_base and "encoding_format" not in embedding_kwargs:
|
if self._provider_config.api_base and "encoding_format" not in embedding_kwargs:
|
||||||
embedding_kwargs["encoding_format"] = "float"
|
embedding_kwargs["encoding_format"] = "float"
|
||||||
|
|
||||||
|
# Determine adaptive max_tokens_per_batch
|
||||||
|
# Use 90% of model's max_input_tokens as safety margin
|
||||||
|
if max_tokens_per_batch is None:
|
||||||
|
max_tokens_per_batch = int(self.max_input_tokens * 0.9)
|
||||||
|
logger.debug(
|
||||||
|
f"Using adaptive batch size: {max_tokens_per_batch} tokens "
|
||||||
|
f"(90% of {self.max_input_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
# Split into token-aware batches
|
# Split into token-aware batches
|
||||||
batches = self._create_batches(text_list, max_tokens_per_batch)
|
batches = self._create_batches(text_list, max_tokens_per_batch)
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
|||||||
providers: dict[str, Any] = {}
|
providers: dict[str, Any] = {}
|
||||||
llm_models: dict[str, Any] = {}
|
llm_models: dict[str, Any] = {}
|
||||||
embedding_models: dict[str, Any] = {}
|
embedding_models: dict[str, Any] = {}
|
||||||
|
reranker_models: dict[str, Any] = {}
|
||||||
default_provider: str | None = None
|
default_provider: str | None = None
|
||||||
|
|
||||||
for provider in json_config.get("providers", []):
|
for provider in json_config.get("providers", []):
|
||||||
@@ -186,6 +187,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
|||||||
"provider": provider_id,
|
"provider": provider_id,
|
||||||
"model": model.get("name", ""),
|
"model": model.get("name", ""),
|
||||||
"dimensions": model.get("capabilities", {}).get("embeddingDimension", 1536),
|
"dimensions": model.get("capabilities", {}).get("embeddingDimension", 1536),
|
||||||
|
"max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192),
|
||||||
}
|
}
|
||||||
# Add model-specific endpoint settings
|
# Add model-specific endpoint settings
|
||||||
endpoint = model.get("endpointSettings", {})
|
endpoint = model.get("endpointSettings", {})
|
||||||
@@ -196,6 +198,29 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
|||||||
|
|
||||||
embedding_models[model_id] = embedding_model_config
|
embedding_models[model_id] = embedding_model_config
|
||||||
|
|
||||||
|
# Convert reranker models
|
||||||
|
for model in provider.get("rerankerModels", []):
|
||||||
|
if not model.get("enabled", True):
|
||||||
|
continue
|
||||||
|
model_id = model.get("id", "")
|
||||||
|
if not model_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
reranker_model_config: dict[str, Any] = {
|
||||||
|
"provider": provider_id,
|
||||||
|
"model": model.get("name", ""),
|
||||||
|
"max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192),
|
||||||
|
"top_k": model.get("capabilities", {}).get("topK", 50),
|
||||||
|
}
|
||||||
|
# Add model-specific endpoint settings
|
||||||
|
endpoint = model.get("endpointSettings", {})
|
||||||
|
if endpoint.get("baseUrl"):
|
||||||
|
reranker_model_config["api_base"] = endpoint["baseUrl"]
|
||||||
|
if endpoint.get("timeout"):
|
||||||
|
reranker_model_config["timeout"] = endpoint["timeout"]
|
||||||
|
|
||||||
|
reranker_models[model_id] = reranker_model_config
|
||||||
|
|
||||||
# Ensure we have defaults if no models found
|
# Ensure we have defaults if no models found
|
||||||
if not llm_models:
|
if not llm_models:
|
||||||
llm_models["default"] = {
|
llm_models["default"] = {
|
||||||
@@ -208,6 +233,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
|||||||
"provider": default_provider or "openai",
|
"provider": default_provider or "openai",
|
||||||
"model": "text-embedding-3-small",
|
"model": "text-embedding-3-small",
|
||||||
"dimensions": 1536,
|
"dimensions": 1536,
|
||||||
|
"max_input_tokens": 8191,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -216,6 +242,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
|||||||
"providers": providers,
|
"providers": providers,
|
||||||
"llm_models": llm_models,
|
"llm_models": llm_models,
|
||||||
"embedding_models": embedding_models,
|
"embedding_models": embedding_models,
|
||||||
|
"reranker_models": reranker_models,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,18 @@ class EmbeddingModelConfig(BaseModel):
|
|||||||
provider: str # "openai", "fastembed", "ollama", etc.
|
provider: str # "openai", "fastembed", "ollama", etc.
|
||||||
model: str
|
model: str
|
||||||
dimensions: int
|
dimensions: int
|
||||||
|
max_input_tokens: int = 8192 # Maximum tokens per embedding request
|
||||||
|
|
||||||
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
|
|
||||||
|
class RerankerModelConfig(BaseModel):
|
||||||
|
"""Reranker model configuration."""
|
||||||
|
|
||||||
|
provider: str # "siliconflow", "cohere", "jina", etc.
|
||||||
|
model: str
|
||||||
|
max_input_tokens: int = 8192 # Maximum tokens per reranking request
|
||||||
|
top_k: int = 50 # Default top_k for reranking
|
||||||
|
|
||||||
model_config = {"extra": "allow"}
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
@@ -69,6 +81,7 @@ class LiteLLMConfig(BaseModel):
|
|||||||
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
|
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||||
llm_models: dict[str, LLMModelConfig] = Field(default_factory=dict)
|
llm_models: dict[str, LLMModelConfig] = Field(default_factory=dict)
|
||||||
embedding_models: dict[str, EmbeddingModelConfig] = Field(default_factory=dict)
|
embedding_models: dict[str, EmbeddingModelConfig] = Field(default_factory=dict)
|
||||||
|
reranker_models: dict[str, RerankerModelConfig] = Field(default_factory=dict)
|
||||||
|
|
||||||
model_config = {"extra": "allow"}
|
model_config = {"extra": "allow"}
|
||||||
|
|
||||||
@@ -110,6 +123,25 @@ class LiteLLMConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return self.embedding_models[model]
|
return self.embedding_models[model]
|
||||||
|
|
||||||
|
def get_reranker_model(self, model: str = "default") -> RerankerModelConfig:
|
||||||
|
"""Get reranker model configuration by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name or "default"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reranker model configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model not found
|
||||||
|
"""
|
||||||
|
if model not in self.reranker_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Reranker model '{model}' not found in configuration. "
|
||||||
|
f"Available models: {list(self.reranker_models.keys())}"
|
||||||
|
)
|
||||||
|
return self.reranker_models[model]
|
||||||
|
|
||||||
def get_provider(self, provider: str) -> ProviderConfig:
|
def get_provider(self, provider: str) -> ProviderConfig:
|
||||||
"""Get provider configuration by name.
|
"""Get provider configuration by name.
|
||||||
|
|
||||||
|
|||||||
@@ -1672,6 +1672,7 @@ const i18n = {
|
|||||||
// Embedding models
|
// Embedding models
|
||||||
'apiSettings.embeddingDimensions': 'Dimensions',
|
'apiSettings.embeddingDimensions': 'Dimensions',
|
||||||
'apiSettings.embeddingMaxTokens': 'Max Tokens',
|
'apiSettings.embeddingMaxTokens': 'Max Tokens',
|
||||||
|
'apiSettings.rerankerTopK': 'Top K',
|
||||||
'apiSettings.selectEmbeddingModel': 'Select Embedding Model',
|
'apiSettings.selectEmbeddingModel': 'Select Embedding Model',
|
||||||
|
|
||||||
// Model modal
|
// Model modal
|
||||||
@@ -3698,6 +3699,7 @@ const i18n = {
|
|||||||
// Embedding models
|
// Embedding models
|
||||||
'apiSettings.embeddingDimensions': '向量维度',
|
'apiSettings.embeddingDimensions': '向量维度',
|
||||||
'apiSettings.embeddingMaxTokens': '最大 Token',
|
'apiSettings.embeddingMaxTokens': '最大 Token',
|
||||||
|
'apiSettings.rerankerTopK': 'Top K',
|
||||||
'apiSettings.selectEmbeddingModel': '选择嵌入模型',
|
'apiSettings.selectEmbeddingModel': '选择嵌入模型',
|
||||||
|
|
||||||
// Model modal
|
// Model modal
|
||||||
|
|||||||
@@ -1163,7 +1163,7 @@ function renderProviderDetail(providerId) {
|
|||||||
var maskedKey = provider.apiKey ? '••••••••••••••••' + provider.apiKey.slice(-4) : '••••••••';
|
var maskedKey = provider.apiKey ? '••••••••••••••••' + provider.apiKey.slice(-4) : '••••••••';
|
||||||
var currentApiBase = provider.apiBase || getDefaultApiBase(provider.type);
|
var currentApiBase = provider.apiBase || getDefaultApiBase(provider.type);
|
||||||
// Show full endpoint URL preview based on active model tab
|
// Show full endpoint URL preview based on active model tab
|
||||||
var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : '/chat/completions';
|
var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : activeModelTab === 'reranker' ? '/rerank' : '/chat/completions';
|
||||||
var apiBasePreview = currentApiBase + endpointPath;
|
var apiBasePreview = currentApiBase + endpointPath;
|
||||||
|
|
||||||
var html = '<div class="provider-detail-header">' +
|
var html = '<div class="provider-detail-header">' +
|
||||||
@@ -1322,10 +1322,17 @@ function renderModelTree(provider) {
|
|||||||
var embeddingBadge = model.capabilities && model.capabilities.embeddingDimension
|
var embeddingBadge = model.capabilities && model.capabilities.embeddingDimension
|
||||||
? model.capabilities.embeddingDimension + 'd'
|
? model.capabilities.embeddingDimension + 'd'
|
||||||
: '';
|
: '';
|
||||||
var displayBadge = activeModelTab === 'llm' ? badge : embeddingBadge;
|
|
||||||
|
// Badge for reranker models shows max tokens
|
||||||
|
var rerankerBadge = model.capabilities && model.capabilities.maxInputTokens
|
||||||
|
? formatContextWindow(model.capabilities.maxInputTokens)
|
||||||
|
: '';
|
||||||
|
|
||||||
|
var displayBadge = activeModelTab === 'llm' ? badge : activeModelTab === 'reranker' ? rerankerBadge : embeddingBadge;
|
||||||
|
var iconName = activeModelTab === 'llm' ? 'sparkles' : activeModelTab === 'reranker' ? 'arrow-up-down' : 'box';
|
||||||
|
|
||||||
html += '<div class="model-item" data-model-id="' + model.id + '">' +
|
html += '<div class="model-item" data-model-id="' + model.id + '">' +
|
||||||
'<i data-lucide="' + (activeModelTab === 'llm' ? 'sparkles' : 'box') + '" class="model-item-icon"></i>' +
|
'<i data-lucide="' + iconName + '" class="model-item-icon"></i>' +
|
||||||
'<span class="model-item-name">' + escapeHtml(model.name) + '</span>' +
|
'<span class="model-item-name">' + escapeHtml(model.name) + '</span>' +
|
||||||
(displayBadge ? '<span class="model-item-badge">' + displayBadge + '</span>' : '') +
|
(displayBadge ? '<span class="model-item-badge">' + displayBadge + '</span>' : '') +
|
||||||
'<div class="model-item-actions">' +
|
'<div class="model-item-actions">' +
|
||||||
@@ -1966,14 +1973,25 @@ function showModelSettingsModal(providerId, modelId, modelType) {
|
|||||||
'<label class="checkbox-label"><input type="checkbox" id="model-settings-function-calling"' + (capabilities.functionCalling ? ' checked' : '') + '> ' + t('apiSettings.functionCalling') + '</label>' +
|
'<label class="checkbox-label"><input type="checkbox" id="model-settings-function-calling"' + (capabilities.functionCalling ? ' checked' : '') + '> ' + t('apiSettings.functionCalling') + '</label>' +
|
||||||
'<label class="checkbox-label"><input type="checkbox" id="model-settings-vision"' + (capabilities.vision ? ' checked' : '') + '> ' + t('apiSettings.vision') + '</label>' +
|
'<label class="checkbox-label"><input type="checkbox" id="model-settings-vision"' + (capabilities.vision ? ' checked' : '') + '> ' + t('apiSettings.vision') + '</label>' +
|
||||||
'</div>'
|
'</div>'
|
||||||
|
) : isReranker ? (
|
||||||
|
// Reranker capabilities - only maxInputTokens and topK
|
||||||
|
'<div class="form-group">' +
|
||||||
|
'<label>' + t('apiSettings.embeddingMaxTokens') + '</label>' +
|
||||||
|
'<input type="number" id="model-settings-max-tokens" class="cli-input" value="' + (capabilities.maxInputTokens || 8192) + '" min="128">' +
|
||||||
|
'</div>' +
|
||||||
|
'<div class="form-group">' +
|
||||||
|
'<label>' + t('apiSettings.rerankerTopK') + '</label>' +
|
||||||
|
'<input type="number" id="model-settings-top-k" class="cli-input" value="' + (capabilities.topK || 50) + '" min="1" max="1000">' +
|
||||||
|
'</div>'
|
||||||
) : (
|
) : (
|
||||||
|
// Embedding capabilities - embeddingDimension and maxInputTokens
|
||||||
'<div class="form-group">' +
|
'<div class="form-group">' +
|
||||||
'<label>' + t('apiSettings.embeddingDimensions') + '</label>' +
|
'<label>' + t('apiSettings.embeddingDimensions') + '</label>' +
|
||||||
'<input type="number" id="model-settings-dimensions" class="cli-input" value="' + (capabilities.embeddingDimension || 1536) + '" min="64">' +
|
'<input type="number" id="model-settings-dimensions" class="cli-input" value="' + (capabilities.embeddingDimension || 1536) + '" min="64">' +
|
||||||
'</div>' +
|
'</div>' +
|
||||||
'<div class="form-group">' +
|
'<div class="form-group">' +
|
||||||
'<label>' + t('apiSettings.embeddingMaxTokens') + '</label>' +
|
'<label>' + t('apiSettings.embeddingMaxTokens') + '</label>' +
|
||||||
'<input type="number" id="model-settings-max-tokens" class="cli-input" value="' + (capabilities.contextWindow || 8192) + '" min="128">' +
|
'<input type="number" id="model-settings-max-tokens" class="cli-input" value="' + (capabilities.maxInputTokens || 8192) + '" min="128">' +
|
||||||
'</div>'
|
'</div>'
|
||||||
)) +
|
)) +
|
||||||
'</div>' +
|
'</div>' +
|
||||||
@@ -2070,14 +2088,14 @@ function saveModelSettings(event, providerId, modelId, modelType) {
|
|||||||
vision: document.getElementById('model-settings-vision').checked
|
vision: document.getElementById('model-settings-vision').checked
|
||||||
};
|
};
|
||||||
} else if (isReranker) {
|
} else if (isReranker) {
|
||||||
var topKEl = document.getElementById('model-settings-top-k');
|
|
||||||
models[modelIndex].capabilities = {
|
models[modelIndex].capabilities = {
|
||||||
topK: topKEl ? parseInt(topKEl.value) || 10 : 10
|
maxInputTokens: parseInt(document.getElementById('model-settings-max-tokens').value) || 8192,
|
||||||
|
topK: parseInt(document.getElementById('model-settings-top-k').value) || 50
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
models[modelIndex].capabilities = {
|
models[modelIndex].capabilities = {
|
||||||
embeddingDimension: parseInt(document.getElementById('model-settings-dimensions').value) || 1536,
|
embeddingDimension: parseInt(document.getElementById('model-settings-dimensions').value) || 1536,
|
||||||
contextWindow: parseInt(document.getElementById('model-settings-max-tokens').value) || 8192
|
maxInputTokens: parseInt(document.getElementById('model-settings-max-tokens').value) || 8192
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2218,7 +2236,7 @@ function updateApiBasePreview(apiBase) {
|
|||||||
if (base.endsWith('/')) {
|
if (base.endsWith('/')) {
|
||||||
base = base.slice(0, -1);
|
base = base.slice(0, -1);
|
||||||
}
|
}
|
||||||
var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : '/chat/completions';
|
var endpointPath = activeModelTab === 'embedding' ? '/embeddings' : activeModelTab === 'reranker' ? '/rerank' : '/chat/completions';
|
||||||
preview.textContent = t('apiSettings.preview') + ': ' + base + endpointPath;
|
preview.textContent = t('apiSettings.preview') + ': ' + base + endpointPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -140,6 +140,7 @@ class Config:
|
|||||||
reranker_backend: str = "onnx"
|
reranker_backend: str = "onnx"
|
||||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
reranker_top_k: int = 50
|
reranker_top_k: int = 50
|
||||||
|
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
|
||||||
|
|
||||||
# Cascade search configuration (two-stage retrieval)
|
# Cascade search configuration (two-stage retrieval)
|
||||||
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||||
@@ -277,6 +278,7 @@ class Config:
|
|||||||
"backend": self.reranker_backend,
|
"backend": self.reranker_backend,
|
||||||
"model": self.reranker_model,
|
"model": self.reranker_model,
|
||||||
"top_k": self.reranker_top_k,
|
"top_k": self.reranker_top_k,
|
||||||
|
"max_input_tokens": self.reranker_max_input_tokens,
|
||||||
"pool_enabled": self.reranker_pool_enabled,
|
"pool_enabled": self.reranker_pool_enabled,
|
||||||
"strategy": self.reranker_strategy,
|
"strategy": self.reranker_strategy,
|
||||||
"cooldown": self.reranker_cooldown,
|
"cooldown": self.reranker_cooldown,
|
||||||
@@ -359,6 +361,8 @@ class Config:
|
|||||||
self.reranker_model = reranker["model"]
|
self.reranker_model = reranker["model"]
|
||||||
if "top_k" in reranker:
|
if "top_k" in reranker:
|
||||||
self.reranker_top_k = reranker["top_k"]
|
self.reranker_top_k = reranker["top_k"]
|
||||||
|
if "max_input_tokens" in reranker:
|
||||||
|
self.reranker_max_input_tokens = reranker["max_input_tokens"]
|
||||||
if "pool_enabled" in reranker:
|
if "pool_enabled" in reranker:
|
||||||
self.reranker_pool_enabled = reranker["pool_enabled"]
|
self.reranker_pool_enabled = reranker["pool_enabled"]
|
||||||
if "strategy" in reranker:
|
if "strategy" in reranker:
|
||||||
|
|||||||
@@ -1798,6 +1798,11 @@ class ChainSearchEngine:
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
if backend == "onnx":
|
if backend == "onnx":
|
||||||
kwargs["use_gpu"] = use_gpu
|
kwargs["use_gpu"] = use_gpu
|
||||||
|
elif backend == "api":
|
||||||
|
# Pass max_input_tokens for adaptive batching
|
||||||
|
max_tokens = getattr(self._config, "reranker_max_input_tokens", None)
|
||||||
|
if max_tokens:
|
||||||
|
kwargs["max_input_tokens"] = max_tokens
|
||||||
|
|
||||||
reranker = get_reranker(backend=backend, model_name=model_name, **kwargs)
|
reranker = get_reranker(backend=backend, model_name=model_name, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -400,6 +400,11 @@ class HybridSearchEngine:
|
|||||||
elif backend == "legacy":
|
elif backend == "legacy":
|
||||||
if not bool(getattr(self._config, "embedding_use_gpu", True)):
|
if not bool(getattr(self._config, "embedding_use_gpu", True)):
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
elif backend == "api":
|
||||||
|
# Pass max_input_tokens for adaptive batching
|
||||||
|
max_tokens = getattr(self._config, "reranker_max_input_tokens", None)
|
||||||
|
if max_tokens:
|
||||||
|
kwargs["max_input_tokens"] = max_tokens
|
||||||
|
|
||||||
return get_reranker(
|
return get_reranker(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
|
|||||||
@@ -69,13 +69,13 @@ class LiteLLMEmbedderWrapper(BaseEmbedder):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: Maximum number of tokens that can be embedded at once.
|
int: Maximum number of tokens that can be embedded at once.
|
||||||
Inferred from model config or model name patterns.
|
Reads from LiteLLM config's max_input_tokens property.
|
||||||
"""
|
"""
|
||||||
# Try to get from LiteLLM config first
|
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
|
||||||
if hasattr(self._embedder, 'max_input_tokens') and self._embedder.max_input_tokens:
|
if hasattr(self._embedder, 'max_input_tokens'):
|
||||||
return self._embedder.max_input_tokens
|
return self._embedder.max_input_tokens
|
||||||
|
|
||||||
# Infer from model name
|
# Fallback: infer from model name
|
||||||
model_name_lower = self.model_name.lower()
|
model_name_lower = self.model_name.lower()
|
||||||
|
|
||||||
# Large models (8B or "large" in name)
|
# Large models (8B or "large" in name)
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class APIReranker(BaseReranker):
|
|||||||
backoff_max_s: float = 8.0,
|
backoff_max_s: float = 8.0,
|
||||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||||
workspace_root: Path | str | None = None,
|
workspace_root: Path | str | None = None,
|
||||||
|
max_input_tokens: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
ok, err = check_httpx_available()
|
ok, err = check_httpx_available()
|
||||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||||
@@ -135,6 +136,22 @@ class APIReranker(BaseReranker):
|
|||||||
timeout=self.timeout_s,
|
timeout=self.timeout_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store max_input_tokens with model-aware defaults
|
||||||
|
if max_input_tokens is not None:
|
||||||
|
self._max_input_tokens = max_input_tokens
|
||||||
|
else:
|
||||||
|
# Infer from model name
|
||||||
|
model_lower = self.model_name.lower()
|
||||||
|
if '8b' in model_lower or 'large' in model_lower:
|
||||||
|
self._max_input_tokens = 32768
|
||||||
|
else:
|
||||||
|
self._max_input_tokens = 8192
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_input_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit for reranking."""
|
||||||
|
return self._max_input_tokens
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
try:
|
try:
|
||||||
self._client.close()
|
self._client.close()
|
||||||
@@ -276,15 +293,78 @@ class APIReranker(BaseReranker):
|
|||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
def _estimate_tokens(self, text: str) -> int:
|
||||||
|
"""Estimate token count using fast heuristic (len/4)."""
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
def _create_token_aware_batches(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: Sequence[str],
|
||||||
|
) -> list[list[tuple[int, str]]]:
|
||||||
|
"""Split documents into batches that fit within token limits.
|
||||||
|
|
||||||
|
Uses 90% of max_input_tokens as safety margin.
|
||||||
|
Each batch includes the query tokens overhead.
|
||||||
|
"""
|
||||||
|
max_tokens = int(self._max_input_tokens * 0.9)
|
||||||
|
query_tokens = self._estimate_tokens(query)
|
||||||
|
|
||||||
|
batches: list[list[tuple[int, str]]] = []
|
||||||
|
current_batch: list[tuple[int, str]] = []
|
||||||
|
current_tokens = query_tokens # Start with query overhead
|
||||||
|
|
||||||
|
for idx, doc in enumerate(documents):
|
||||||
|
doc_tokens = self._estimate_tokens(doc)
|
||||||
|
|
||||||
|
# If single doc + query exceeds limit, include it anyway (will be truncated by API)
|
||||||
|
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
current_batch = []
|
||||||
|
current_tokens = query_tokens
|
||||||
|
|
||||||
|
current_batch.append((idx, doc))
|
||||||
|
current_tokens += doc_tokens
|
||||||
|
|
||||||
|
if current_batch:
|
||||||
|
batches.append(current_batch)
|
||||||
|
|
||||||
|
return batches
|
||||||
|
|
||||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||||
if not documents:
|
if not documents:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
payload = self._build_payload(query=query, documents=documents)
|
# Create token-aware batches
|
||||||
data = self._request_json(payload)
|
batches = self._create_token_aware_batches(query, documents)
|
||||||
|
|
||||||
results = data.get("results")
|
if len(batches) == 1:
|
||||||
return self._extract_scores_from_results(results, expected=len(documents))
|
# Single batch - original behavior
|
||||||
|
payload = self._build_payload(query=query, documents=documents)
|
||||||
|
data = self._request_json(payload)
|
||||||
|
results = data.get("results")
|
||||||
|
return self._extract_scores_from_results(results, expected=len(documents))
|
||||||
|
|
||||||
|
# Multiple batches - process each and merge results
|
||||||
|
logger.info(
|
||||||
|
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||||
|
f"(max_input_tokens: {self._max_input_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_scores: list[float] = [0.0] * len(documents)
|
||||||
|
|
||||||
|
for batch in batches:
|
||||||
|
batch_docs = [doc for _, doc in batch]
|
||||||
|
payload = self._build_payload(query=query, documents=batch_docs)
|
||||||
|
data = self._request_json(payload)
|
||||||
|
results = data.get("results")
|
||||||
|
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||||
|
|
||||||
|
# Map scores back to original indices
|
||||||
|
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||||
|
all_scores[orig_idx] = score
|
||||||
|
|
||||||
|
return all_scores
|
||||||
|
|
||||||
def score_pairs(
|
def score_pairs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -16,6 +16,16 @@ class BaseReranker(ABC):
|
|||||||
the abstract methods to ensure a consistent interface.
|
the abstract methods to ensure a consistent interface.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_input_tokens(self) -> int:
|
||||||
|
"""Return maximum token limit for reranking.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Maximum number of tokens that can be processed at once.
|
||||||
|
Default is 8192 if not overridden by implementation.
|
||||||
|
"""
|
||||||
|
return 8192
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_pairs(
|
def score_pairs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user