feat: 添加对 LiteLLM 嵌入后端的支持,增强并发 API 调用能力

This commit is contained in:
catlog22
2025-12-24 22:20:13 +08:00
parent e3e61bcae9
commit 3c3ce55842
8 changed files with 412 additions and 141 deletions

View File

@@ -108,6 +108,7 @@ def init(
no_embeddings: bool = typer.Option(False, "--no-embeddings", help="Skip automatic embedding generation (if semantic deps installed)."),
embedding_backend: str = typer.Option("fastembed", "--embedding-backend", help="Embedding backend: fastembed (local) or litellm (remote API)."),
embedding_model: str = typer.Option("code", "--embedding-model", help="Embedding model: profile name for fastembed (fast/code/multilingual/balanced) or model name for litellm (e.g. text-embedding-3-small)."),
max_workers: int = typer.Option(1, "--max-workers", min=1, max=16, help="Max concurrent API calls for embedding generation. Recommended: 4-8 for litellm backend."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
@@ -165,31 +166,31 @@ def init(
"errors": len(build_result.errors),
}
if json_mode:
print_json(success=True, result=result)
else:
if not json_mode:
console.print(f"[green]OK[/green] Indexed [bold]{build_result.total_files}[/bold] files in [bold]{build_result.total_dirs}[/bold] directories")
console.print(f" Index root: {build_result.index_root}")
if build_result.errors:
console.print(f" [yellow]Warnings:[/yellow] {len(build_result.errors)} errors")
# Auto-generate embeddings if semantic search is available
# Auto-generate embeddings if the requested backend is available
if not no_embeddings:
try:
from codexlens.semantic import SEMANTIC_AVAILABLE
from codexlens.semantic import is_embedding_backend_available
from codexlens.cli.embedding_manager import generate_embeddings_recursive, get_embeddings_status
if SEMANTIC_AVAILABLE:
# Validate embedding backend
valid_backends = ["fastembed", "litellm"]
if embedding_backend not in valid_backends:
error_msg = f"Invalid embedding backend: {embedding_backend}. Must be one of: {', '.join(valid_backends)}"
if json_mode:
print_json(success=False, error=error_msg)
else:
console.print(f"[red]Error:[/red] {error_msg}")
raise typer.Exit(code=1)
# Validate embedding backend
valid_backends = ["fastembed", "litellm"]
if embedding_backend not in valid_backends:
error_msg = f"Invalid embedding backend: {embedding_backend}. Must be one of: {', '.join(valid_backends)}"
if json_mode:
print_json(success=False, error=error_msg)
else:
console.print(f"[red]Error:[/red] {error_msg}")
raise typer.Exit(code=1)
backend_available, backend_error = is_embedding_backend_available(embedding_backend)
if backend_available:
# Use the index root directory (not the _index.db file)
index_root = Path(build_result.index_root)
@@ -221,6 +222,7 @@ def init(
force=False, # Don't force regenerate during init
chunk_size=2000,
progress_callback=progress_update, # Always use callback
max_workers=max_workers,
)
if embed_result["success"]:
@@ -262,10 +264,10 @@ def init(
}
else:
if not json_mode and verbose:
console.print("[dim]Semantic search not available. Skipping embeddings.[/dim]")
console.print(f"[dim]Embedding backend '{embedding_backend}' not available. Skipping embeddings.[/dim]")
result["embeddings"] = {
"generated": False,
"error": "Semantic dependencies not installed",
"error": backend_error or "Embedding backend not available",
}
except Exception as e:
if not json_mode and verbose:
@@ -280,6 +282,10 @@ def init(
"error": "Skipped (--no-embeddings)",
}
# Output final JSON result with embeddings status
if json_mode:
print_json(success=True, result=result)
except StorageError as exc:
if json_mode:
print_json(success=False, error=f"Storage error: {exc}")
@@ -1971,9 +1977,12 @@ def embeddings_generate(
# Provide helpful hints
if "already has" in error_msg:
console.print("\n[dim]Use --force to regenerate existing embeddings[/dim]")
elif "Semantic search not available" in error_msg:
elif "fastembed not available" in error_msg or "Semantic search not available" in error_msg:
console.print("\n[dim]Install semantic dependencies:[/dim]")
console.print(" [cyan]pip install codexlens[semantic][/cyan]")
elif "ccw-litellm not available" in error_msg:
console.print("\n[dim]Install LiteLLM backend dependencies:[/dim]")
console.print(" [cyan]pip install ccw-litellm[/cyan]")
raise typer.Exit(code=1)

View File

@@ -10,13 +10,11 @@ from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple
try:
from codexlens.semantic import SEMANTIC_AVAILABLE
if SEMANTIC_AVAILABLE:
from codexlens.semantic.embedder import Embedder, get_embedder, clear_embedder_cache
from codexlens.semantic.vector_store import VectorStore
from codexlens.semantic.chunker import Chunker, ChunkConfig
from codexlens.semantic import SEMANTIC_AVAILABLE, is_embedding_backend_available
except ImportError:
SEMANTIC_AVAILABLE = False
def is_embedding_backend_available(_backend: str): # type: ignore[no-redef]
return False, "codexlens.semantic not available"
logger = logging.getLogger(__name__)
@@ -25,6 +23,15 @@ logger = logging.getLogger(__name__)
EMBEDDING_BATCH_SIZE = 256
def _cleanup_fastembed_resources() -> None:
"""Best-effort cleanup for fastembed/ONNX resources (no-op for other backends)."""
try:
from codexlens.semantic.embedder import clear_embedder_cache
clear_embedder_cache()
except Exception:
pass
def _generate_chunks_from_cursor(
cursor,
chunker,
@@ -252,7 +259,7 @@ def generate_embeddings(
progress_callback: Optional[callable] = None,
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: int = 1,
max_workers: Optional[int] = None,
) -> Dict[str, any]:
"""Generate embeddings for an index using memory-efficient batch processing.
@@ -276,8 +283,9 @@ def generate_embeddings(
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
Recommended: 2-4 for LiteLLM API backends.
max_workers: Maximum number of concurrent API calls.
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
4 for litellm (network I/O bound).
Returns:
Result dictionary with generation statistics
@@ -291,11 +299,19 @@ def generate_embeddings(
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
if not SEMANTIC_AVAILABLE:
return {
"success": False,
"error": "Semantic search not available. Install with: pip install codexlens[semantic]",
}
# Set dynamic max_workers default based on backend type
# - FastEmbed: CPU-bound, sequential is optimal (1 worker)
# - LiteLLM: Network I/O bound, concurrent calls improve throughput (4 workers)
if max_workers is None:
if embedding_backend == "litellm":
max_workers = 4
else:
max_workers = 1
backend_available, backend_error = is_embedding_backend_available(embedding_backend)
if not backend_available:
return {"success": False, "error": backend_error or "Embedding backend not available"}
if not index_path.exists():
return {
@@ -335,6 +351,8 @@ def generate_embeddings(
try:
# Import factory function to support both backends
from codexlens.semantic.factory import get_embedder as get_embedder_factory
from codexlens.semantic.vector_store import VectorStore
from codexlens.semantic.chunker import Chunker, ChunkConfig
# Initialize embedder using factory (supports both fastembed and litellm)
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
@@ -441,7 +459,7 @@ def generate_embeddings(
files_seen = set()
def compute_embeddings_only(batch_data: Tuple[int, List[Tuple]]):
"""Compute embeddings for a batch (no DB write).
"""Compute embeddings for a batch (no DB write) with retry logic.
Args:
batch_data: Tuple of (batch_number, chunk_batch)
@@ -449,22 +467,43 @@ def generate_embeddings(
Returns:
Tuple of (batch_num, chunk_batch, embeddings_numpy, batch_files, error)
"""
import random
batch_num, chunk_batch = batch_data
batch_files = set()
for _, file_path in chunk_batch:
batch_files.add(file_path)
try:
for _, file_path in chunk_batch:
batch_files.add(file_path)
max_retries = 3
base_delay = 1.0
batch_contents = [chunk.content for chunk, _ in chunk_batch]
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
for attempt in range(max_retries + 1):
try:
batch_contents = [chunk.content for chunk, _ in chunk_batch]
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
return batch_num, chunk_batch, embeddings_numpy, batch_files, None
return batch_num, chunk_batch, embeddings_numpy, batch_files, None
except Exception as e:
error_str = str(e).lower()
# Check for retryable errors (rate limit, connection issues)
is_retryable = any(x in error_str for x in [
"429", "rate limit", "connection", "timeout",
"502", "503", "504", "service unavailable"
])
except Exception as e:
error_msg = f"Batch {batch_num}: {str(e)}"
logger.error(f"Failed to compute embeddings for batch {batch_num}: {str(e)}")
return batch_num, chunk_batch, None, batch_files, error_msg
if attempt < max_retries and is_retryable:
sleep_time = base_delay * (2 ** attempt) + random.uniform(0, 0.5)
logger.warning(f"Batch {batch_num} failed (attempt {attempt+1}/{max_retries+1}). "
f"Retrying in {sleep_time:.1f}s. Error: {e}")
time.sleep(sleep_time)
continue
error_msg = f"Batch {batch_num}: {str(e)}"
logger.error(f"Failed to compute embeddings for batch {batch_num}: {str(e)}")
return batch_num, chunk_batch, None, batch_files, error_msg
# Should not reach here, but just in case
return batch_num, chunk_batch, None, batch_files, f"Batch {batch_num}: Max retries exceeded"
# Process batches based on max_workers setting
if max_workers <= 1:
@@ -496,77 +535,74 @@ def generate_embeddings(
logger.error(f"Failed to process batch {batch_number}: {str(e)}")
files_seen.update(batch_files)
else:
# Concurrent processing with producer-consumer pattern
# Workers compute embeddings (parallel), main thread writes to DB (serial)
from queue import Queue
from threading import Thread
result_queue = Queue(maxsize=max_workers * 2) # Bounded queue to limit memory
batch_counter = [0] # Mutable counter for producer thread
producer_done = [False]
def producer():
"""Submit batches to executor, put results in queue."""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
pending_futures = []
for chunk_batch in batch_generator:
batch_counter[0] += 1
batch_num = batch_counter[0]
# Submit compute task
future = executor.submit(compute_embeddings_only, (batch_num, chunk_batch))
pending_futures.append(future)
# Check for completed futures and add to queue
for f in list(pending_futures):
if f.done():
try:
result_queue.put(f.result())
except Exception as e:
logger.error(f"Future raised exception: {e}")
pending_futures.remove(f)
# Wait for remaining futures
for future in as_completed(pending_futures):
try:
result_queue.put(future.result())
except Exception as e:
logger.error(f"Future raised exception: {e}")
producer_done[0] = True
result_queue.put(None) # Sentinel to signal completion
# Start producer thread
producer_thread = Thread(target=producer, daemon=True)
producer_thread.start()
# Concurrent processing - main thread iterates batches (SQLite safe),
# workers compute embeddings (parallel), main thread writes to DB (serial)
if progress_callback:
progress_callback(f"Processing with {max_workers} concurrent embedding workers...")
# Consumer: main thread writes to DB (serial, no contention)
completed = 0
while True:
result = result_queue.get()
if result is None: # Sentinel
break
with ThreadPoolExecutor(max_workers=max_workers) as executor:
pending_futures = {} # future -> (batch_num, chunk_batch)
completed_batches = 0
last_reported_batch = 0
batch_num, chunk_batch, embeddings_numpy, batch_files, error = result
def process_completed_futures():
"""Process any completed futures and write to DB."""
nonlocal total_chunks_created, total_files_processed, completed_batches, last_reported_batch
done_futures = [f for f in pending_futures if f.done()]
for f in done_futures:
try:
batch_num, chunk_batch, embeddings_numpy, batch_files, error = f.result()
if embeddings_numpy is not None and error is None:
# Write to DB in main thread (no contention)
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
total_chunks_created += len(chunk_batch)
files_seen.update(batch_files)
total_files_processed = len(files_seen)
completed_batches += 1
except Exception as e:
logger.error(f"Future raised exception: {e}")
completed_batches += 1
del pending_futures[f]
if embeddings_numpy is not None and error is None:
# Write to DB in main thread (no contention)
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
total_chunks_created += len(chunk_batch)
# Report progress based on completed batches (every 5 batches)
if progress_callback and completed_batches >= last_reported_batch + 5:
progress_callback(f" Batch {completed_batches}: {total_chunks_created} chunks, {total_files_processed} files")
last_reported_batch = completed_batches
files_seen.update(batch_files)
total_files_processed = len(files_seen)
completed += 1
# Iterate batches in main thread (SQLite cursor is main-thread bound)
for chunk_batch in batch_generator:
batch_number += 1
if progress_callback and completed % 10 == 0:
progress_callback(f" Completed {completed} batches: {total_chunks_created} chunks")
# Submit compute task to worker pool
future = executor.submit(compute_embeddings_only, (batch_number, chunk_batch))
pending_futures[future] = batch_number
producer_thread.join()
batch_number = batch_counter[0]
# Process any completed futures to free memory and write to DB
process_completed_futures()
# Backpressure: wait if too many pending
while len(pending_futures) >= max_workers * 2:
process_completed_futures()
if len(pending_futures) >= max_workers * 2:
time.sleep(0.1) # time is imported at module level
# Wait for remaining futures
for future in as_completed(list(pending_futures.keys())):
try:
batch_num, chunk_batch, embeddings_numpy, batch_files, error = future.result()
if embeddings_numpy is not None and error is None:
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
total_chunks_created += len(chunk_batch)
files_seen.update(batch_files)
total_files_processed = len(files_seen)
completed_batches += 1
# Report progress for remaining batches
if progress_callback and completed_batches >= last_reported_batch + 5:
progress_callback(f" Batch {completed_batches}: {total_chunks_created} chunks, {total_files_processed} files")
last_reported_batch = completed_batches
except Exception as e:
logger.error(f"Future raised exception: {e}")
# Notify before ANN index finalization (happens when bulk_insert context exits)
if progress_callback:
@@ -575,7 +611,7 @@ def generate_embeddings(
except Exception as e:
# Cleanup on error to prevent process hanging
try:
clear_embedder_cache()
_cleanup_fastembed_resources()
gc.collect()
except Exception:
pass
@@ -586,7 +622,7 @@ def generate_embeddings(
# Final cleanup: release ONNX resources to allow process exit
# This is critical - without it, ONNX Runtime threads prevent Python from exiting
try:
clear_embedder_cache()
_cleanup_fastembed_resources()
gc.collect()
except Exception:
pass
@@ -647,7 +683,7 @@ def generate_embeddings_recursive(
progress_callback: Optional[callable] = None,
use_gpu: Optional[bool] = None,
max_tokens_per_batch: Optional[int] = None,
max_workers: int = 1,
max_workers: Optional[int] = None,
) -> Dict[str, any]:
"""Generate embeddings for all index databases in a project recursively.
@@ -667,8 +703,9 @@ def generate_embeddings_recursive(
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
If None, attempts to get from embedder.max_tokens,
then falls back to 8000. If set, overrides automatic detection.
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
Recommended: 2-4 for LiteLLM API backends.
max_workers: Maximum number of concurrent API calls.
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
4 for litellm (network I/O bound).
Returns:
Aggregated result dictionary with generation statistics
@@ -682,6 +719,14 @@ def generate_embeddings_recursive(
model_profile = default_model
if use_gpu is None:
use_gpu = default_gpu
# Set dynamic max_workers default based on backend type
if max_workers is None:
if embedding_backend == "litellm":
max_workers = 4
else:
max_workers = 1
# Discover all _index.db files
index_files = discover_all_index_dbs(index_root)
@@ -740,9 +785,8 @@ def generate_embeddings_recursive(
# Final cleanup after processing all indexes
# Each generate_embeddings() call does its own cleanup, but do a final one to be safe
try:
if SEMANTIC_AVAILABLE:
clear_embedder_cache()
gc.collect()
_cleanup_fastembed_resources()
gc.collect()
except Exception:
pass

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
SEMANTIC_AVAILABLE = False
SEMANTIC_BACKEND: str | None = None
GPU_AVAILABLE = False
LITELLM_AVAILABLE = False
_import_error: str | None = None
@@ -76,18 +77,40 @@ from .factory import get_embedder as get_embedder_factory
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
try:
import ccw_litellm # noqa: F401
from .litellm_embedder import LiteLLMEmbedderWrapper
_LITELLM_AVAILABLE = True
LITELLM_AVAILABLE = True
except ImportError:
LiteLLMEmbedderWrapper = None
_LITELLM_AVAILABLE = False
LITELLM_AVAILABLE = False
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific embedding backend can be used.
Notes:
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
- "litellm" requires ccw-litellm to be installed in the same environment.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
if SEMANTIC_AVAILABLE:
return True, None
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
if backend == "litellm":
if LITELLM_AVAILABLE:
return True, None
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
__all__ = [
"SEMANTIC_AVAILABLE",
"SEMANTIC_BACKEND",
"GPU_AVAILABLE",
"LITELLM_AVAILABLE",
"check_semantic_available",
"is_embedding_backend_available",
"check_gpu_available",
"BaseEmbedder",
"get_embedder_factory",