mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: add reranker models to ProviderCredential and improve FastEmbedReranker scoring
- Added `rerankerModels` property to the `ProviderCredential` interface in `litellm-api-config.ts` to support additional reranker configurations. - Introduced a numerically stable sigmoid function in `FastEmbedReranker` for score normalization. - Updated the scoring logic in `FastEmbedReranker` to use raw float scores from the encoder and normalize them using the new sigmoid function. - Adjusted the result mapping to maintain original document order while applying normalization.
This commit is contained in:
@@ -789,6 +789,46 @@ export async function handleLiteLLMApiRoutes(ctx: RouteContext): Promise<boolean
|
||||
return true;
|
||||
}
|
||||
|
||||
// GET /api/litellm-api/reranker-pool - Get available reranker models from all providers
|
||||
if (pathname === '/api/litellm-api/reranker-pool' && req.method === 'GET') {
|
||||
try {
|
||||
// Get list of all available reranker models from all providers
|
||||
const config = loadLiteLLMApiConfig(initialPath);
|
||||
const availableModels: Array<{ modelId: string; modelName: string; providers: string[] }> = [];
|
||||
const modelMap = new Map<string, { modelId: string; modelName: string; providers: string[] }>();
|
||||
|
||||
for (const provider of config.providers) {
|
||||
if (!provider.enabled || !provider.rerankerModels) continue;
|
||||
|
||||
for (const model of provider.rerankerModels) {
|
||||
if (!model.enabled) continue;
|
||||
|
||||
const key = model.id;
|
||||
if (modelMap.has(key)) {
|
||||
modelMap.get(key)!.providers.push(provider.name);
|
||||
} else {
|
||||
modelMap.set(key, {
|
||||
modelId: model.id,
|
||||
modelName: model.name,
|
||||
providers: [provider.name],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
availableModels.push(...Array.from(modelMap.values()));
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
availableModels,
|
||||
}));
|
||||
} catch (err) {
|
||||
res.writeHead(500, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ error: (err as Error).message }));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// GET /api/litellm-api/embedding-pool/discover/:model - Preview auto-discovery results
|
||||
const discoverMatch = pathname.match(/^\/api\/litellm-api\/embedding-pool\/discover\/([^/]+)$/);
|
||||
if (discoverMatch && req.method === 'GET') {
|
||||
|
||||
@@ -1465,6 +1465,10 @@ const i18n = {
|
||||
'apiSettings.noProvidersFound': 'No providers found',
|
||||
'apiSettings.llmModels': 'LLM Models',
|
||||
'apiSettings.embeddingModels': 'Embedding Models',
|
||||
'apiSettings.rerankerModels': 'Reranker Models',
|
||||
'apiSettings.addRerankerModel': 'Add Reranker Model',
|
||||
'apiSettings.rerankerTopK': 'Default Top K',
|
||||
'apiSettings.rerankerTopKHint': 'Number of top results to return (default: 10)',
|
||||
'apiSettings.manageModels': 'Manage',
|
||||
'apiSettings.addModel': 'Add Model',
|
||||
'apiSettings.multiKeySettings': 'Multi-Key Settings',
|
||||
@@ -3416,6 +3420,10 @@ const i18n = {
|
||||
'apiSettings.noProvidersFound': '未找到供应商',
|
||||
'apiSettings.llmModels': '大语言模型',
|
||||
'apiSettings.embeddingModels': '向量模型',
|
||||
'apiSettings.rerankerModels': '重排模型',
|
||||
'apiSettings.addRerankerModel': '添加重排模型',
|
||||
'apiSettings.rerankerTopK': '默认 Top K',
|
||||
'apiSettings.rerankerTopKHint': '返回的最高排名结果数量(默认:10)',
|
||||
'apiSettings.manageModels': '管理',
|
||||
'apiSettings.addModel': '添加模型',
|
||||
'apiSettings.multiKeySettings': '多密钥设置',
|
||||
|
||||
@@ -1221,6 +1221,9 @@ function renderProviderDetail(providerId) {
|
||||
'<button class="model-tab' + (activeModelTab === 'embedding' ? ' active' : '') + '" onclick="switchModelTab(\'embedding\')">' +
|
||||
t('apiSettings.embeddingModels') +
|
||||
'</button>' +
|
||||
'<button class="model-tab' + (activeModelTab === 'reranker' ? ' active' : '') + '" onclick="switchModelTab(\'reranker\')">' +
|
||||
t('apiSettings.rerankerModels') +
|
||||
'</button>' +
|
||||
'</div>' +
|
||||
'<div class="model-section-actions">' +
|
||||
'<button class="btn btn-secondary" onclick="showManageModelsModal(\'' + providerId + '\')">' +
|
||||
@@ -1275,6 +1278,8 @@ function renderModelTree(provider) {
|
||||
|
||||
var models = activeModelTab === 'llm'
|
||||
? (provider.llmModels || [])
|
||||
: activeModelTab === 'reranker'
|
||||
? (provider.rerankerModels || [])
|
||||
: (provider.embeddingModels || []);
|
||||
|
||||
if (models.length === 0) {
|
||||
@@ -1537,10 +1542,11 @@ function showAddModelModal(providerId, modelType) {
|
||||
if (!provider) return;
|
||||
|
||||
const isLlm = modelType === 'llm';
|
||||
const title = isLlm ? t('apiSettings.addLlmModel') : t('apiSettings.addEmbeddingModel');
|
||||
const isReranker = modelType === 'reranker';
|
||||
const title = isLlm ? t('apiSettings.addLlmModel') : isReranker ? t('apiSettings.addRerankerModel') : t('apiSettings.addEmbeddingModel');
|
||||
|
||||
// Get model presets based on provider type
|
||||
const presets = isLlm ? getLlmPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
|
||||
const presets = isLlm ? getLlmPresetsForType(provider.type) : isReranker ? getRerankerPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
|
||||
|
||||
// Group presets by series
|
||||
const groupedPresets = groupPresetsBySeries(presets);
|
||||
@@ -1562,9 +1568,8 @@ function showAddModelModal(providerId, modelType) {
|
||||
Object.keys(groupedPresets).map(function(series) {
|
||||
return '<optgroup label="' + series + '">' +
|
||||
groupedPresets[series].map(function(m) {
|
||||
return '<option value="' + m.id + '">' + m.name + ' ' +
|
||||
(isLlm ? '(' + (m.contextWindow/1000) + 'K)' : '(' + m.dimensions + 'D)') +
|
||||
'</option>';
|
||||
var info = isLlm ? '(' + (m.contextWindow/1000) + 'K)' : isReranker ? '' : '(' + m.dimensions + 'D)';
|
||||
return '<option value="' + m.id + '">' + m.name + ' ' + info + '</option>';
|
||||
}).join('') +
|
||||
'</optgroup>';
|
||||
}).join('') +
|
||||
@@ -1607,6 +1612,12 @@ function showAddModelModal(providerId, modelType) {
|
||||
'<input type="checkbox" id="cap-vision" /> ' + t('apiSettings.vision') +
|
||||
'</label>' +
|
||||
'</div>'
|
||||
: isReranker ?
|
||||
'<div class="form-group">' +
|
||||
'<label>' + t('apiSettings.rerankerTopK') + '</label>' +
|
||||
'<input type="number" id="model-top-k" class="cli-input" value="10" min="1" max="100" />' +
|
||||
'<span class="field-hint">' + t('apiSettings.rerankerTopKHint') + '</span>' +
|
||||
'</div>'
|
||||
:
|
||||
'<div class="form-group">' +
|
||||
'<label>' + t('apiSettings.embeddingDimensions') + ' *</label>' +
|
||||
@@ -1691,6 +1702,37 @@ function getEmbeddingPresetsForType(providerType) {
|
||||
return presets[providerType] || presets.custom;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get reranker model presets based on provider type
|
||||
*/
|
||||
function getRerankerPresetsForType(providerType) {
|
||||
const presets = {
|
||||
openai: [
|
||||
{ id: 'BAAI/bge-reranker-v2-m3', name: 'BGE Reranker v2 M3', series: 'BGE Reranker', topK: 10 },
|
||||
{ id: 'BAAI/bge-reranker-large', name: 'BGE Reranker Large', series: 'BGE Reranker', topK: 10 },
|
||||
{ id: 'BAAI/bge-reranker-base', name: 'BGE Reranker Base', series: 'BGE Reranker', topK: 10 }
|
||||
],
|
||||
cohere: [
|
||||
{ id: 'rerank-english-v3.0', name: 'Rerank English v3.0', series: 'Cohere Rerank', topK: 10 },
|
||||
{ id: 'rerank-multilingual-v3.0', name: 'Rerank Multilingual v3.0', series: 'Cohere Rerank', topK: 10 },
|
||||
{ id: 'rerank-english-v2.0', name: 'Rerank English v2.0', series: 'Cohere Rerank', topK: 10 }
|
||||
],
|
||||
voyage: [
|
||||
{ id: 'rerank-2', name: 'Rerank 2', series: 'Voyage Rerank', topK: 10 },
|
||||
{ id: 'rerank-2-lite', name: 'Rerank 2 Lite', series: 'Voyage Rerank', topK: 10 },
|
||||
{ id: 'rerank-1', name: 'Rerank 1', series: 'Voyage Rerank', topK: 10 }
|
||||
],
|
||||
jina: [
|
||||
{ id: 'jina-reranker-v2-base-multilingual', name: 'Jina Reranker v2 Multilingual', series: 'Jina Reranker', topK: 10 },
|
||||
{ id: 'jina-reranker-v1-base-en', name: 'Jina Reranker v1 English', series: 'Jina Reranker', topK: 10 }
|
||||
],
|
||||
custom: [
|
||||
{ id: 'custom-reranker', name: 'Custom Reranker', series: 'Custom', topK: 10 }
|
||||
]
|
||||
};
|
||||
return presets[providerType] || presets.custom;
|
||||
}
|
||||
|
||||
/**
|
||||
* Group presets by series
|
||||
*/
|
||||
@@ -1721,7 +1763,8 @@ function fillModelFromPreset(presetId, modelType) {
|
||||
if (!provider) return;
|
||||
|
||||
const isLlm = modelType === 'llm';
|
||||
const presets = isLlm ? getLlmPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
|
||||
const isReranker = modelType === 'reranker';
|
||||
const presets = isLlm ? getLlmPresetsForType(provider.type) : isReranker ? getRerankerPresetsForType(provider.type) : getEmbeddingPresetsForType(provider.type);
|
||||
const preset = presets.find(function(p) { return p.id === presetId; });
|
||||
|
||||
if (preset) {
|
||||
@@ -1732,7 +1775,11 @@ function fillModelFromPreset(presetId, modelType) {
|
||||
if (isLlm && preset.contextWindow) {
|
||||
document.getElementById('model-context-window').value = preset.contextWindow;
|
||||
}
|
||||
if (!isLlm && preset.dimensions) {
|
||||
if (isReranker && preset.topK) {
|
||||
var topKEl = document.getElementById('model-top-k');
|
||||
if (topKEl) topKEl.value = preset.topK;
|
||||
}
|
||||
if (!isLlm && !isReranker && preset.dimensions) {
|
||||
document.getElementById('model-dimensions').value = preset.dimensions;
|
||||
if (preset.maxTokens) {
|
||||
document.getElementById('model-max-tokens').value = preset.maxTokens;
|
||||
@@ -1748,6 +1795,7 @@ function saveNewModel(event, providerId, modelType) {
|
||||
event.preventDefault();
|
||||
|
||||
const isLlm = modelType === 'llm';
|
||||
const isReranker = modelType === 'reranker';
|
||||
const now = new Date().toISOString();
|
||||
|
||||
const newModel = {
|
||||
@@ -1769,6 +1817,11 @@ function saveNewModel(event, providerId, modelType) {
|
||||
functionCalling: document.getElementById('cap-function-calling').checked,
|
||||
vision: document.getElementById('cap-vision').checked
|
||||
};
|
||||
} else if (isReranker) {
|
||||
var topKEl = document.getElementById('model-top-k');
|
||||
newModel.capabilities = {
|
||||
topK: topKEl ? parseInt(topKEl.value) || 10 : 10
|
||||
};
|
||||
} else {
|
||||
newModel.capabilities = {
|
||||
embeddingDimension: parseInt(document.getElementById('model-dimensions').value) || 1536,
|
||||
@@ -1780,7 +1833,7 @@ function saveNewModel(event, providerId, modelType) {
|
||||
fetch('/api/litellm-api/providers/' + providerId)
|
||||
.then(function(res) { return res.json(); })
|
||||
.then(function(provider) {
|
||||
const modelsKey = isLlm ? 'llmModels' : 'embeddingModels';
|
||||
const modelsKey = isLlm ? 'llmModels' : isReranker ? 'rerankerModels' : 'embeddingModels';
|
||||
const models = provider[modelsKey] || [];
|
||||
|
||||
// Check for duplicate ID
|
||||
@@ -1824,7 +1877,8 @@ function showModelSettingsModal(providerId, modelId, modelType) {
|
||||
if (!provider) return;
|
||||
|
||||
var isLlm = modelType === 'llm';
|
||||
var models = isLlm ? (provider.llmModels || []) : (provider.embeddingModels || []);
|
||||
var isReranker = modelType === 'reranker';
|
||||
var models = isLlm ? (provider.llmModels || []) : isReranker ? (provider.rerankerModels || []) : (provider.embeddingModels || []);
|
||||
var model = models.find(function(m) { return m.id === modelId; });
|
||||
if (!model) return;
|
||||
|
||||
@@ -1834,7 +1888,7 @@ function showModelSettingsModal(providerId, modelId, modelType) {
|
||||
// Calculate endpoint preview URL
|
||||
var providerBase = provider.apiBase || getDefaultApiBase(provider.type);
|
||||
var modelBaseUrl = endpointSettings.baseUrl || providerBase;
|
||||
var endpointPath = isLlm ? '/chat/completions' : '/embeddings';
|
||||
var endpointPath = isLlm ? '/chat/completions' : isReranker ? '/rerank' : '/embeddings';
|
||||
var endpointPreview = modelBaseUrl + endpointPath;
|
||||
|
||||
var modalHtml = '<div class="modal-overlay" id="model-settings-modal">' +
|
||||
@@ -1848,7 +1902,7 @@ function showModelSettingsModal(providerId, modelId, modelType) {
|
||||
|
||||
// Endpoint Preview Section (combined view + settings)
|
||||
'<div class="form-section endpoint-preview-section">' +
|
||||
'<h4><i data-lucide="' + (isLlm ? 'message-square' : 'box') + '"></i> ' + t('apiSettings.endpointPreview') + '</h4>' +
|
||||
'<h4><i data-lucide="' + (isLlm ? 'message-square' : isReranker ? 'sort-asc' : 'box') + '"></i> ' + t('apiSettings.endpointPreview') + '</h4>' +
|
||||
'<div class="endpoint-preview-box">' +
|
||||
'<code id="model-endpoint-preview">' + escapeHtml(endpointPreview) + '</code>' +
|
||||
'<button type="button" class="btn-icon-sm" onclick="copyModelEndpoint()" title="' + t('common.copy') + '">' +
|
||||
@@ -1857,7 +1911,7 @@ function showModelSettingsModal(providerId, modelId, modelType) {
|
||||
'</div>' +
|
||||
'<div class="form-group">' +
|
||||
'<label>' + t('apiSettings.modelBaseUrlOverride') + ' <span class="text-muted">(' + t('common.optional') + ')</span></label>' +
|
||||
'<input type="text" id="model-settings-baseurl" class="cli-input" value="' + escapeHtml(endpointSettings.baseUrl || '') + '" placeholder="' + escapeHtml(providerBase) + '" oninput="updateModelEndpointPreview(\'' + (isLlm ? 'chat/completions' : 'embeddings') + '\', \'' + escapeHtml(providerBase) + '\')">' +
|
||||
'<input type="text" id="model-settings-baseurl" class="cli-input" value="' + escapeHtml(endpointSettings.baseUrl || '') + '" placeholder="' + escapeHtml(providerBase) + '" oninput="updateModelEndpointPreview(\'' + (isLlm ? 'chat/completions' : isReranker ? 'rerank' : 'embeddings') + '\', \'' + escapeHtml(providerBase) + '\')">' +
|
||||
'<small class="form-hint">' + t('apiSettings.modelBaseUrlHint') + '</small>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
@@ -1968,7 +2022,8 @@ function saveModelSettings(event, providerId, modelId, modelType) {
|
||||
event.preventDefault();
|
||||
|
||||
var isLlm = modelType === 'llm';
|
||||
var modelsKey = isLlm ? 'llmModels' : 'embeddingModels';
|
||||
var isReranker = modelType === 'reranker';
|
||||
var modelsKey = isLlm ? 'llmModels' : isReranker ? 'rerankerModels' : 'embeddingModels';
|
||||
|
||||
fetch('/api/litellm-api/providers/' + providerId)
|
||||
.then(function(res) { return res.json(); })
|
||||
@@ -1994,6 +2049,11 @@ function saveModelSettings(event, providerId, modelId, modelType) {
|
||||
functionCalling: document.getElementById('model-settings-function-calling').checked,
|
||||
vision: document.getElementById('model-settings-vision').checked
|
||||
};
|
||||
} else if (isReranker) {
|
||||
var topKEl = document.getElementById('model-settings-top-k');
|
||||
models[modelIndex].capabilities = {
|
||||
topK: topKEl ? parseInt(topKEl.value) || 10 : 10
|
||||
};
|
||||
} else {
|
||||
models[modelIndex].capabilities = {
|
||||
embeddingDimension: parseInt(document.getElementById('model-settings-dimensions').value) || 1536,
|
||||
@@ -2042,7 +2102,8 @@ function deleteModel(providerId, modelId, modelType) {
|
||||
if (!confirm(t('common.confirmDelete'))) return;
|
||||
|
||||
var isLlm = modelType === 'llm';
|
||||
var modelsKey = isLlm ? 'llmModels' : 'embeddingModels';
|
||||
var isReranker = modelType === 'reranker';
|
||||
var modelsKey = isLlm ? 'llmModels' : isReranker ? 'rerankerModels' : 'embeddingModels';
|
||||
|
||||
fetch('/api/litellm-api/providers/' + providerId)
|
||||
.then(function(res) { return res.json(); })
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -226,6 +226,9 @@ export interface ProviderCredential {
|
||||
/** Embedding models configured for this provider */
|
||||
embeddingModels?: ModelDefinition[];
|
||||
|
||||
/** Reranker models configured for this provider */
|
||||
rerankerModels?: ModelDefinition[];
|
||||
|
||||
/** Creation timestamp (ISO 8601) */
|
||||
createdAt: string;
|
||||
|
||||
|
||||
@@ -125,6 +125,16 @@ class FastEmbedReranker(BaseReranker):
|
||||
|
||||
logger.debug("FastEmbed reranker model loaded successfully")
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: float) -> float:
|
||||
"""Numerically stable sigmoid function."""
|
||||
if x < -709:
|
||||
return 0.0
|
||||
if x > 709:
|
||||
return 1.0
|
||||
import math
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
@@ -165,8 +175,8 @@ class FastEmbedReranker(BaseReranker):
|
||||
indices = [idx for idx, _ in indexed_docs]
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns list of RerankResult with score attribute
|
||||
results = list(
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
@@ -174,22 +184,12 @@ class FastEmbedReranker(BaseReranker):
|
||||
)
|
||||
)
|
||||
|
||||
# Map scores back to original positions
|
||||
# Results are returned in descending score order, but we need original order
|
||||
for result in results:
|
||||
# Each result has 'index' (position in input docs) and 'score'
|
||||
doc_idx = result.index if hasattr(result, "index") else 0
|
||||
score = result.score if hasattr(result, "score") else 0.0
|
||||
|
||||
if doc_idx < len(indices):
|
||||
original_idx = indices[doc_idx]
|
||||
# Normalize score to [0, 1] using sigmoid if needed
|
||||
# FastEmbed typically returns scores in [0, 1] already
|
||||
if score < 0 or score > 1:
|
||||
import math
|
||||
|
||||
score = 1.0 / (1.0 + math.exp(-score))
|
||||
scores[original_idx] = float(score)
|
||||
# Map scores back to original positions and normalize with sigmoid
|
||||
for i, raw_score in enumerate(raw_scores):
|
||||
if i < len(indices):
|
||||
original_idx = indices[i]
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
scores[original_idx] = self._sigmoid(float(raw_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
|
||||
@@ -227,7 +227,8 @@ class FastEmbedReranker(BaseReranker):
|
||||
return []
|
||||
|
||||
try:
|
||||
results = list(
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=list(documents),
|
||||
@@ -235,13 +236,13 @@ class FastEmbedReranker(BaseReranker):
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to our format: (score, document, original_index)
|
||||
# Convert to our format: (normalized_score, document, original_index)
|
||||
ranked = []
|
||||
for result in results:
|
||||
idx = result.index if hasattr(result, "index") else 0
|
||||
score = result.score if hasattr(result, "score") else 0.0
|
||||
doc = documents[idx] if idx < len(documents) else ""
|
||||
ranked.append((float(score), doc, idx))
|
||||
for idx, raw_score in enumerate(raw_scores):
|
||||
if idx < len(documents):
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
normalized = self._sigmoid(float(raw_score))
|
||||
ranked.append((normalized, documents[idx], idx))
|
||||
|
||||
# Sort by score descending
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
Reference in New Issue
Block a user