feat: add reranker models to ProviderCredential and improve FastEmbedReranker scoring

- Added `rerankerModels` property to the `ProviderCredential` interface in `litellm-api-config.ts` to support additional reranker configurations.
- Introduced a numerically stable sigmoid function in `FastEmbedReranker` for score normalization.
- Updated the scoring logic in `FastEmbedReranker` to use raw float scores from the encoder and normalize them using the new sigmoid function.
- Adjusted the result mapping to maintain original document order while applying normalization.
This commit is contained in:
catlog22
2026-01-03 22:20:06 +08:00
parent 74ad2d0463
commit 504ccfebbc
6 changed files with 1277 additions and 451 deletions

View File

@@ -789,6 +789,46 @@ export async function handleLiteLLMApiRoutes(ctx: RouteContext): Promise<boolean
return true;
}
// GET /api/litellm-api/reranker-pool - Get available reranker models from all providers
if (pathname === '/api/litellm-api/reranker-pool' && req.method === 'GET') {
try {
// Get list of all available reranker models from all providers
const config = loadLiteLLMApiConfig(initialPath);
const availableModels: Array<{ modelId: string; modelName: string; providers: string[] }> = [];
const modelMap = new Map<string, { modelId: string; modelName: string; providers: string[] }>();
for (const provider of config.providers) {
if (!provider.enabled || !provider.rerankerModels) continue;
for (const model of provider.rerankerModels) {
if (!model.enabled) continue;
const key = model.id;
if (modelMap.has(key)) {
modelMap.get(key)!.providers.push(provider.name);
} else {
modelMap.set(key, {
modelId: model.id,
modelName: model.name,
providers: [provider.name],
});
}
}
}
availableModels.push(...Array.from(modelMap.values()));
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({
availableModels,
}));
} catch (err) {
res.writeHead(500, { 'Content-Type': 'application/json' });
res.end(JSON.stringify({ error: (err as Error).message }));
}
return true;
}
// GET /api/litellm-api/embedding-pool/discover/:model - Preview auto-discovery results
const discoverMatch = pathname.match(/^\/api\/litellm-api\/embedding-pool\/discover\/([^/]+)$/);
if (discoverMatch && req.method === 'GET') {

View File

@@ -1465,6 +1465,10 @@ const i18n = {
'apiSettings.noProvidersFound': 'No providers found',
'apiSettings.llmModels': 'LLM Models',
'apiSettings.embeddingModels': 'Embedding Models',
'apiSettings.rerankerModels': 'Reranker Models',
'apiSettings.addRerankerModel': 'Add Reranker Model',
'apiSettings.rerankerTopK': 'Default Top K',
'apiSettings.rerankerTopKHint': 'Number of top results to return (default: 10)',
'apiSettings.manageModels': 'Manage',
'apiSettings.addModel': 'Add Model',
'apiSettings.multiKeySettings': 'Multi-Key Settings',
@@ -3416,6 +3420,10 @@ const i18n = {
'apiSettings.noProvidersFound': '未找到供应商',
'apiSettings.llmModels': '大语言模型',
'apiSettings.embeddingModels': '向量模型',
'apiSettings.rerankerModels': '重排模型',
'apiSettings.addRerankerModel': '添加重排模型',
'apiSettings.rerankerTopK': '默认 Top K',
'apiSettings.rerankerTopKHint': '返回的最高排名结果数量默认10',
'apiSettings.manageModels': '管理',
'apiSettings.addModel': '添加模型',
'apiSettings.multiKeySettings': '多密钥设置',

View File

@@ -1221,6 +1221,9 @@ function renderProviderDetail(providerId) {
'<button class="model-tab' + (activeModelTab === 'embedding' ? ' active' : '') + '" onclick="switchModelTab(\'embedding\')">' +
t('apiSettings.embeddingModels') +
'</button>' +
'<button class="model-tab' + (activeModelTab === 'reranker' ? ' active' : '') + '" onclick="switchModelTab(\'reranker\')">' +
t('apiSettings.rerankerModels') +
'</button>' +
'</div>' +
'<div class="model-section-actions">' +
'<button class="btn btn-secondary" onclick="showManageModelsModal(\'' + providerId + '\')">' +
@@ -1275,6 +1278,8 @@ function renderModelTree(provider) {
var models = activeModelTab === 'llm'
? (provider.llmModels || [])
: activeModelTab === 'reranker'
? (provider.rerankerModels || [])
: (provider.embeddingModels || []);
if (models.length === 0) {
@@ -1537,10 +1542,11 @@ function showAddModelModal(providerId, modelType) {
if (!provider) return;
const isLlm = modelType === 'llm';
const title = isLlm ? t('apiSettings.addLlmModel') : t('apiSettings.addEmbeddingModel');
const isReranker = modelType === 'reranker';
const title = isLlm ? t('apiSettings.addLlmModel') : isReranker ? t('apiSettings.addRerankerModel') : t('apiSettings.addEmbeddingModel');
// Get model presets based on provider type
const presets = isLlm ? getLlmPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
const presets = isLlm ? getLlmPresetsForType(provider.type) : isReranker ? getRerankerPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
// Group presets by series
const groupedPresets = groupPresetsBySeries(presets);
@@ -1562,9 +1568,8 @@ function showAddModelModal(providerId, modelType) {
Object.keys(groupedPresets).map(function(series) {
return '<optgroup label="' + series + '">' +
groupedPresets[series].map(function(m) {
return '<option value="' + m.id + '">' + m.name + ' ' +
(isLlm ? '(' + (m.contextWindow/1000) + 'K)' : '(' + m.dimensions + 'D)') +
'</option>';
var info = isLlm ? '(' + (m.contextWindow/1000) + 'K)' : isReranker ? '' : '(' + m.dimensions + 'D)';
return '<option value="' + m.id + '">' + m.name + ' ' + info + '</option>';
}).join('') +
'</optgroup>';
}).join('') +
@@ -1607,6 +1612,12 @@ function showAddModelModal(providerId, modelType) {
'<input type="checkbox" id="cap-vision" /> ' + t('apiSettings.vision') +
'</label>' +
'</div>'
: isReranker ?
'<div class="form-group">' +
'<label>' + t('apiSettings.rerankerTopK') + '</label>' +
'<input type="number" id="model-top-k" class="cli-input" value="10" min="1" max="100" />' +
'<span class="field-hint">' + t('apiSettings.rerankerTopKHint') + '</span>' +
'</div>'
:
'<div class="form-group">' +
'<label>' + t('apiSettings.embeddingDimensions') + ' *</label>' +
@@ -1691,6 +1702,37 @@ function getEmbeddingPresetsForType(providerType) {
return presets[providerType] || presets.custom;
}
/**
* Get reranker model presets based on provider type
*/
function getRerankerPresetsForType(providerType) {
const presets = {
openai: [
{ id: 'BAAI/bge-reranker-v2-m3', name: 'BGE Reranker v2 M3', series: 'BGE Reranker', topK: 10 },
{ id: 'BAAI/bge-reranker-large', name: 'BGE Reranker Large', series: 'BGE Reranker', topK: 10 },
{ id: 'BAAI/bge-reranker-base', name: 'BGE Reranker Base', series: 'BGE Reranker', topK: 10 }
],
cohere: [
{ id: 'rerank-english-v3.0', name: 'Rerank English v3.0', series: 'Cohere Rerank', topK: 10 },
{ id: 'rerank-multilingual-v3.0', name: 'Rerank Multilingual v3.0', series: 'Cohere Rerank', topK: 10 },
{ id: 'rerank-english-v2.0', name: 'Rerank English v2.0', series: 'Cohere Rerank', topK: 10 }
],
voyage: [
{ id: 'rerank-2', name: 'Rerank 2', series: 'Voyage Rerank', topK: 10 },
{ id: 'rerank-2-lite', name: 'Rerank 2 Lite', series: 'Voyage Rerank', topK: 10 },
{ id: 'rerank-1', name: 'Rerank 1', series: 'Voyage Rerank', topK: 10 }
],
jina: [
{ id: 'jina-reranker-v2-base-multilingual', name: 'Jina Reranker v2 Multilingual', series: 'Jina Reranker', topK: 10 },
{ id: 'jina-reranker-v1-base-en', name: 'Jina Reranker v1 English', series: 'Jina Reranker', topK: 10 }
],
custom: [
{ id: 'custom-reranker', name: 'Custom Reranker', series: 'Custom', topK: 10 }
]
};
return presets[providerType] || presets.custom;
}
/**
* Group presets by series
*/
@@ -1721,7 +1763,8 @@ function fillModelFromPreset(presetId, modelType) {
if (!provider) return;
const isLlm = modelType === 'llm';
const presets = isLlm ? getLlmPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
const isReranker = modelType === 'reranker';
const presets = isLlm ? getLlmPresetsForType(provider.type) : isReranker ? getRerankerPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
const preset = presets.find(function(p) { return p.id === presetId; });
if (preset) {
@@ -1732,7 +1775,11 @@ function fillModelFromPreset(presetId, modelType) {
if (isLlm && preset.contextWindow) {
document.getElementById('model-context-window').value = preset.contextWindow;
}
if (!isLlm && preset.dimensions) {
if (isReranker && preset.topK) {
var topKEl = document.getElementById('model-top-k');
if (topKEl) topKEl.value = preset.topK;
}
if (!isLlm && !isReranker && preset.dimensions) {
document.getElementById('model-dimensions').value = preset.dimensions;
if (preset.maxTokens) {
document.getElementById('model-max-tokens').value = preset.maxTokens;
@@ -1748,6 +1795,7 @@ function saveNewModel(event, providerId, modelType) {
event.preventDefault();
const isLlm = modelType === 'llm';
const isReranker = modelType === 'reranker';
const now = new Date().toISOString();
const newModel = {
@@ -1769,6 +1817,11 @@ function saveNewModel(event, providerId, modelType) {
functionCalling: document.getElementById('cap-function-calling').checked,
vision: document.getElementById('cap-vision').checked
};
} else if (isReranker) {
var topKEl = document.getElementById('model-top-k');
newModel.capabilities = {
topK: topKEl ? parseInt(topKEl.value) || 10 : 10
};
} else {
newModel.capabilities = {
embeddingDimension: parseInt(document.getElementById('model-dimensions').value) || 1536,
@@ -1780,7 +1833,7 @@ function saveNewModel(event, providerId, modelType) {
fetch('/api/litellm-api/providers/' + providerId)
.then(function(res) { return res.json(); })
.then(function(provider) {
const modelsKey = isLlm ? 'llmModels' : 'embeddingModels';
const modelsKey = isLlm ? 'llmModels' : isReranker ? 'rerankerModels' : 'embeddingModels';
const models = provider[modelsKey] || [];
// Check for duplicate ID
@@ -1824,7 +1877,8 @@ function showModelSettingsModal(providerId, modelId, modelType) {
if (!provider) return;
var isLlm = modelType === 'llm';
var models = isLlm ? (provider.llmModels || []) : (provider.embeddingModels || []);
var isReranker = modelType === 'reranker';
var models = isLlm ? (provider.llmModels || []) : isReranker ? (provider.rerankerModels || []) : (provider.embeddingModels || []);
var model = models.find(function(m) { return m.id === modelId; });
if (!model) return;
@@ -1834,7 +1888,7 @@ function showModelSettingsModal(providerId, modelId, modelType) {
// Calculate endpoint preview URL
var providerBase = provider.apiBase || getDefaultApiBase(provider.type);
var modelBaseUrl = endpointSettings.baseUrl || providerBase;
var endpointPath = isLlm ? '/chat/completions' : '/embeddings';
var endpointPath = isLlm ? '/chat/completions' : isReranker ? '/rerank' : '/embeddings';
var endpointPreview = modelBaseUrl + endpointPath;
var modalHtml = '<div class="modal-overlay" id="model-settings-modal">' +
@@ -1848,7 +1902,7 @@ function showModelSettingsModal(providerId, modelId, modelType) {
// Endpoint Preview Section (combined view + settings)
'<div class="form-section endpoint-preview-section">' +
'<h4><i data-lucide="' + (isLlm ? 'message-square' : 'box') + '"></i> ' + t('apiSettings.endpointPreview') + '</h4>' +
'<h4><i data-lucide="' + (isLlm ? 'message-square' : isReranker ? 'sort-asc' : 'box') + '"></i> ' + t('apiSettings.endpointPreview') + '</h4>' +
'<div class="endpoint-preview-box">' +
'<code id="model-endpoint-preview">' + escapeHtml(endpointPreview) + '</code>' +
'<button type="button" class="btn-icon-sm" onclick="copyModelEndpoint()" title="' + t('common.copy') + '">' +
@@ -1857,7 +1911,7 @@ function showModelSettingsModal(providerId, modelId, modelType) {
'</div>' +
'<div class="form-group">' +
'<label>' + t('apiSettings.modelBaseUrlOverride') + ' <span class="text-muted">(' + t('common.optional') + ')</span></label>' +
'<input type="text" id="model-settings-baseurl" class="cli-input" value="' + escapeHtml(endpointSettings.baseUrl || '') + '" placeholder="' + escapeHtml(providerBase) + '" oninput="updateModelEndpointPreview(\'' + (isLlm ? 'chat/completions' : 'embeddings') + '\', \'' + escapeHtml(providerBase) + '\')">' +
'<input type="text" id="model-settings-baseurl" class="cli-input" value="' + escapeHtml(endpointSettings.baseUrl || '') + '" placeholder="' + escapeHtml(providerBase) + '" oninput="updateModelEndpointPreview(\'' + (isLlm ? 'chat/completions' : isReranker ? 'rerank' : 'embeddings') + '\', \'' + escapeHtml(providerBase) + '\')">' +
'<small class="form-hint">' + t('apiSettings.modelBaseUrlHint') + '</small>' +
'</div>' +
'</div>' +
@@ -1968,7 +2022,8 @@ function saveModelSettings(event, providerId, modelId, modelType) {
event.preventDefault();
var isLlm = modelType === 'llm';
var modelsKey = isLlm ? 'llmModels' : 'embeddingModels';
var isReranker = modelType === 'reranker';
var modelsKey = isLlm ? 'llmModels' : isReranker ? 'rerankerModels' : 'embeddingModels';
fetch('/api/litellm-api/providers/' + providerId)
.then(function(res) { return res.json(); })
@@ -1994,6 +2049,11 @@ function saveModelSettings(event, providerId, modelId, modelType) {
functionCalling: document.getElementById('model-settings-function-calling').checked,
vision: document.getElementById('model-settings-vision').checked
};
} else if (isReranker) {
var topKEl = document.getElementById('model-settings-top-k');
models[modelIndex].capabilities = {
topK: topKEl ? parseInt(topKEl.value) || 10 : 10
};
} else {
models[modelIndex].capabilities = {
embeddingDimension: parseInt(document.getElementById('model-settings-dimensions').value) || 1536,
@@ -2042,7 +2102,8 @@ function deleteModel(providerId, modelId, modelType) {
if (!confirm(t('common.confirmDelete'))) return;
var isLlm = modelType === 'llm';
var modelsKey = isLlm ? 'llmModels' : 'embeddingModels';
var isReranker = modelType === 'reranker';
var modelsKey = isLlm ? 'llmModels' : isReranker ? 'rerankerModels' : 'embeddingModels';
fetch('/api/litellm-api/providers/' + providerId)
.then(function(res) { return res.json(); })

File diff suppressed because it is too large Load Diff

View File

@@ -226,6 +226,9 @@ export interface ProviderCredential {
/** Embedding models configured for this provider */
embeddingModels?: ModelDefinition[];
/** Reranker models configured for this provider */
rerankerModels?: ModelDefinition[];
/** Creation timestamp (ISO 8601) */
createdAt: string;

View File

@@ -125,6 +125,16 @@ class FastEmbedReranker(BaseReranker):
logger.debug("FastEmbed reranker model loaded successfully")
@staticmethod
def _sigmoid(x: float) -> float:
"""Numerically stable sigmoid function."""
if x < -709:
return 0.0
if x > 709:
return 1.0
import math
return 1.0 / (1.0 + math.exp(-x))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
@@ -165,8 +175,8 @@ class FastEmbedReranker(BaseReranker):
indices = [idx for idx, _ in indexed_docs]
try:
# TextCrossEncoder.rerank returns list of RerankResult with score attribute
results = list(
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=docs,
@@ -174,22 +184,12 @@ class FastEmbedReranker(BaseReranker):
)
)
# Map scores back to original positions
# Results are returned in descending score order, but we need original order
for result in results:
# Each result has 'index' (position in input docs) and 'score'
doc_idx = result.index if hasattr(result, "index") else 0
score = result.score if hasattr(result, "score") else 0.0
if doc_idx < len(indices):
original_idx = indices[doc_idx]
# Normalize score to [0, 1] using sigmoid if needed
# FastEmbed typically returns scores in [0, 1] already
if score < 0 or score > 1:
import math
score = 1.0 / (1.0 + math.exp(-score))
scores[original_idx] = float(score)
# Map scores back to original positions and normalize with sigmoid
for i, raw_score in enumerate(raw_scores):
if i < len(indices):
original_idx = indices[i]
# Normalize score to [0, 1] using stable sigmoid
scores[original_idx] = self._sigmoid(float(raw_score))
except Exception as e:
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
@@ -227,7 +227,8 @@ class FastEmbedReranker(BaseReranker):
return []
try:
results = list(
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=list(documents),
@@ -235,13 +236,13 @@ class FastEmbedReranker(BaseReranker):
)
)
# Convert to our format: (score, document, original_index)
# Convert to our format: (normalized_score, document, original_index)
ranked = []
for result in results:
idx = result.index if hasattr(result, "index") else 0
score = result.score if hasattr(result, "score") else 0.0
doc = documents[idx] if idx < len(documents) else ""
ranked.append((float(score), doc, idx))
for idx, raw_score in enumerate(raw_scores):
if idx < len(documents):
# Normalize score to [0, 1] using stable sigmoid
normalized = self._sigmoid(float(raw_score))
ranked.append((normalized, documents[idx], idx))
# Sort by score descending
ranked.sort(key=lambda x: x[0], reverse=True)