mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: 添加重排序模型配置,支持最大输入令牌数,优化 API 批处理能力
This commit is contained in:
@@ -102,6 +102,15 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
||||
"""Embedding vector size."""
|
||||
return self._model_config.dimensions
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Maximum token limit for embeddings.
|
||||
|
||||
Returns the configured max_input_tokens from model config,
|
||||
enabling adaptive batch sizing based on actual model capacity.
|
||||
"""
|
||||
return self._model_config.max_input_tokens
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count for a text using fast heuristic.
|
||||
|
||||
@@ -162,7 +171,7 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
||||
texts: str | Sequence[str],
|
||||
*,
|
||||
batch_size: int | None = None,
|
||||
max_tokens_per_batch: int = 30000,
|
||||
max_tokens_per_batch: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> NDArray[np.floating]:
|
||||
"""Embed one or more texts.
|
||||
@@ -170,7 +179,8 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
||||
Args:
|
||||
texts: Single text or sequence of texts
|
||||
batch_size: Batch size for processing (deprecated, use max_tokens_per_batch)
|
||||
max_tokens_per_batch: Maximum estimated tokens per API call (default: 30000)
|
||||
max_tokens_per_batch: Maximum estimated tokens per API call.
|
||||
If None, uses 90% of model's max_input_tokens for safety margin.
|
||||
**kwargs: Additional arguments for litellm.embedding()
|
||||
|
||||
Returns:
|
||||
@@ -196,6 +206,15 @@ class LiteLLMEmbedder(AbstractEmbedder):
|
||||
if self._provider_config.api_base and "encoding_format" not in embedding_kwargs:
|
||||
embedding_kwargs["encoding_format"] = "float"
|
||||
|
||||
# Determine adaptive max_tokens_per_batch
|
||||
# Use 90% of model's max_input_tokens as safety margin
|
||||
if max_tokens_per_batch is None:
|
||||
max_tokens_per_batch = int(self.max_input_tokens * 0.9)
|
||||
logger.debug(
|
||||
f"Using adaptive batch size: {max_tokens_per_batch} tokens "
|
||||
f"(90% of {self.max_input_tokens})"
|
||||
)
|
||||
|
||||
# Split into token-aware batches
|
||||
batches = self._create_batches(text_list, max_tokens_per_batch)
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
||||
providers: dict[str, Any] = {}
|
||||
llm_models: dict[str, Any] = {}
|
||||
embedding_models: dict[str, Any] = {}
|
||||
reranker_models: dict[str, Any] = {}
|
||||
default_provider: str | None = None
|
||||
|
||||
for provider in json_config.get("providers", []):
|
||||
@@ -186,6 +187,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
||||
"provider": provider_id,
|
||||
"model": model.get("name", ""),
|
||||
"dimensions": model.get("capabilities", {}).get("embeddingDimension", 1536),
|
||||
"max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192),
|
||||
}
|
||||
# Add model-specific endpoint settings
|
||||
endpoint = model.get("endpointSettings", {})
|
||||
@@ -196,6 +198,29 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
||||
|
||||
embedding_models[model_id] = embedding_model_config
|
||||
|
||||
# Convert reranker models
|
||||
for model in provider.get("rerankerModels", []):
|
||||
if not model.get("enabled", True):
|
||||
continue
|
||||
model_id = model.get("id", "")
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
reranker_model_config: dict[str, Any] = {
|
||||
"provider": provider_id,
|
||||
"model": model.get("name", ""),
|
||||
"max_input_tokens": model.get("capabilities", {}).get("maxInputTokens", 8192),
|
||||
"top_k": model.get("capabilities", {}).get("topK", 50),
|
||||
}
|
||||
# Add model-specific endpoint settings
|
||||
endpoint = model.get("endpointSettings", {})
|
||||
if endpoint.get("baseUrl"):
|
||||
reranker_model_config["api_base"] = endpoint["baseUrl"]
|
||||
if endpoint.get("timeout"):
|
||||
reranker_model_config["timeout"] = endpoint["timeout"]
|
||||
|
||||
reranker_models[model_id] = reranker_model_config
|
||||
|
||||
# Ensure we have defaults if no models found
|
||||
if not llm_models:
|
||||
llm_models["default"] = {
|
||||
@@ -208,6 +233,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
||||
"provider": default_provider or "openai",
|
||||
"model": "text-embedding-3-small",
|
||||
"dimensions": 1536,
|
||||
"max_input_tokens": 8191,
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -216,6 +242,7 @@ def _convert_json_to_internal_format(json_config: dict[str, Any]) -> dict[str, A
|
||||
"providers": providers,
|
||||
"llm_models": llm_models,
|
||||
"embedding_models": embedding_models,
|
||||
"reranker_models": reranker_models,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,18 @@ class EmbeddingModelConfig(BaseModel):
|
||||
provider: str # "openai", "fastembed", "ollama", etc.
|
||||
model: str
|
||||
dimensions: int
|
||||
max_input_tokens: int = 8192 # Maximum tokens per embedding request
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class RerankerModelConfig(BaseModel):
|
||||
"""Reranker model configuration."""
|
||||
|
||||
provider: str # "siliconflow", "cohere", "jina", etc.
|
||||
model: str
|
||||
max_input_tokens: int = 8192 # Maximum tokens per reranking request
|
||||
top_k: int = 50 # Default top_k for reranking
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@@ -69,6 +81,7 @@ class LiteLLMConfig(BaseModel):
|
||||
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
|
||||
llm_models: dict[str, LLMModelConfig] = Field(default_factory=dict)
|
||||
embedding_models: dict[str, EmbeddingModelConfig] = Field(default_factory=dict)
|
||||
reranker_models: dict[str, RerankerModelConfig] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@@ -110,6 +123,25 @@ class LiteLLMConfig(BaseModel):
|
||||
)
|
||||
return self.embedding_models[model]
|
||||
|
||||
def get_reranker_model(self, model: str = "default") -> RerankerModelConfig:
|
||||
"""Get reranker model configuration by name.
|
||||
|
||||
Args:
|
||||
model: Model name or "default"
|
||||
|
||||
Returns:
|
||||
Reranker model configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found
|
||||
"""
|
||||
if model not in self.reranker_models:
|
||||
raise ValueError(
|
||||
f"Reranker model '{model}' not found in configuration. "
|
||||
f"Available models: {list(self.reranker_models.keys())}"
|
||||
)
|
||||
return self.reranker_models[model]
|
||||
|
||||
def get_provider(self, provider: str) -> ProviderConfig:
|
||||
"""Get provider configuration by name.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user