feat(model-lock): implement model lock management with localStorage support

This commit is contained in:
catlog22
2026-01-03 19:48:07 +08:00
parent 6043e6aa3b
commit 0af84be775
4 changed files with 570 additions and 99 deletions

View File

@@ -283,9 +283,10 @@ function buildCodexLensConfigContent(config) {
'<div id="semanticDepsStatus" class="space-y-2">' + '<div id="semanticDepsStatus" class="space-y-2">' +
'<div class="text-sm text-muted-foreground">' + t('codexlens.checkingDeps') + '</div>' + '<div class="text-sm text-muted-foreground">' + t('codexlens.checkingDeps') + '</div>' +
'</div>' + '</div>' +
'<div id="spladeStatus" class="space-y-2 mt-3 pt-3 border-t border-border">' + // SPLADE status hidden - not currently used
'<div class="text-sm text-muted-foreground">' + t('common.loading') + '</div>' + // '<div id="spladeStatus" class="space-y-2 mt-3 pt-3 border-t border-border">' +
'</div>' + // '<div class="text-sm text-muted-foreground">' + t('common.loading') + '</div>' +
// '</div>' +
'</div>' + '</div>' +
// Model Management // Model Management
@@ -498,13 +499,142 @@ function initCodexLensConfigEvents(currentConfig) {
// Load semantic dependencies status // Load semantic dependencies status
loadSemanticDepsStatus(); loadSemanticDepsStatus();
// Load SPLADE status // SPLADE status hidden - not currently used
loadSpladeStatus(); // loadSpladeStatus();
// Load model list // Load model list
loadModelList(); loadModelList();
} }
// ============================================================
// MODEL LOCK/UNLOCK MANAGEMENT
// ============================================================
var MODEL_LOCK_KEY = 'codexlens_model_lock';
/**
* Get model lock state from localStorage
* @returns {Object} { locked: boolean, backend: string, model: string }
*/
function getModelLockState() {
try {
var stored = localStorage.getItem(MODEL_LOCK_KEY);
if (stored) {
return JSON.parse(stored);
}
} catch (e) {
console.warn('[CodexLens] Failed to get model lock state:', e);
}
return { locked: false, backend: 'fastembed', model: 'code' };
}
/**
* Set model lock state in localStorage
* @param {boolean} locked - Whether model is locked
* @param {string} backend - Selected backend
* @param {string} model - Selected model
*/
function setModelLockState(locked, backend, model) {
try {
localStorage.setItem(MODEL_LOCK_KEY, JSON.stringify({
locked: locked,
backend: backend || 'fastembed',
model: model || 'code'
}));
} catch (e) {
console.warn('[CodexLens] Failed to save model lock state:', e);
}
}
/**
* Toggle model lock state
*/
function toggleModelLock() {
var backendSelect = document.getElementById('pageBackendSelect');
var modelSelect = document.getElementById('pageModelSelect');
var lockBtn = document.getElementById('modelLockBtn');
var lockIcon = document.getElementById('modelLockIcon');
var currentState = getModelLockState();
var newLocked = !currentState.locked;
// Get current values if locking
var backend = newLocked ? (backendSelect ? backendSelect.value : 'fastembed') : currentState.backend;
var model = newLocked ? (modelSelect ? modelSelect.value : 'code') : currentState.model;
// Save state
setModelLockState(newLocked, backend, model);
// Update UI
applyModelLockUI(newLocked, backend, model);
// Show feedback
if (newLocked) {
showRefreshToast('Model locked: ' + backend + ' / ' + model, 'success');
} else {
showRefreshToast('Model unlocked', 'info');
}
}
/**
* Apply model lock UI state
*/
function applyModelLockUI(locked, backend, model) {
var backendSelect = document.getElementById('pageBackendSelect');
var modelSelect = document.getElementById('pageModelSelect');
var lockBtn = document.getElementById('modelLockBtn');
var lockIcon = document.getElementById('modelLockIcon');
var lockText = document.getElementById('modelLockText');
if (backendSelect) {
backendSelect.disabled = locked;
if (locked && backend) {
backendSelect.value = backend;
}
}
if (modelSelect) {
modelSelect.disabled = locked;
if (locked && model) {
modelSelect.value = model;
}
}
if (lockBtn) {
if (locked) {
lockBtn.classList.remove('btn-outline');
lockBtn.classList.add('btn-primary');
} else {
lockBtn.classList.remove('btn-primary');
lockBtn.classList.add('btn-outline');
}
}
if (lockIcon) {
lockIcon.setAttribute('data-lucide', locked ? 'lock' : 'unlock');
if (window.lucide) lucide.createIcons();
}
if (lockText) {
lockText.textContent = locked ? 'Locked' : 'Lock Model';
}
}
/**
* Initialize model lock state on page load
*/
function initModelLockState() {
var state = getModelLockState();
if (state.locked) {
applyModelLockUI(true, state.backend, state.model);
}
}
// Make functions globally accessible
window.toggleModelLock = toggleModelLock;
window.initModelLockState = initModelLockState;
window.getModelLockState = getModelLockState;
// ============================================================ // ============================================================
// ENVIRONMENT VARIABLES MANAGEMENT // ENVIRONMENT VARIABLES MANAGEMENT
// ============================================================ // ============================================================
@@ -987,12 +1117,12 @@ async function installSemanticDeps() {
} }
// ============================================================ // ============================================================
// SPLADE MANAGEMENT // SPLADE MANAGEMENT - Hidden (not currently used)
// ============================================================ // ============================================================
// SPLADE functionality is hidden from the UI. The code is preserved
// for potential future use but is not exposed to users.
/** /*
* Load SPLADE status
*/
async function loadSpladeStatus() { async function loadSpladeStatus() {
var container = document.getElementById('spladeStatus'); var container = document.getElementById('spladeStatus');
if (!container) return; if (!container) return;
@@ -1039,9 +1169,6 @@ async function loadSpladeStatus() {
} }
} }
/**
* Install SPLADE package
*/
async function installSplade(gpu) { async function installSplade(gpu) {
var container = document.getElementById('spladeStatus'); var container = document.getElementById('spladeStatus');
if (!container) return; if (!container) return;
@@ -1073,6 +1200,7 @@ async function installSplade(gpu) {
loadSpladeStatus(); loadSpladeStatus();
} }
} }
*/
// ============================================================ // ============================================================
@@ -2342,9 +2470,6 @@ function buildCodexLensManagerPage(config) {
var indexCount = config.index_count || 0; var indexCount = config.index_count || 0;
var isInstalled = window.cliToolsStatus?.codexlens?.installed || false; var isInstalled = window.cliToolsStatus?.codexlens?.installed || false;
// Build model options for vector indexing
var modelOptions = buildModelSelectOptionsForPage();
return '<div class="codexlens-manager-page space-y-6">' + return '<div class="codexlens-manager-page space-y-6">' +
// Header with status // Header with status
'<div class="bg-card border border-border rounded-lg p-6">' + '<div class="bg-card border border-border rounded-lg p-6">' +
@@ -2375,64 +2500,11 @@ function buildCodexLensManagerPage(config) {
'<div class="grid grid-cols-1 lg:grid-cols-2 gap-6">' + '<div class="grid grid-cols-1 lg:grid-cols-2 gap-6">' +
// Left Column // Left Column
'<div class="space-y-6">' + '<div class="space-y-6">' +
// Create Index Section // Create Index Section - Simplified (model config in Environment Variables)
'<div class="bg-card border border-border rounded-lg p-5">' + '<div class="bg-card border border-border rounded-lg p-5">' +
'<h4 class="text-lg font-semibold mb-4 flex items-center gap-2"><i data-lucide="layers" class="w-5 h-5 text-primary"></i> ' + t('codexlens.createIndex') + '</h4>' + '<h4 class="text-lg font-semibold mb-4 flex items-center gap-2"><i data-lucide="layers" class="w-5 h-5 text-primary"></i> ' + t('codexlens.createIndex') + '</h4>' +
'<div class="space-y-4">' + '<div class="space-y-4">' +
// Backend selector (fastembed local or litellm API) // Index Actions - Primary buttons
'<div class="mb-4">' +
'<label class="block text-sm font-medium mb-1.5">' + (t('codexlens.embeddingBackend') || 'Embedding Backend') + '</label>' +
'<select id="pageBackendSelect" class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm" onchange="onEmbeddingBackendChange()">' +
'<option value="fastembed">' + (t('codexlens.localFastembed') || 'Local (FastEmbed)') + '</option>' +
'<option value="litellm">' + (t('codexlens.apiLitellm') || 'API (LiteLLM)') + '</option>' +
'</select>' +
'<p class="text-xs text-muted-foreground mt-1">' + (t('codexlens.backendHint') || 'Select local model or remote API endpoint') + '</p>' +
'</div>' +
// Model selector
'<div>' +
'<label class="block text-sm font-medium mb-1.5">' + t('codexlens.embeddingModel') + '</label>' +
'<select id="pageModelSelect" class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm">' +
modelOptions +
'</select>' +
'<p class="text-xs text-muted-foreground mt-1">' + t('codexlens.modelHint') + '</p>' +
'</div>' +
// Concurrency selector (only for LiteLLM backend)
'<div id="concurrencySelector" class="hidden">' +
'<label class="block text-sm font-medium mb-1.5">' + t('codexlens.concurrency') + '</label>' +
'<div class="flex items-center gap-2">' +
'<input type="number" id="pageConcurrencyInput" min="1" value="4" ' +
'class="w-24 px-3 py-2 border border-border rounded-lg bg-background text-sm" ' +
'onchange="validateConcurrencyInput(this)" />' +
'<span class="text-sm text-muted-foreground">workers</span>' +
'<span class="text-xs text-primary ml-2">(4 = recommended)</span>' +
'</div>' +
'<p class="text-xs text-muted-foreground mt-1">' + t('codexlens.concurrencyHint') + '</p>' +
'</div>' +
// Multi-Provider Rotation (only for LiteLLM backend) - Simplified, config in API Settings
'<div id="rotationSection" class="hidden">' +
'<div class="border border-border rounded-lg p-3 bg-muted/30">' +
'<div class="flex items-center justify-between mb-2">' +
'<div class="flex items-center gap-2">' +
'<i data-lucide="rotate-cw" class="w-4 h-4 text-primary"></i>' +
'<span class="text-sm font-medium">' + t('codexlens.rotation') + '</span>' +
'</div>' +
'<div id="rotationStatusBadge" class="text-xs px-2 py-0.5 rounded-full bg-muted text-muted-foreground">' +
t('common.disabled') +
'</div>' +
'</div>' +
'<p class="text-xs text-muted-foreground mb-2">' + t('codexlens.rotationDesc') + '</p>' +
'<div id="rotationDetails" class="text-xs text-muted-foreground mb-3 hidden">' +
'<span id="rotationModelName"></span> · <span id="rotationEndpointCount"></span>' +
'</div>' +
'<div class="flex items-center gap-2">' +
'<a href="#" class="btn-sm btn-outline flex items-center gap-1.5" onclick="navigateToApiSettingsEmbeddingPool(); return false;">' +
'<i data-lucide="external-link" class="w-3.5 h-3.5"></i>' +
t('codexlens.configureInApiSettings') +
'</a>' +
'</div>' +
'</div>' +
'</div>' +
// Index buttons - two modes: full (FTS + Vector) or FTS only
'<div class="grid grid-cols-2 gap-3">' + '<div class="grid grid-cols-2 gap-3">' +
'<button class="btn btn-primary flex items-center justify-center gap-2 py-3" onclick="initCodexLensIndexFromPage(\'full\')" title="' + t('codexlens.fullIndexDesc') + '">' + '<button class="btn btn-primary flex items-center justify-center gap-2 py-3" onclick="initCodexLensIndexFromPage(\'full\')" title="' + t('codexlens.fullIndexDesc') + '">' +
'<i data-lucide="layers" class="w-4 h-4"></i>' + '<i data-lucide="layers" class="w-4 h-4"></i>' +
@@ -2443,7 +2515,28 @@ function buildCodexLensManagerPage(config) {
'<span>' + t('codexlens.ftsIndex') + '</span>' + '<span>' + t('codexlens.ftsIndex') + '</span>' +
'</button>' + '</button>' +
'</div>' + '</div>' +
'<p class="text-xs text-muted-foreground">' + t('codexlens.indexTypeHint') + '</p>' + // Incremental Update button
'<button class="btn btn-outline w-full flex items-center justify-center gap-2 py-2.5" onclick="runIncrementalUpdate()" title="Update index with changed files only">' +
'<i data-lucide="refresh-cw" class="w-4 h-4"></i>' +
'<span>Incremental Update</span>' +
'</button>' +
// Watchdog Section
'<div class="border border-border rounded-lg p-3 bg-muted/30">' +
'<div class="flex items-center justify-between">' +
'<div class="flex items-center gap-2">' +
'<i data-lucide="eye" class="w-4 h-4 text-primary"></i>' +
'<span class="text-sm font-medium">File Watcher</span>' +
'</div>' +
'<div id="watcherStatusBadge" class="flex items-center gap-2">' +
'<span class="text-xs px-2 py-0.5 rounded-full bg-muted text-muted-foreground">Stopped</span>' +
'<button class="btn-sm btn-outline" onclick="toggleWatcher()" id="watcherToggleBtn">' +
'<i data-lucide="play" class="w-3.5 h-3.5"></i>' +
'</button>' +
'</div>' +
'</div>' +
'<p class="text-xs text-muted-foreground mt-2">Auto-update index when files change</p>' +
'</div>' +
'<p class="text-xs text-muted-foreground">' + t('codexlens.indexTypeHint') + ' Configure embedding model in Environment Variables below.</p>' +
'</div>' + '</div>' +
'</div>' + '</div>' +
// Storage Path Section // Storage Path Section
@@ -2464,6 +2557,16 @@ function buildCodexLensManagerPage(config) {
'</div>' + '</div>' +
'</div>' + '</div>' +
'</div>' + '</div>' +
// Environment Variables Section
'<div class="bg-card border border-border rounded-lg p-5">' +
'<div class="flex items-center justify-between mb-4">' +
'<h4 class="text-lg font-semibold flex items-center gap-2"><i data-lucide="file-code" class="w-5 h-5 text-primary"></i> Environment Variables</h4>' +
'<button class="btn-sm btn-outline" onclick="loadEnvVariables()"><i data-lucide="refresh-cw" class="w-3.5 h-3.5"></i> Load</button>' +
'</div>' +
'<div id="envVarsContainer" class="space-y-2">' +
'<div class="text-sm text-muted-foreground">Click Load to view/edit ~/.codexlens/.env</div>' +
'</div>' +
'</div>' +
// Maintenance Section // Maintenance Section
'<div class="bg-card border border-border rounded-lg p-5">' + '<div class="bg-card border border-border rounded-lg p-5">' +
'<h4 class="text-lg font-semibold mb-4 flex items-center gap-2"><i data-lucide="settings" class="w-5 h-5 text-primary"></i> ' + t('codexlens.maintenance') + '</h4>' + '<h4 class="text-lg font-semibold mb-4 flex items-center gap-2"><i data-lucide="settings" class="w-5 h-5 text-primary"></i> ' + t('codexlens.maintenance') + '</h4>' +
@@ -2743,26 +2846,114 @@ function buildLiteLLMModelOptions() {
window.onEmbeddingBackendChange = onEmbeddingBackendChange; window.onEmbeddingBackendChange = onEmbeddingBackendChange;
/** /**
* Initialize index from page with selected model * Initialize index from page - uses env-based config
* Model/backend configured in Environment Variables section
*/ */
function initCodexLensIndexFromPage(indexType) { function initCodexLensIndexFromPage(indexType) {
var backendSelect = document.getElementById('pageBackendSelect'); // For FTS-only index, no embedding config needed
var modelSelect = document.getElementById('pageModelSelect');
var concurrencyInput = document.getElementById('pageConcurrencyInput');
var selectedBackend = backendSelect ? backendSelect.value : 'fastembed';
var selectedModel = modelSelect ? modelSelect.value : 'code';
var selectedConcurrency = concurrencyInput ? Math.max(1, parseInt(concurrencyInput.value, 10) || 4) : 4;
// For FTS-only index, model is not needed
if (indexType === 'normal') { if (indexType === 'normal') {
initCodexLensIndex(indexType); initCodexLensIndex(indexType);
} else { } else {
// Pass concurrency only for litellm backend // Use litellm backend with env-configured model (default 4 workers)
var maxWorkers = selectedBackend === 'litellm' ? selectedConcurrency : 1; // The CLI will read EMBEDDING_MODEL/LITELLM_MODEL from env
initCodexLensIndex(indexType, selectedModel, selectedBackend, maxWorkers); initCodexLensIndex(indexType, null, 'litellm', 4);
} }
} }
/**
* Run incremental update on the current workspace index
*/
window.runIncrementalUpdate = async function runIncrementalUpdate() {
var projectPath = window.CCW_PROJECT_ROOT || '.';
showRefreshToast('Starting incremental update...', 'info');
try {
var response = await fetch('/api/codexlens/update', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ path: projectPath })
});
var result = await response.json();
if (result.success) {
showRefreshToast('Incremental update completed', 'success');
} else {
showRefreshToast('Update failed: ' + (result.error || 'Unknown error'), 'error');
}
} catch (err) {
showRefreshToast('Update error: ' + err.message, 'error');
}
}
/**
* Toggle file watcher (watchdog) on/off
*/
window.toggleWatcher = async function toggleWatcher() {
console.log('[CodexLens] toggleWatcher called');
// Debug: uncomment to test if function is called
// alert('toggleWatcher called!');
var projectPath = window.CCW_PROJECT_ROOT || '.';
console.log('[CodexLens] Project path:', projectPath);
// Check current status first
try {
console.log('[CodexLens] Checking watcher status...');
var statusResponse = await fetch('/api/codexlens/watch/status');
var statusResult = await statusResponse.json();
console.log('[CodexLens] Status result:', statusResult);
var isRunning = statusResult.success && statusResult.running;
// Toggle: if running, stop; if stopped, start
var action = isRunning ? 'stop' : 'start';
console.log('[CodexLens] Action:', action);
var response = await fetch('/api/codexlens/watch/' + action, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ path: projectPath })
});
var result = await response.json();
console.log('[CodexLens] Action result:', result);
if (result.success) {
var newRunning = action === 'start';
updateWatcherUI(newRunning);
showRefreshToast('File watcher ' + (newRunning ? 'started' : 'stopped'), 'success');
} else {
showRefreshToast('Watcher ' + action + ' failed: ' + (result.error || 'Unknown error'), 'error');
}
} catch (err) {
console.error('[CodexLens] Watcher error:', err);
showRefreshToast('Watcher error: ' + err.message, 'error');
}
}
/**
* Update watcher UI state
*/
function updateWatcherUI(running) {
var statusBadge = document.getElementById('watcherStatusBadge');
if (statusBadge) {
var badgeClass = running ? 'bg-success/20 text-success' : 'bg-muted text-muted-foreground';
var badgeText = running ? 'Running' : 'Stopped';
var iconName = running ? 'pause' : 'play';
statusBadge.innerHTML =
'<span class="text-xs px-2 py-0.5 rounded-full ' + badgeClass + '">' + badgeText + '</span>' +
'<button class="btn-sm btn-outline" onclick="toggleWatcher()" id="watcherToggleBtn">' +
'<i data-lucide="' + iconName + '" class="w-3.5 h-3.5"></i>' +
'</button>';
if (window.lucide) lucide.createIcons();
}
}
// Make functions globally accessible
window.runIncrementalUpdate = runIncrementalUpdate;
window.toggleWatcher = toggleWatcher;
window.updateWatcherUI = updateWatcherUI;
/** /**
* Initialize CodexLens Manager page event handlers * Initialize CodexLens Manager page event handlers
*/ */

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
from .base import BaseReranker from .base import BaseReranker
from .factory import check_reranker_available, get_reranker from .factory import check_reranker_available, get_reranker
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
from .legacy import CrossEncoderReranker, check_cross_encoder_available from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
@@ -17,6 +18,8 @@ __all__ = [
"get_reranker", "get_reranker",
"CrossEncoderReranker", "CrossEncoderReranker",
"check_cross_encoder_available", "check_cross_encoder_available",
"FastEmbedReranker",
"check_fastembed_reranker_available",
"ONNXReranker", "ONNXReranker",
"check_onnx_reranker_available", "check_onnx_reranker_available",
] ]

View File

@@ -14,8 +14,9 @@ def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used. """Check whether a specific reranker backend can be used.
Notes: Notes:
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
- "onnx" redirects to "fastembed" for backward compatibility.
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]). - "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "onnx" uses Optimum + ONNX Runtime (pip install codexlens[reranker] or codexlens[reranker-onnx]).
- "api" uses a remote reranking HTTP API (requires httpx). - "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers. - "litellm" uses `ccw-litellm` for unified access to LLM providers.
""" """
@@ -26,10 +27,16 @@ def check_reranker_available(backend: str) -> tuple[bool, str | None]:
return check_cross_encoder_available() return check_cross_encoder_available()
if backend == "onnx": if backend == "fastembed":
from .onnx_reranker import check_onnx_reranker_available from .fastembed_reranker import check_fastembed_reranker_available
return check_onnx_reranker_available() return check_fastembed_reranker_available()
if backend == "onnx":
# Redirect to fastembed for backward compatibility
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "litellm": if backend == "litellm":
try: try:
@@ -54,12 +61,12 @@ def check_reranker_available(backend: str) -> tuple[bool, str | None]:
return False, ( return False, (
f"Invalid reranker backend: {backend}. " f"Invalid reranker backend: {backend}. "
"Must be 'onnx', 'api', 'litellm', or 'legacy'." "Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
) )
def get_reranker( def get_reranker(
backend: str = "onnx", backend: str = "fastembed",
model_name: str | None = None, model_name: str | None = None,
*, *,
device: str | None = None, device: str | None = None,
@@ -69,12 +76,14 @@ def get_reranker(
Args: Args:
backend: Reranker backend to use. Options: backend: Reranker backend to use. Options:
- "onnx": Optimum + onnxruntime backend (default) - "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
- "onnx": Redirects to fastembed for backward compatibility
- "api": HTTP API backend (remote providers) - "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, experimental) - "litellm": LiteLLM backend (LLM-based, for API mode)
- "legacy": sentence-transformers CrossEncoder backend (optional) - "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend: model_name: Model identifier for model-based backends. Defaults depend on backend:
- onnx: Xenova/ms-marco-MiniLM-L-6-v2 - fastembed: Xenova/ms-marco-MiniLM-L-6-v2
- onnx: (redirects to fastembed)
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow) - api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2 - legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default - litellm: default
@@ -90,16 +99,28 @@ def get_reranker(
""" """
backend = (backend or "").strip().lower() backend = (backend or "").strip().lower()
if backend == "onnx": if backend == "fastembed":
ok, err = check_reranker_available("onnx") ok, err = check_reranker_available("fastembed")
if not ok: if not ok:
raise ImportError(err) raise ImportError(err)
from .onnx_reranker import ONNXReranker from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or ONNXReranker.DEFAULT_MODEL resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via ONNX Runtime providers. _ = device # Device selection is managed via fastembed providers.
return ONNXReranker(model_name=resolved_model_name, **kwargs) return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "onnx":
# Redirect to fastembed for backward compatibility
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy": if backend == "legacy":
ok, err = check_reranker_available("legacy") ok, err = check_reranker_available("legacy")
@@ -134,5 +155,5 @@ def get_reranker(
return APIReranker(model_name=resolved_model_name, **kwargs) return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError( raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'onnx', 'api', 'litellm', 'legacy'" f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
) )

View File

@@ -0,0 +1,256 @@
"""FastEmbed-based reranker backend.
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
Install:
pip install fastembed>=0.4.0
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
"""Check whether fastembed reranker dependencies are available."""
try:
import fastembed # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
)
try:
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed TextCrossEncoder not available: {exc}. "
"Upgrade with: pip install fastembed>=0.4.0",
)
return True, None
class FastEmbedReranker(BaseReranker):
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
# Alternative models supported by fastembed:
# - "BAAI/bge-reranker-base"
# - "BAAI/bge-reranker-large"
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
cache_dir: str | None = None,
threads: int | None = None,
) -> None:
"""Initialize FastEmbed reranker.
Args:
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
use_gpu: Whether to use GPU acceleration when available.
cache_dir: Optional directory for caching downloaded models.
threads: Optional number of threads for ONNX Runtime.
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.cache_dir = cache_dir
self.threads = threads
self._encoder: Any | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
"""Lazy-load the TextCrossEncoder model."""
if self._encoder is not None:
return
ok, err = check_fastembed_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._encoder is not None:
return
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Determine providers based on GPU preference
providers: list[str] | None = None
if self.use_gpu:
try:
from ..gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
except Exception:
# Fallback: let fastembed decide
providers = None
# Build initialization kwargs
init_kwargs: dict[str, Any] = {}
if self.cache_dir:
init_kwargs["cache_dir"] = self.cache_dir
if self.threads is not None:
init_kwargs["threads"] = self.threads
if providers:
init_kwargs["providers"] = providers
logger.debug(
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
self.model_name,
self.use_gpu,
)
self._encoder = TextCrossEncoder(
model_name=self.model_name,
**init_kwargs,
)
logger.debug("FastEmbed reranker model loaded successfully")
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair), normalized to [0, 1] range.
"""
if not pairs:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
# For batch scoring of multiple query-doc pairs, we need to process them.
# Group by query for efficiency when same query appears multiple times.
query_to_docs: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
if query not in query_to_docs:
query_to_docs[query] = []
query_to_docs[query].append((idx, doc))
# Score each query group
scores: list[float] = [0.0] * len(pairs)
for query, indexed_docs in query_to_docs.items():
docs = [doc for _, doc in indexed_docs]
indices = [idx for idx, _ in indexed_docs]
try:
# TextCrossEncoder.rerank returns list of RerankResult with score attribute
results = list(
self._encoder.rerank(
query=query,
documents=docs,
batch_size=batch_size,
)
)
# Map scores back to original positions
# Results are returned in descending score order, but we need original order
for result in results:
# Each result has 'index' (position in input docs) and 'score'
doc_idx = result.index if hasattr(result, "index") else 0
score = result.score if hasattr(result, "score") else 0.0
if doc_idx < len(indices):
original_idx = indices[doc_idx]
# Normalize score to [0, 1] using sigmoid if needed
# FastEmbed typically returns scores in [0, 1] already
if score < 0 or score > 1:
import math
score = 1.0 / (1.0 + math.exp(-score))
scores[original_idx] = float(score)
except Exception as e:
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
# Leave scores as 0.0 for failed queries
return scores
def rerank(
self,
query: str,
documents: Sequence[str],
*,
top_k: int | None = None,
batch_size: int = 32,
) -> list[tuple[float, str, int]]:
"""Rerank documents for a single query.
This is a convenience method that provides results in ranked order.
Args:
query: The query string.
documents: List of documents to rerank.
top_k: Return only top K results. None returns all.
batch_size: Batch size for scoring.
Returns:
List of (score, document, original_index) tuples, sorted by score descending.
"""
if not documents:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
try:
results = list(
self._encoder.rerank(
query=query,
documents=list(documents),
batch_size=batch_size,
)
)
# Convert to our format: (score, document, original_index)
ranked = []
for result in results:
idx = result.index if hasattr(result, "index") else 0
score = result.score if hasattr(result, "score") else 0.0
doc = documents[idx] if idx < len(documents) else ""
ranked.append((float(score), doc, idx))
# Sort by score descending
ranked.sort(key=lambda x: x[0], reverse=True)
if top_k is not None and top_k > 0:
ranked = ranked[:top_k]
return ranked
except Exception as e:
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
return []