mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-11 02:33:51 +08:00
feat: 添加重排序模型配置,支持最大输入令牌数,优化 API 批处理能力
This commit is contained in:
@@ -140,6 +140,7 @@ class Config:
|
||||
reranker_backend: str = "onnx"
|
||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
reranker_top_k: int = 50
|
||||
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
|
||||
|
||||
# Cascade search configuration (two-stage retrieval)
|
||||
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||
@@ -277,6 +278,7 @@ class Config:
|
||||
"backend": self.reranker_backend,
|
||||
"model": self.reranker_model,
|
||||
"top_k": self.reranker_top_k,
|
||||
"max_input_tokens": self.reranker_max_input_tokens,
|
||||
"pool_enabled": self.reranker_pool_enabled,
|
||||
"strategy": self.reranker_strategy,
|
||||
"cooldown": self.reranker_cooldown,
|
||||
@@ -359,6 +361,8 @@ class Config:
|
||||
self.reranker_model = reranker["model"]
|
||||
if "top_k" in reranker:
|
||||
self.reranker_top_k = reranker["top_k"]
|
||||
if "max_input_tokens" in reranker:
|
||||
self.reranker_max_input_tokens = reranker["max_input_tokens"]
|
||||
if "pool_enabled" in reranker:
|
||||
self.reranker_pool_enabled = reranker["pool_enabled"]
|
||||
if "strategy" in reranker:
|
||||
|
||||
@@ -1798,6 +1798,11 @@ class ChainSearchEngine:
|
||||
kwargs = {}
|
||||
if backend == "onnx":
|
||||
kwargs["use_gpu"] = use_gpu
|
||||
elif backend == "api":
|
||||
# Pass max_input_tokens for adaptive batching
|
||||
max_tokens = getattr(self._config, "reranker_max_input_tokens", None)
|
||||
if max_tokens:
|
||||
kwargs["max_input_tokens"] = max_tokens
|
||||
|
||||
reranker = get_reranker(backend=backend, model_name=model_name, **kwargs)
|
||||
|
||||
|
||||
@@ -400,6 +400,11 @@ class HybridSearchEngine:
|
||||
elif backend == "legacy":
|
||||
if not bool(getattr(self._config, "embedding_use_gpu", True)):
|
||||
device = "cpu"
|
||||
elif backend == "api":
|
||||
# Pass max_input_tokens for adaptive batching
|
||||
max_tokens = getattr(self._config, "reranker_max_input_tokens", None)
|
||||
if max_tokens:
|
||||
kwargs["max_input_tokens"] = max_tokens
|
||||
|
||||
return get_reranker(
|
||||
backend=backend,
|
||||
|
||||
@@ -69,13 +69,13 @@ class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Inferred from model config or model name patterns.
|
||||
Reads from LiteLLM config's max_input_tokens property.
|
||||
"""
|
||||
# Try to get from LiteLLM config first
|
||||
if hasattr(self._embedder, 'max_input_tokens') and self._embedder.max_input_tokens:
|
||||
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
|
||||
if hasattr(self._embedder, 'max_input_tokens'):
|
||||
return self._embedder.max_input_tokens
|
||||
|
||||
# Infer from model name
|
||||
# Fallback: infer from model name
|
||||
model_name_lower = self.model_name.lower()
|
||||
|
||||
# Large models (8B or "large" in name)
|
||||
|
||||
@@ -78,6 +78,7 @@ class APIReranker(BaseReranker):
|
||||
backoff_max_s: float = 8.0,
|
||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||
workspace_root: Path | str | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> None:
|
||||
ok, err = check_httpx_available()
|
||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||
@@ -135,6 +136,22 @@ class APIReranker(BaseReranker):
|
||||
timeout=self.timeout_s,
|
||||
)
|
||||
|
||||
# Store max_input_tokens with model-aware defaults
|
||||
if max_input_tokens is not None:
|
||||
self._max_input_tokens = max_input_tokens
|
||||
else:
|
||||
# Infer from model name
|
||||
model_lower = self.model_name.lower()
|
||||
if '8b' in model_lower or 'large' in model_lower:
|
||||
self._max_input_tokens = 32768
|
||||
else:
|
||||
self._max_input_tokens = 8192
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking."""
|
||||
return self._max_input_tokens
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
@@ -276,15 +293,78 @@ class APIReranker(BaseReranker):
|
||||
}
|
||||
return payload
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count using fast heuristic (len/4)."""
|
||||
return len(text) // 4
|
||||
|
||||
def _create_token_aware_batches(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> list[list[tuple[int, str]]]:
|
||||
"""Split documents into batches that fit within token limits.
|
||||
|
||||
Uses 90% of max_input_tokens as safety margin.
|
||||
Each batch includes the query tokens overhead.
|
||||
"""
|
||||
max_tokens = int(self._max_input_tokens * 0.9)
|
||||
query_tokens = self._estimate_tokens(query)
|
||||
|
||||
batches: list[list[tuple[int, str]]] = []
|
||||
current_batch: list[tuple[int, str]] = []
|
||||
current_tokens = query_tokens # Start with query overhead
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_tokens = self._estimate_tokens(doc)
|
||||
|
||||
# If single doc + query exceeds limit, include it anyway (will be truncated by API)
|
||||
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
current_tokens = query_tokens
|
||||
|
||||
current_batch.append((idx, doc))
|
||||
current_tokens += doc_tokens
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
# Create token-aware batches
|
||||
batches = self._create_token_aware_batches(query, documents)
|
||||
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
if len(batches) == 1:
|
||||
# Single batch - original behavior
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
|
||||
# Multiple batches - process each and merge results
|
||||
logger.info(
|
||||
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||
f"(max_input_tokens: {self._max_input_tokens})"
|
||||
)
|
||||
|
||||
all_scores: list[float] = [0.0] * len(documents)
|
||||
|
||||
for batch in batches:
|
||||
batch_docs = [doc for _, doc in batch]
|
||||
payload = self._build_payload(query=query, documents=batch_docs)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||
|
||||
# Map scores back to original indices
|
||||
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||
all_scores[orig_idx] = score
|
||||
|
||||
return all_scores
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
|
||||
@@ -16,6 +16,16 @@ class BaseReranker(ABC):
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be processed at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def score_pairs(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user