mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat: Enhance configuration management and embedding capabilities
- Added JSON-based settings management in Config class for embedding and LLM configurations. - Introduced methods to save and load settings from a JSON file. - Updated BaseEmbedder and its subclasses to include max_tokens property for better token management. - Enhanced chunking strategy to support recursive splitting of large symbols with improved overlap handling. - Implemented comprehensive tests for recursive splitting and chunking behavior. - Added CLI tools configuration management for better integration with external tools. - Introduced a new command for compacting session memory into structured text for recovery.
This commit is contained in:
@@ -4,8 +4,10 @@ import gc
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
@@ -79,6 +81,44 @@ def _generate_chunks_from_cursor(
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
|
||||
def _create_token_aware_batches(
|
||||
chunk_generator: Generator,
|
||||
max_tokens_per_batch: int = 8000,
|
||||
) -> Generator[List[Tuple], None, None]:
|
||||
"""Group chunks by total token count instead of fixed count.
|
||||
|
||||
Uses fast token estimation (len(content) // 4) for efficiency.
|
||||
Yields batches when approaching the token limit.
|
||||
|
||||
Args:
|
||||
chunk_generator: Generator yielding (chunk, file_path) tuples
|
||||
max_tokens_per_batch: Maximum tokens per batch (default: 8000)
|
||||
|
||||
Yields:
|
||||
List of (chunk, file_path) tuples representing a batch
|
||||
"""
|
||||
current_batch = []
|
||||
current_tokens = 0
|
||||
|
||||
for chunk, file_path in chunk_generator:
|
||||
# Fast token estimation: len(content) // 4
|
||||
chunk_tokens = len(chunk.content) // 4
|
||||
|
||||
# If adding this chunk would exceed limit and we have items, yield current batch
|
||||
if current_tokens + chunk_tokens > max_tokens_per_batch and current_batch:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
current_tokens = 0
|
||||
|
||||
# Add chunk to current batch
|
||||
current_batch.append((chunk, file_path))
|
||||
current_tokens += chunk_tokens
|
||||
|
||||
# Yield final batch if not empty
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
def _get_path_column(conn: sqlite3.Connection) -> str:
|
||||
"""Detect whether files table uses 'path' or 'full_path' column.
|
||||
|
||||
@@ -189,31 +229,69 @@ def check_index_embeddings(index_path: Path) -> Dict[str, any]:
|
||||
}
|
||||
|
||||
|
||||
def _get_embedding_defaults() -> tuple[str, str, bool]:
|
||||
"""Get default embedding settings from config.
|
||||
|
||||
Returns:
|
||||
Tuple of (backend, model, use_gpu)
|
||||
"""
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
config = Config.load()
|
||||
return config.embedding_backend, config.embedding_model, config.embedding_use_gpu
|
||||
except Exception:
|
||||
return "fastembed", "code", True
|
||||
|
||||
|
||||
def generate_embeddings(
|
||||
index_path: Path,
|
||||
embedding_backend: str = "fastembed",
|
||||
model_profile: str = "code",
|
||||
embedding_backend: Optional[str] = None,
|
||||
model_profile: Optional[str] = None,
|
||||
force: bool = False,
|
||||
chunk_size: int = 2000,
|
||||
overlap: int = 200,
|
||||
progress_callback: Optional[callable] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: int = 1,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for an index using memory-efficient batch processing.
|
||||
|
||||
This function processes files in small batches to keep memory usage under 2GB,
|
||||
regardless of the total project size.
|
||||
regardless of the total project size. Supports concurrent API calls for
|
||||
LiteLLM backend to improve throughput.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm)
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm).
|
||||
Defaults to config setting.
|
||||
model_profile: Model profile for fastembed (fast, code, multilingual, balanced)
|
||||
or model name for litellm (e.g., text-embedding-3-small)
|
||||
or model name for litellm (e.g., qwen3-embedding).
|
||||
Defaults to config setting.
|
||||
force: If True, regenerate even if embeddings exist
|
||||
chunk_size: Maximum chunk size in characters
|
||||
overlap: Overlap size in characters for sliding window chunking (default: 200)
|
||||
progress_callback: Optional callback for progress updates
|
||||
use_gpu: Whether to use GPU acceleration (fastembed only).
|
||||
Defaults to config setting.
|
||||
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
|
||||
Recommended: 2-4 for LiteLLM API backends.
|
||||
|
||||
Returns:
|
||||
Result dictionary with generation statistics
|
||||
"""
|
||||
# Get defaults from config if not specified
|
||||
default_backend, default_model, default_gpu = _get_embedding_defaults()
|
||||
|
||||
if embedding_backend is None:
|
||||
embedding_backend = default_backend
|
||||
if model_profile is None:
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -261,9 +339,9 @@ def generate_embeddings(
|
||||
|
||||
# Initialize embedder using factory (supports both fastembed and litellm)
|
||||
# For fastembed: model_profile is a profile name (fast/code/multilingual/balanced)
|
||||
# For litellm: model_profile is a model name (e.g., text-embedding-3-small)
|
||||
# For litellm: model_profile is a model name (e.g., qwen3-embedding)
|
||||
if embedding_backend == "fastembed":
|
||||
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=True)
|
||||
embedder = get_embedder_factory(backend="fastembed", profile=model_profile, use_gpu=use_gpu)
|
||||
elif embedding_backend == "litellm":
|
||||
embedder = get_embedder_factory(backend="litellm", model=model_profile)
|
||||
else:
|
||||
@@ -274,7 +352,11 @@ def generate_embeddings(
|
||||
|
||||
# skip_token_count=True: Use fast estimation (len/4) instead of expensive tiktoken
|
||||
# This significantly reduces CPU usage with minimal impact on metadata accuracy
|
||||
chunker = Chunker(config=ChunkConfig(max_chunk_size=chunk_size, skip_token_count=True))
|
||||
chunker = Chunker(config=ChunkConfig(
|
||||
max_chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
skip_token_count=True
|
||||
))
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"Using model: {embedder.model_name} ({embedder.embedding_dim} dimensions)")
|
||||
@@ -336,43 +418,105 @@ def generate_embeddings(
|
||||
cursor, chunker, path_column, FILE_BATCH_SIZE, failed_files
|
||||
)
|
||||
|
||||
# Determine max tokens per batch
|
||||
# Priority: explicit parameter > embedder.max_tokens > default 8000
|
||||
if max_tokens_per_batch is None:
|
||||
max_tokens_per_batch = getattr(embedder, 'max_tokens', 8000)
|
||||
|
||||
# Create token-aware batches or fall back to fixed-size batching
|
||||
if max_tokens_per_batch:
|
||||
batch_generator = _create_token_aware_batches(
|
||||
chunk_generator, max_tokens_per_batch
|
||||
)
|
||||
else:
|
||||
# Fallback to fixed-size batching for backward compatibility
|
||||
def fixed_size_batches():
|
||||
while True:
|
||||
batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
batch_generator = fixed_size_batches()
|
||||
|
||||
batch_number = 0
|
||||
files_seen = set()
|
||||
|
||||
while True:
|
||||
# Get a small batch of chunks from the generator (EMBEDDING_BATCH_SIZE at a time)
|
||||
chunk_batch = list(islice(chunk_generator, EMBEDDING_BATCH_SIZE))
|
||||
if not chunk_batch:
|
||||
break
|
||||
# Thread-safe counters for concurrent processing
|
||||
counter_lock = Lock()
|
||||
|
||||
batch_number += 1
|
||||
def process_batch(batch_data: Tuple[int, List[Tuple]]) -> Tuple[int, set, Optional[str]]:
|
||||
"""Process a single batch: generate embeddings and store.
|
||||
|
||||
# Track unique files for progress
|
||||
for _, file_path in chunk_batch:
|
||||
files_seen.add(file_path)
|
||||
Args:
|
||||
batch_data: Tuple of (batch_number, chunk_batch)
|
||||
|
||||
Returns:
|
||||
Tuple of (chunks_created, files_in_batch, error_message)
|
||||
"""
|
||||
batch_num, chunk_batch = batch_data
|
||||
batch_files = set()
|
||||
|
||||
# Generate embeddings directly to numpy (no tolist() conversion)
|
||||
try:
|
||||
# Track files in this batch
|
||||
for _, file_path in chunk_batch:
|
||||
batch_files.add(file_path)
|
||||
|
||||
# Generate embeddings
|
||||
batch_contents = [chunk.content for chunk, _ in chunk_batch]
|
||||
# Pass batch_size to fastembed for optimal GPU utilization
|
||||
embeddings_numpy = embedder.embed_to_numpy(batch_contents, batch_size=EMBEDDING_BATCH_SIZE)
|
||||
|
||||
# Use add_chunks_batch_numpy to avoid numpy->list->numpy roundtrip
|
||||
# Store embeddings (thread-safe via SQLite's serialized mode)
|
||||
vector_store.add_chunks_batch_numpy(chunk_batch, embeddings_numpy)
|
||||
|
||||
total_chunks_created += len(chunk_batch)
|
||||
total_files_processed = len(files_seen)
|
||||
|
||||
if progress_callback and batch_number % 10 == 0:
|
||||
progress_callback(f" Batch {batch_number}: {total_chunks_created} chunks, {total_files_processed} files")
|
||||
|
||||
# Cleanup intermediate data
|
||||
del batch_contents, embeddings_numpy, chunk_batch
|
||||
return len(chunk_batch), batch_files, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process embedding batch {batch_number}: {str(e)}")
|
||||
# Continue to next batch instead of failing entirely
|
||||
continue
|
||||
error_msg = f"Batch {batch_num}: {str(e)}"
|
||||
logger.error(f"Failed to process embedding batch {batch_num}: {str(e)}")
|
||||
return 0, batch_files, error_msg
|
||||
|
||||
# Collect batches for concurrent processing
|
||||
all_batches = []
|
||||
for chunk_batch in batch_generator:
|
||||
batch_number += 1
|
||||
all_batches.append((batch_number, chunk_batch))
|
||||
|
||||
# Process batches (sequential or concurrent based on max_workers)
|
||||
if max_workers <= 1:
|
||||
# Sequential processing (original behavior)
|
||||
for batch_num, chunk_batch in all_batches:
|
||||
chunks_created, batch_files, error = process_batch((batch_num, chunk_batch))
|
||||
files_seen.update(batch_files)
|
||||
total_chunks_created += chunks_created
|
||||
total_files_processed = len(files_seen)
|
||||
|
||||
if progress_callback and batch_num % 10 == 0:
|
||||
progress_callback(f" Batch {batch_num}: {total_chunks_created} chunks, {total_files_processed} files")
|
||||
else:
|
||||
# Concurrent processing for API backends
|
||||
if progress_callback:
|
||||
progress_callback(f"Processing {len(all_batches)} batches with {max_workers} concurrent workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(process_batch, batch): batch[0] for batch in all_batches}
|
||||
|
||||
completed = 0
|
||||
for future in as_completed(futures):
|
||||
batch_num = futures[future]
|
||||
try:
|
||||
chunks_created, batch_files, error = future.result()
|
||||
|
||||
with counter_lock:
|
||||
files_seen.update(batch_files)
|
||||
total_chunks_created += chunks_created
|
||||
total_files_processed = len(files_seen)
|
||||
completed += 1
|
||||
|
||||
if progress_callback and completed % 10 == 0:
|
||||
progress_callback(f" Completed {completed}/{len(all_batches)} batches: {total_chunks_created} chunks")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch {batch_num} raised exception: {str(e)}")
|
||||
|
||||
# Notify before ANN index finalization (happens when bulk_insert context exits)
|
||||
if progress_callback:
|
||||
@@ -445,26 +589,49 @@ def find_all_indexes(scan_dir: Path) -> List[Path]:
|
||||
|
||||
def generate_embeddings_recursive(
|
||||
index_root: Path,
|
||||
embedding_backend: str = "fastembed",
|
||||
model_profile: str = "code",
|
||||
embedding_backend: Optional[str] = None,
|
||||
model_profile: Optional[str] = None,
|
||||
force: bool = False,
|
||||
chunk_size: int = 2000,
|
||||
overlap: int = 200,
|
||||
progress_callback: Optional[callable] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
max_tokens_per_batch: Optional[int] = None,
|
||||
max_workers: int = 1,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for all index databases in a project recursively.
|
||||
|
||||
Args:
|
||||
index_root: Root index directory containing _index.db files
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm)
|
||||
embedding_backend: Embedding backend to use (fastembed or litellm).
|
||||
Defaults to config setting.
|
||||
model_profile: Model profile for fastembed (fast, code, multilingual, balanced)
|
||||
or model name for litellm (e.g., text-embedding-3-small)
|
||||
or model name for litellm (e.g., qwen3-embedding).
|
||||
Defaults to config setting.
|
||||
force: If True, regenerate even if embeddings exist
|
||||
chunk_size: Maximum chunk size in characters
|
||||
overlap: Overlap size in characters for sliding window chunking (default: 200)
|
||||
progress_callback: Optional callback for progress updates
|
||||
use_gpu: Whether to use GPU acceleration (fastembed only).
|
||||
Defaults to config setting.
|
||||
max_tokens_per_batch: Maximum tokens per batch for token-aware batching.
|
||||
If None, attempts to get from embedder.max_tokens,
|
||||
then falls back to 8000. If set, overrides automatic detection.
|
||||
max_workers: Maximum number of concurrent API calls (default: 1 for sequential).
|
||||
Recommended: 2-4 for LiteLLM API backends.
|
||||
|
||||
Returns:
|
||||
Aggregated result dictionary with generation statistics
|
||||
"""
|
||||
# Get defaults from config if not specified
|
||||
default_backend, default_model, default_gpu = _get_embedding_defaults()
|
||||
|
||||
if embedding_backend is None:
|
||||
embedding_backend = default_backend
|
||||
if model_profile is None:
|
||||
model_profile = default_model
|
||||
if use_gpu is None:
|
||||
use_gpu = default_gpu
|
||||
# Discover all _index.db files
|
||||
index_files = discover_all_index_dbs(index_root)
|
||||
|
||||
@@ -498,7 +665,11 @@ def generate_embeddings_recursive(
|
||||
model_profile=model_profile,
|
||||
force=force,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
progress_callback=None, # Don't cascade callbacks
|
||||
use_gpu=use_gpu,
|
||||
max_tokens_per_batch=max_tokens_per_batch,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
all_results.append({
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
@@ -14,6 +15,9 @@ from .errors import ConfigError
|
||||
# Workspace-local directory name
|
||||
WORKSPACE_DIR_NAME = ".codexlens"
|
||||
|
||||
# Settings file name
|
||||
SETTINGS_FILE_NAME = "settings.json"
|
||||
|
||||
|
||||
def _default_global_dir() -> Path:
|
||||
"""Get global CodexLens data directory."""
|
||||
@@ -89,6 +93,13 @@ class Config:
|
||||
# Hybrid chunker configuration
|
||||
hybrid_max_chunk_size: int = 2000 # Max characters per chunk before LLM refinement
|
||||
hybrid_llm_refinement: bool = False # Enable LLM-based semantic boundary refinement
|
||||
|
||||
# Embedding configuration
|
||||
embedding_backend: str = "fastembed" # "fastembed" (local) or "litellm" (API)
|
||||
embedding_model: str = "code" # For fastembed: profile (fast/code/multilingual/balanced)
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
try:
|
||||
self.data_dir = self.data_dir.expanduser().resolve()
|
||||
@@ -133,6 +144,67 @@ class Config:
|
||||
"""Get parsing rules for a specific language, falling back to defaults."""
|
||||
return {**self.parsing_rules.get("default", {}), **self.parsing_rules.get(language_id, {})}
|
||||
|
||||
@cached_property
|
||||
def settings_path(self) -> Path:
|
||||
"""Path to the settings file."""
|
||||
return self.data_dir / SETTINGS_FILE_NAME
|
||||
|
||||
def save_settings(self) -> None:
|
||||
"""Save embedding and other settings to file."""
|
||||
settings = {
|
||||
"embedding": {
|
||||
"backend": self.embedding_backend,
|
||||
"model": self.embedding_model,
|
||||
"use_gpu": self.embedding_use_gpu,
|
||||
},
|
||||
"llm": {
|
||||
"enabled": self.llm_enabled,
|
||||
"tool": self.llm_tool,
|
||||
"timeout_ms": self.llm_timeout_ms,
|
||||
"batch_size": self.llm_batch_size,
|
||||
},
|
||||
}
|
||||
with open(self.settings_path, "w", encoding="utf-8") as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
def load_settings(self) -> None:
|
||||
"""Load settings from file if exists."""
|
||||
if not self.settings_path.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.settings_path, "r", encoding="utf-8") as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Load embedding settings
|
||||
embedding = settings.get("embedding", {})
|
||||
if "backend" in embedding:
|
||||
self.embedding_backend = embedding["backend"]
|
||||
if "model" in embedding:
|
||||
self.embedding_model = embedding["model"]
|
||||
if "use_gpu" in embedding:
|
||||
self.embedding_use_gpu = embedding["use_gpu"]
|
||||
|
||||
# Load LLM settings
|
||||
llm = settings.get("llm", {})
|
||||
if "enabled" in llm:
|
||||
self.llm_enabled = llm["enabled"]
|
||||
if "tool" in llm:
|
||||
self.llm_tool = llm["tool"]
|
||||
if "timeout_ms" in llm:
|
||||
self.llm_timeout_ms = llm["timeout_ms"]
|
||||
if "batch_size" in llm:
|
||||
self.llm_batch_size = llm["batch_size"]
|
||||
except Exception:
|
||||
pass # Silently ignore errors
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "Config":
|
||||
"""Load config with settings from file."""
|
||||
config = cls()
|
||||
config.load_settings()
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceConfig:
|
||||
|
||||
@@ -38,6 +38,16 @@ class BaseEmbedder(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for embeddings.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Embed texts to numpy array.
|
||||
|
||||
@@ -39,7 +39,7 @@ from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
class ChunkConfig:
|
||||
"""Configuration for chunking strategies."""
|
||||
max_chunk_size: int = 1000 # Max characters per chunk
|
||||
overlap: int = 100 # Overlap for sliding window
|
||||
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
|
||||
@@ -80,6 +80,7 @@ class Chunker:
|
||||
"""Chunk code by extracted symbols (functions, classes).
|
||||
|
||||
Each symbol becomes one chunk with its full content.
|
||||
Large symbols exceeding max_chunk_size are recursively split using sliding window.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
@@ -101,27 +102,49 @@ class Chunker:
|
||||
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._estimate_token_count(chunk_content)
|
||||
# Check if symbol content exceeds max_chunk_size
|
||||
if len(chunk_content) > self.config.max_chunk_size:
|
||||
# Create line mapping for correct line number tracking
|
||||
line_mapping = list(range(start_line, end_line + 1))
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
# Use sliding window to split large symbol
|
||||
sub_chunks = self.chunk_sliding_window(
|
||||
chunk_content,
|
||||
file_path=file_path,
|
||||
language=language,
|
||||
line_mapping=line_mapping
|
||||
)
|
||||
|
||||
# Update sub_chunks with parent symbol metadata
|
||||
for sub_chunk in sub_chunks:
|
||||
sub_chunk.metadata["symbol_name"] = symbol.name
|
||||
sub_chunk.metadata["symbol_kind"] = symbol.kind
|
||||
sub_chunk.metadata["strategy"] = "symbol_split"
|
||||
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
|
||||
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._estimate_token_count(chunk_content)
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"token_count": token_count,
|
||||
}
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -165,6 +165,33 @@ class Embedder(BaseEmbedder):
|
||||
"""Get embedding dimension for current model."""
|
||||
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Get maximum token limit for current model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens based on model profile.
|
||||
- fast: 512 (lightweight, optimized for speed)
|
||||
- code: 8192 (code-optimized, larger context)
|
||||
- multilingual: 512 (standard multilingual model)
|
||||
- balanced: 512 (general purpose)
|
||||
"""
|
||||
# Determine profile from model name
|
||||
profile = None
|
||||
for prof, model in self.MODELS.items():
|
||||
if model == self._model_name:
|
||||
profile = prof
|
||||
break
|
||||
|
||||
# Return token limit based on profile
|
||||
if profile == "code":
|
||||
return 8192
|
||||
elif profile in ("fast", "multilingual", "balanced"):
|
||||
return 512
|
||||
else:
|
||||
# Default for unknown models
|
||||
return 512
|
||||
|
||||
@property
|
||||
def providers(self) -> List[str]:
|
||||
"""Get configured ONNX execution providers."""
|
||||
|
||||
@@ -63,11 +63,39 @@ class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||
"""
|
||||
return self._embedder.model_name
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for the embedding model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Inferred from model config or model name patterns.
|
||||
"""
|
||||
# Try to get from LiteLLM config first
|
||||
if hasattr(self._embedder, 'max_input_tokens') and self._embedder.max_input_tokens:
|
||||
return self._embedder.max_input_tokens
|
||||
|
||||
# Infer from model name
|
||||
model_name_lower = self.model_name.lower()
|
||||
|
||||
# Large models (8B or "large" in name)
|
||||
if '8b' in model_name_lower or 'large' in model_name_lower:
|
||||
return 32768
|
||||
|
||||
# OpenAI text-embedding-3-* models
|
||||
if 'text-embedding-3' in model_name_lower:
|
||||
return 8191
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts to numpy array using LiteLLMEmbedder.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments (ignored for LiteLLM backend).
|
||||
Accepts batch_size for API compatibility with fastembed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
@@ -76,4 +104,5 @@ class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
# LiteLLM handles batching internally, ignore batch_size parameter
|
||||
return self._embedder.embed(texts)
|
||||
|
||||
Reference in New Issue
Block a user