feat: 增强模型下载功能,支持 HuggingFace Hub 直接下载 ONNX 格式模型

This commit is contained in:
catlog22
2026-01-11 18:22:36 +08:00
parent 1e91fa9f9e
commit b77672dda4
3 changed files with 75 additions and 45 deletions

View File

@@ -916,7 +916,7 @@ select.cli-input {
.cli-textarea { .cli-textarea {
resize: vertical; resize: vertical;
min-height: 4rem; min-height: 4rem;
max-height: 12rem; max-height: 20rem;
font-family: 'SF Mono', 'Consolas', 'Liberation Mono', monospace; font-family: 'SF Mono', 'Consolas', 'Liberation Mono', monospace;
font-size: 0.8125rem; font-size: 0.8125rem;
} }
@@ -2681,7 +2681,7 @@ select.cli-input {
display: flex; display: flex;
position: relative; position: relative;
min-height: 200px; min-height: 200px;
max-height: 350px; max-height: min(450px, 50vh);
} }
.json-line-numbers { .json-line-numbers {

View File

@@ -2672,14 +2672,15 @@ async function loadModelList() {
'<i data-lucide="plus-circle" class="w-3 h-3"></i> Download Custom Model' + '<i data-lucide="plus-circle" class="w-3 h-3"></i> Download Custom Model' +
'</div>' + '</div>' +
'<div class="flex gap-2">' + '<div class="flex gap-2">' +
'<input type="text" id="customModelInput" placeholder="e.g., BAAI/bge-small-en-v1.5" ' + '<input type="text" id="customModelInput" placeholder="e.g., Xenova/bge-small-en-v1.5" ' +
'class="flex-1 text-xs px-2 py-1.5 border border-border rounded bg-background focus:border-primary focus:ring-1 focus:ring-primary outline-none" />' + 'class="flex-1 text-xs px-2 py-1.5 border border-border rounded bg-background focus:border-primary focus:ring-1 focus:ring-primary outline-none" />' +
'<button onclick="downloadCustomModel()" class="text-xs px-3 py-1.5 bg-primary text-primary-foreground rounded hover:bg-primary/90">' + '<button onclick="downloadCustomModel()" class="text-xs px-3 py-1.5 bg-primary text-primary-foreground rounded hover:bg-primary/90">' +
'Download' + 'Download' +
'</button>' + '</button>' +
'</div>' + '</div>' +
'<div class="text-[10px] text-muted-foreground mt-2">' + '<div class="text-[10px] text-muted-foreground mt-2 space-y-1">' +
'Enter any HuggingFace model name compatible with FastEmbed' + '<div><span class="text-amber-500">⚠</span> Only <strong>ONNX-format</strong> models work with FastEmbed (e.g., Xenova/* models)</div>' +
'<div>PyTorch models (intfloat/*, sentence-transformers/*) will download but won\'t work with local embedding</div>' +
'</div>' + '</div>' +
'</div>'; '</div>';
} else { } else {

View File

@@ -6,6 +6,12 @@ import shutil
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
try:
from huggingface_hub import snapshot_download, list_repo_files
HUGGINGFACE_HUB_AVAILABLE = True
except ImportError:
HUGGINGFACE_HUB_AVAILABLE = False
try: try:
from fastembed import TextEmbedding from fastembed import TextEmbedding
FASTEMBED_AVAILABLE = True FASTEMBED_AVAILABLE = True
@@ -557,64 +563,85 @@ def download_model(profile: str, progress_callback: Optional[callable] = None) -
def download_custom_model(model_name: str, model_type: str = "embedding", progress_callback: Optional[callable] = None) -> Dict[str, any]: def download_custom_model(model_name: str, model_type: str = "embedding", progress_callback: Optional[callable] = None) -> Dict[str, any]:
"""Download a custom model by HuggingFace model name. """Download a custom model by HuggingFace model name.
This allows users to download any HuggingFace model that is compatible This allows users to download any HuggingFace model directly from
with fastembed (TextEmbedding or TextCrossEncoder). HuggingFace Hub. The model will be placed in the standard cache
directory where it can be discovered by scan_discovered_models().
Note: Downloaded models may not be directly usable by FastEmbed unless
they are in ONNX format. This function is primarily for downloading
models that users want to use with other frameworks or custom code.
Args: Args:
model_name: Full HuggingFace model name (e.g., "BAAI/bge-small-en-v1.5") model_name: Full HuggingFace model name (e.g., "intfloat/e5-small-v2")
model_type: Type of model ("embedding" or "reranker") model_type: Type of model ("embedding" or "reranker") - for metadata only
progress_callback: Optional callback function to report progress progress_callback: Optional callback function to report progress
Returns: Returns:
Result dictionary with success status Result dictionary with success status
""" """
if model_type == "embedding": if not HUGGINGFACE_HUB_AVAILABLE:
if not FASTEMBED_AVAILABLE: return {
return { "success": False,
"success": False, "error": "huggingface_hub not installed. Install with: pip install huggingface_hub",
"error": "fastembed not installed. Install with: pip install codexlens[semantic]", }
}
else:
if not RERANKER_AVAILABLE:
return {
"success": False,
"error": "fastembed reranker not available. Install with: pip install fastembed>=0.4.0",
}
# Validate model name format (org/model-name) # Validate model name format (org/model-name)
if not model_name or "/" not in model_name: if not model_name or "/" not in model_name:
return { return {
"success": False, "success": False,
"error": "Invalid model name format. Expected: 'org/model-name' (e.g., 'BAAI/bge-small-en-v1.5')", "error": "Invalid model name format. Expected: 'org/model-name' (e.g., 'intfloat/e5-small-v2')",
} }
try: try:
cache_dir = get_cache_dir() cache_dir = get_cache_dir()
if progress_callback: if progress_callback:
progress_callback(f"Downloading custom model {model_name}...") progress_callback(f"Checking model format for {model_name}...")
if model_type == "reranker": # Check if model contains ONNX files before downloading
# Download reranker model try:
reranker = TextCrossEncoder(model_name=model_name, cache_dir=str(cache_dir)) files = list_repo_files(repo_id=model_name)
has_onnx = any(
f.endswith('.onnx') or
f.startswith('onnx/') or
'/onnx/' in f or
f == 'model.onnx'
for f in files
)
if not has_onnx:
return {
"success": False,
"error": f"Model '{model_name}' does not contain ONNX files. "
f"FastEmbed requires ONNX-format models. "
f"Try Xenova/* versions or check the recommended models list.",
"files_found": len(files),
"suggestion": "Use models from the 'Recommended Models' list, or search for ONNX versions (e.g., Xenova/*).",
}
if progress_callback: if progress_callback:
progress_callback(f"Initializing reranker {model_name}...") progress_callback(f"ONNX format detected. Downloading {model_name}...")
list(reranker.rerank("test query", ["test document"]))
else: except Exception as check_err:
# Download embedding model # If we can't check, warn but allow download
embedder = TextEmbedding(model_name=model_name, cache_dir=str(cache_dir))
if progress_callback: if progress_callback:
progress_callback(f"Initializing {model_name}...") progress_callback(f"Could not verify format, proceeding with download...")
list(embedder.embed(["test"]))
# Use huggingface_hub to download the model
# This downloads to the standard HuggingFace cache directory
local_path = snapshot_download(
repo_id=model_name,
cache_dir=str(cache_dir),
)
if progress_callback: if progress_callback:
progress_callback(f"Custom model {model_name} downloaded successfully") progress_callback(f"Model {model_name} downloaded successfully")
# Get cache info # Get cache info
sanitized_name = f"models--{model_name.replace('/', '--')}" sanitized_name = f"models--{model_name.replace('/', '--')}"
model_cache_path = cache_dir / sanitized_name model_cache_path = cache_dir / sanitized_name
cache_size = 0 cache_size = 0
if model_cache_path.exists(): if model_cache_path.exists():
total_size = sum( total_size = sum(
@@ -623,7 +650,7 @@ def download_custom_model(model_name: str, model_type: str = "embedding", progre
if f.is_file() if f.is_file()
) )
cache_size = round(total_size / (1024 * 1024), 1) cache_size = round(total_size / (1024 * 1024), 1)
return { return {
"success": True, "success": True,
"result": { "result": {
@@ -631,6 +658,8 @@ def download_custom_model(model_name: str, model_type: str = "embedding", progre
"model_type": model_type, "model_type": model_type,
"cache_size_mb": cache_size, "cache_size_mb": cache_size,
"cache_path": str(model_cache_path), "cache_path": str(model_cache_path),
"local_path": local_path,
"note": "Model downloaded. Note: Only ONNX-format models are compatible with FastEmbed.",
}, },
} }