mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
feat: 添加对 LiteLLM 嵌入后端的支持,增强并发 API 调用能力
This commit is contained in:
@@ -10,13 +10,11 @@ from pathlib import Path
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE
|
||||
if SEMANTIC_AVAILABLE:
|
||||
from codexlens.semantic.embedder import Embedder, get_embedder, clear_embedder_cache
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
from codexlens.semantic.chunker import Chunker, ChunkConfig
|
||||
from codexlens.semantic import SEMANTIC_AVAILABLE, is_embedding_backend_available
|
||||
except ImportError:
|
||||
SEMANTIC_AVAILABLE = False
|
||||
def is_embedding_backend_available(_backend: str): # type: ignore[no-redef]
|
||||
return False, "codexlens.semantic not available"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +23,15 @@ logger = logging.getLogger(__name__)
|
||||
EMBEDDING_BATCH_SIZE = 256
|
||||
|
||||
|
||||
def _cleanup_fastembed_resources() -> None:
|
||||
"""Best-effort cleanup for fastembed/ONNX resources (no-op for other backends)."""
|
||||
try:
|
||||
from codexlens.semantic.embedder import clear_embedder_cache
|
||||
clear_embedder_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _generate_chunks_from_cursor(
|
||||
cursor,
|
||||
chunker,
|
||||
@@ -252,7 +259,7 @@ def generate_embeddings(
|
||||
progress_callback: Optional[callable] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: int = 1,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for an index using memory-efficient batch processing.
|
||||
|
||||
@@ -276,8 +283,9 @@ def generate_embeddings(
|
||||
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
|
||||
Recommended: 2-4 for LiteLLM API backends.
|
||||
max_workers: Maximum number of concurrent API calls.
|
||||
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
|
||||
4 for litellm (network I/O bound).
|
||||
|
||||
Returns:
|
||||
Result dictionary with generation statistics
|
||||
@@ -291,11 +299,19 @@ def generate_embeddings(
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Semantic search not available. Install with: pip install codexlens[semantic]",
|
||||
}
|
||||
|
||||
# Set dynamic max_workers default based on backend type
|
||||
# - FastEmbed: CPU-bound, sequential is optimal (1 worker)
|
||||
# - LiteLLM: Network I/O bound, concurrent calls improve throughput (4 workers)
|
||||
if max_workers is None:
|
||||
if embedding_backend == "litellm":
|
||||
max_workers = 4
|
||||
else:
|
||||
max_workers = 1
|
||||
|
||||
backend_available, backend_error = is_embedding_backend_available(embedding_backend)
|
||||
if not backend_available:
|
||||
return {"success": False, "error": backend_error or "Embedding backend not available"}
|
||||
|
||||
if not index_path.exists():
|
||||
return {
|
||||
@@ -335,6 +351,8 @@ def generate_embeddings(
|
||||
try:
|
||||
# Import factory function to support both backends
|
||||
from codexlens.semantic.factory import get_embedder as get_embedder_factory
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
from codexlens.semantic.chunker import Chunker, ChunkConfig
|
||||
|
||||
# Initialize embedder using factory (supports both fastembed and litellm)
|
||||
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
|
||||
@@ -441,7 +459,7 @@ def generate_embeddings(
|
||||
files_seen = set()
|
||||
|
||||
def compute_embeddings_only(batch_data: Tuple[int, List[Tuple]]):
|
||||
"""Compute embeddings for a batch (no DB write).
|
||||
"""Compute embeddings for a batch (no DB write) with retry logic.
|
||||
|
||||
Args:
|
||||
batch_data: Tuple of (batch_number, chunk_batch)
|
||||
@@ -449,22 +467,43 @@ def generate_embeddings(
|
||||
Returns:
|
||||
Tuple of (batch_num, chunk_batch, embeddings_numpy, batch_files, error)
|
||||
"""
|
||||
import random
|
||||
|
||||
batch_num, chunk_batch = batch_data
|
||||
batch_files = set()
|
||||
for _, file_path in chunk_batch:
|
||||
batch_files.add(file_path)
|
||||
|
||||
try:
|
||||
for _, file_path in chunk_batch:
|
||||
batch_files.add(file_path)
|
||||
max_retries = 3
|
||||
base_delay = 1.0
|
||||
|
||||
batch_contents = [chunk.content for chunk, _ in chunk_batch]
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
batch_contents = [chunk.content for chunk, _ in chunk_batch]
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
|
||||
return batch_num, chunk_batch, embeddings_numpy, batch_files, None
|
||||
|
||||
return batch_num, chunk_batch, embeddings_numpy, batch_files, None
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
# Check for retryable errors (rate limit, connection issues)
|
||||
is_retryable = any(x in error_str for x in [
|
||||
"429", "rate limit", "connection", "timeout",
|
||||
"502", "503", "504", "service unavailable"
|
||||
])
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Batch {batch_num}: {str(e)}"
|
||||
logger.error(f"Failed to compute embeddings for batch {batch_num}: {str(e)}")
|
||||
return batch_num, chunk_batch, None, batch_files, error_msg
|
||||
if attempt < max_retries and is_retryable:
|
||||
sleep_time = base_delay * (2 ** attempt) + random.uniform(0, 0.5)
|
||||
logger.warning(f"Batch {batch_num} failed (attempt {attempt+1}/{max_retries+1}). "
|
||||
f"Retrying in {sleep_time:.1f}s. Error: {e}")
|
||||
time.sleep(sleep_time)
|
||||
continue
|
||||
|
||||
error_msg = f"Batch {batch_num}: {str(e)}"
|
||||
logger.error(f"Failed to compute embeddings for batch {batch_num}: {str(e)}")
|
||||
return batch_num, chunk_batch, None, batch_files, error_msg
|
||||
|
||||
# Should not reach here, but just in case
|
||||
return batch_num, chunk_batch, None, batch_files, f"Batch {batch_num}: Max retries exceeded"
|
||||
|
||||
# Process batches based on max_workers setting
|
||||
if max_workers <= 1:
|
||||
@@ -496,77 +535,74 @@ def generate_embeddings(
|
||||
logger.error(f"Failed to process batch {batch_number}: {str(e)}")
|
||||
files_seen.update(batch_files)
|
||||
else:
|
||||
# Concurrent processing with producer-consumer pattern
|
||||
# Workers compute embeddings (parallel), main thread writes to DB (serial)
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
result_queue = Queue(maxsize=max_workers * 2) # Bounded queue to limit memory
|
||||
batch_counter = [0] # Mutable counter for producer thread
|
||||
producer_done = [False]
|
||||
|
||||
def producer():
|
||||
"""Submit batches to executor, put results in queue."""
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
pending_futures = []
|
||||
|
||||
for chunk_batch in batch_generator:
|
||||
batch_counter[0] += 1
|
||||
batch_num = batch_counter[0]
|
||||
|
||||
# Submit compute task
|
||||
future = executor.submit(compute_embeddings_only, (batch_num, chunk_batch))
|
||||
pending_futures.append(future)
|
||||
|
||||
# Check for completed futures and add to queue
|
||||
for f in list(pending_futures):
|
||||
if f.done():
|
||||
try:
|
||||
result_queue.put(f.result())
|
||||
except Exception as e:
|
||||
logger.error(f"Future raised exception: {e}")
|
||||
pending_futures.remove(f)
|
||||
|
||||
# Wait for remaining futures
|
||||
for future in as_completed(pending_futures):
|
||||
try:
|
||||
result_queue.put(future.result())
|
||||
except Exception as e:
|
||||
logger.error(f"Future raised exception: {e}")
|
||||
|
||||
producer_done[0] = True
|
||||
result_queue.put(None) # Sentinel to signal completion
|
||||
|
||||
# Start producer thread
|
||||
producer_thread = Thread(target=producer, daemon=True)
|
||||
producer_thread.start()
|
||||
|
||||
# Concurrent processing - main thread iterates batches (SQLite safe),
|
||||
# workers compute embeddings (parallel), main thread writes to DB (serial)
|
||||
if progress_callback:
|
||||
progress_callback(f"Processing with {max_workers} concurrent embedding workers...")
|
||||
|
||||
# Consumer: main thread writes to DB (serial, no contention)
|
||||
completed = 0
|
||||
while True:
|
||||
result = result_queue.get()
|
||||
if result is None: # Sentinel
|
||||
break
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
pending_futures = {} # future -> (batch_num, chunk_batch)
|
||||
completed_batches = 0
|
||||
last_reported_batch = 0
|
||||
|
||||
batch_num, chunk_batch, embeddings_numpy, batch_files, error = result
|
||||
def process_completed_futures():
|
||||
"""Process any completed futures and write to DB."""
|
||||
nonlocal total_chunks_created, total_files_processed, completed_batches, last_reported_batch
|
||||
done_futures = [f for f in pending_futures if f.done()]
|
||||
for f in done_futures:
|
||||
try:
|
||||
batch_num, chunk_batch, embeddings_numpy, batch_files, error = f.result()
|
||||
if embeddings_numpy is not None and error is None:
|
||||
# Write to DB in main thread (no contention)
|
||||
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
|
||||
total_chunks_created += len(chunk_batch)
|
||||
files_seen.update(batch_files)
|
||||
total_files_processed = len(files_seen)
|
||||
completed_batches += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Future raised exception: {e}")
|
||||
completed_batches += 1
|
||||
del pending_futures[f]
|
||||
|
||||
if embeddings_numpy is not None and error is None:
|
||||
# Write to DB in main thread (no contention)
|
||||
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
|
||||
total_chunks_created += len(chunk_batch)
|
||||
# Report progress based on completed batches (every 5 batches)
|
||||
if progress_callback and completed_batches >= last_reported_batch + 5:
|
||||
progress_callback(f" Batch {completed_batches}: {total_chunks_created} chunks, {total_files_processed} files")
|
||||
last_reported_batch = completed_batches
|
||||
|
||||
files_seen.update(batch_files)
|
||||
total_files_processed = len(files_seen)
|
||||
completed += 1
|
||||
# Iterate batches in main thread (SQLite cursor is main-thread bound)
|
||||
for chunk_batch in batch_generator:
|
||||
batch_number += 1
|
||||
|
||||
if progress_callback and completed % 10 == 0:
|
||||
progress_callback(f" Completed {completed} batches: {total_chunks_created} chunks")
|
||||
# Submit compute task to worker pool
|
||||
future = executor.submit(compute_embeddings_only, (batch_number, chunk_batch))
|
||||
pending_futures[future] = batch_number
|
||||
|
||||
producer_thread.join()
|
||||
batch_number = batch_counter[0]
|
||||
# Process any completed futures to free memory and write to DB
|
||||
process_completed_futures()
|
||||
|
||||
# Backpressure: wait if too many pending
|
||||
while len(pending_futures) >= max_workers * 2:
|
||||
process_completed_futures()
|
||||
if len(pending_futures) >= max_workers * 2:
|
||||
time.sleep(0.1) # time is imported at module level
|
||||
|
||||
# Wait for remaining futures
|
||||
for future in as_completed(list(pending_futures.keys())):
|
||||
try:
|
||||
batch_num, chunk_batch, embeddings_numpy, batch_files, error = future.result()
|
||||
if embeddings_numpy is not None and error is None:
|
||||
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
|
||||
total_chunks_created += len(chunk_batch)
|
||||
files_seen.update(batch_files)
|
||||
total_files_processed = len(files_seen)
|
||||
completed_batches += 1
|
||||
|
||||
# Report progress for remaining batches
|
||||
if progress_callback and completed_batches >= last_reported_batch + 5:
|
||||
progress_callback(f" Batch {completed_batches}: {total_chunks_created} chunks, {total_files_processed} files")
|
||||
last_reported_batch = completed_batches
|
||||
except Exception as e:
|
||||
logger.error(f"Future raised exception: {e}")
|
||||
|
||||
# Notify before ANN index finalization (happens when bulk_insert context exits)
|
||||
if progress_callback:
|
||||
@@ -575,7 +611,7 @@ def generate_embeddings(
|
||||
except Exception as e:
|
||||
# Cleanup on error to prevent process hanging
|
||||
try:
|
||||
clear_embedder_cache()
|
||||
_cleanup_fastembed_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -586,7 +622,7 @@ def generate_embeddings(
|
||||
# Final cleanup: release ONNX resources to allow process exit
|
||||
# This is critical - without it, ONNX Runtime threads prevent Python from exiting
|
||||
try:
|
||||
clear_embedder_cache()
|
||||
_cleanup_fastembed_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -647,7 +683,7 @@ def generate_embeddings_recursive(
|
||||
progress_callback: Optional[callable] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: int = 1,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for all index databases in a project recursively.
|
||||
|
||||
@@ -667,8 +703,9 @@ def generate_embeddings_recursive(
|
||||
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
|
||||
Recommended: 2-4 for LiteLLM API backends.
|
||||
max_workers: Maximum number of concurrent API calls.
|
||||
If None, uses dynamic defaults: 1 for fastembed (CPU bound),
|
||||
4 for litellm (network I/O bound).
|
||||
|
||||
Returns:
|
||||
Aggregated result dictionary with generation statistics
|
||||
@@ -682,6 +719,14 @@ def generate_embeddings_recursive(
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
|
||||
# Set dynamic max_workers default based on backend type
|
||||
if max_workers is None:
|
||||
if embedding_backend == "litellm":
|
||||
max_workers = 4
|
||||
else:
|
||||
max_workers = 1
|
||||
|
||||
# Discover all _index.db files
|
||||
index_files = discover_all_index_dbs(index_root)
|
||||
|
||||
@@ -740,9 +785,8 @@ def generate_embeddings_recursive(
|
||||
# Final cleanup after processing all indexes
|
||||
# Each generate_embeddings() call does its own cleanup, but do a final one to be safe
|
||||
try:
|
||||
if SEMANTIC_AVAILABLE:
|
||||
clear_embedder_cache()
|
||||
gc.collect()
|
||||
_cleanup_fastembed_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user