feat: 增强模型下载功能,支持 HuggingFace Hub 直接下载 ONNX 格式模型

This commit is contained in:
catlog22
2026-01-11 18:22:36 +08:00
parent 1e91fa9f9e
commit b77672dda4
3 changed files with 75 additions and 45 deletions

View File

@@ -916,7 +916,7 @@ select.cli-input {
.cli-textarea {
resize: vertical;
min-height: 4rem;
max-height: 12rem;
max-height: 20rem;
font-family: 'SF Mono', 'Consolas', 'Liberation Mono', monospace;
font-size: 0.8125rem;
}
@@ -2681,7 +2681,7 @@ select.cli-input {
display: flex;
position: relative;
min-height: 200px;
max-height: 350px;
max-height: min(450px, 50vh);
}
.json-line-numbers {

View File

@@ -2672,14 +2672,15 @@ async function loadModelList() {
'<i data-lucide="plus-circle" class="w-3 h-3"></i> Download Custom Model' +
'</div>' +
'<div class="flex gap-2">' +
'<input type="text" id="customModelInput" placeholder="e.g., BAAI/bge-small-en-v1.5" ' +
'<input type="text" id="customModelInput" placeholder="e.g., Xenova/bge-small-en-v1.5" ' +
'class="flex-1 text-xs px-2 py-1.5 border border-border rounded bg-background focus:border-primary focus:ring-1 focus:ring-primary outline-none" />' +
'<button onclick="downloadCustomModel()" class="text-xs px-3 py-1.5 bg-primary text-primary-foreground rounded hover:bg-primary/90">' +
'Download' +
'</button>' +
'</div>' +
'<div class="text-[10px] text-muted-foreground mt-2">' +
'Enter any HuggingFace model name compatible with FastEmbed' +
'<div class="text-[10px] text-muted-foreground mt-2 space-y-1">' +
'<div><span class="text-amber-500">⚠</span> Only <strong>ONNX-format</strong> models work with FastEmbed (e.g., Xenova/* models)</div>' +
'<div>PyTorch models (intfloat/*, sentence-transformers/*) will download but won\'t work with local embedding</div>' +
'</div>' +
'</div>';
} else {

View File

@@ -6,6 +6,12 @@ import shutil
from pathlib import Path
from typing import Dict, List, Optional
try:
from huggingface_hub import snapshot_download, list_repo_files
HUGGINGFACE_HUB_AVAILABLE = True
except ImportError:
HUGGINGFACE_HUB_AVAILABLE = False
try:
from fastembed import TextEmbedding
FASTEMBED_AVAILABLE = True
@@ -557,64 +563,85 @@ def download_model(profile: str, progress_callback: Optional[callable] = None) -
def download_custom_model(model_name: str, model_type: str = "embedding", progress_callback: Optional[callable] = None) -> Dict[str, any]:
"""Download a custom model by HuggingFace model name.
This allows users to download any HuggingFace model that is compatible
with fastembed (TextEmbedding or TextCrossEncoder).
This allows users to download any HuggingFace model directly from
HuggingFace Hub. The model will be placed in the standard cache
directory where it can be discovered by scan_discovered_models().
Note: Downloaded models may not be directly usable by FastEmbed unless
they are in ONNX format. This function is primarily for downloading
models that users want to use with other frameworks or custom code.
Args:
model_name: Full HuggingFace model name (e.g., "BAAI/bge-small-en-v1.5")
model_type: Type of model ("embedding" or "reranker")
model_name: Full HuggingFace model name (e.g., "intfloat/e5-small-v2")
model_type: Type of model ("embedding" or "reranker") - for metadata only
progress_callback: Optional callback function to report progress
Returns:
Result dictionary with success status
"""
if model_type == "embedding":
if not FASTEMBED_AVAILABLE:
return {
"success": False,
"error": "fastembed not installed. Install with: pip install codexlens[semantic]",
}
else:
if not RERANKER_AVAILABLE:
return {
"success": False,
"error": "fastembed reranker not available. Install with: pip install fastembed>=0.4.0",
}
if not HUGGINGFACE_HUB_AVAILABLE:
return {
"success": False,
"error": "huggingface_hub not installed. Install with: pip install huggingface_hub",
}
# Validate model name format (org/model-name)
if not model_name or "/" not in model_name:
return {
"success": False,
"error": "Invalid model name format. Expected: 'org/model-name' (e.g., 'BAAI/bge-small-en-v1.5')",
"error": "Invalid model name format. Expected: 'org/model-name' (e.g., 'intfloat/e5-small-v2')",
}
try:
cache_dir = get_cache_dir()
if progress_callback:
progress_callback(f"Downloading custom model {model_name}...")
if model_type == "reranker":
# Download reranker model
reranker = TextCrossEncoder(model_name=model_name, cache_dir=str(cache_dir))
progress_callback(f"Checking model format for {model_name}...")
# Check if model contains ONNX files before downloading
try:
files = list_repo_files(repo_id=model_name)
has_onnx = any(
f.endswith('.onnx') or
f.startswith('onnx/') or
'/onnx/' in f or
f == 'model.onnx'
for f in files
)
if not has_onnx:
return {
"success": False,
"error": f"Model '{model_name}' does not contain ONNX files. "
f"FastEmbed requires ONNX-format models. "
f"Try Xenova/* versions or check the recommended models list.",
"files_found": len(files),
"suggestion": "Use models from the 'Recommended Models' list, or search for ONNX versions (e.g., Xenova/*).",
}
if progress_callback:
progress_callback(f"Initializing reranker {model_name}...")
list(reranker.rerank("test query", ["test document"]))
else:
# Download embedding model
embedder = TextEmbedding(model_name=model_name, cache_dir=str(cache_dir))
progress_callback(f"ONNX format detected. Downloading {model_name}...")
except Exception as check_err:
# If we can't check, warn but allow download
if progress_callback:
progress_callback(f"Initializing {model_name}...")
list(embedder.embed(["test"]))
progress_callback(f"Could not verify format, proceeding with download...")
# Use huggingface_hub to download the model
# This downloads to the standard HuggingFace cache directory
local_path = snapshot_download(
repo_id=model_name,
cache_dir=str(cache_dir),
)
if progress_callback:
progress_callback(f"Custom model {model_name} downloaded successfully")
progress_callback(f"Model {model_name} downloaded successfully")
# Get cache info
sanitized_name = f"models--{model_name.replace('/', '--')}"
model_cache_path = cache_dir / sanitized_name
cache_size = 0
if model_cache_path.exists():
total_size = sum(
@@ -623,7 +650,7 @@ def download_custom_model(model_name: str, model_type: str = "embedding", progre
if f.is_file()
)
cache_size = round(total_size / (1024 * 1024), 1)
return {
"success": True,
"result": {
@@ -631,6 +658,8 @@ def download_custom_model(model_name: str, model_type: str = "embedding", progre
"model_type": model_type,
"cache_size_mb": cache_size,
"cache_path": str(model_cache_path),
"local_path": local_path,
"note": "Model downloaded. Note: Only ONNX-format models are compatible with FastEmbed.",
},
}