Refactor code structure and remove redundant changes

This commit is contained in:
catlog22
2026-01-24 14:47:47 +08:00
parent cf5fecd66d
commit f2b0a5bbc9
113 changed files with 43217 additions and 235 deletions

View File

@@ -0,0 +1,118 @@
"""Optional semantic search module for CodexLens.
Install with: pip install codexlens[semantic]
Uses fastembed (ONNX-based, lightweight ~200MB)
GPU Acceleration:
- Automatic GPU detection and usage when available
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
"""
from __future__ import annotations
SEMANTIC_AVAILABLE = False
SEMANTIC_BACKEND: str | None = None
GPU_AVAILABLE = False
LITELLM_AVAILABLE = False
_import_error: str | None = None
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
"""Detect if fastembed and GPU are available."""
try:
import numpy as np
except ImportError as e:
return False, None, False, f"numpy not available: {e}"
try:
from fastembed import TextEmbedding
except ImportError:
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
# Check GPU availability
gpu_available = False
try:
from .gpu_support import is_gpu_available
gpu_available = is_gpu_available()
except ImportError:
pass
return True, "fastembed", gpu_available, None
# Initialize on module load
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
def check_semantic_available() -> tuple[bool, str | None]:
"""Check if semantic search dependencies are available."""
return SEMANTIC_AVAILABLE, _import_error
def check_gpu_available() -> tuple[bool, str]:
"""Check if GPU acceleration is available.
Returns:
Tuple of (is_available, status_message)
"""
if not SEMANTIC_AVAILABLE:
return False, "Semantic search not available"
try:
from .gpu_support import is_gpu_available, get_gpu_summary
if is_gpu_available():
return True, get_gpu_summary()
return False, "No GPU detected (using CPU)"
except ImportError:
return False, "GPU support module not available"
# Export embedder components
# BaseEmbedder is always available (abstract base class)
from .base import BaseEmbedder
# Factory function for creating embedders
from .factory import get_embedder as get_embedder_factory
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
try:
import ccw_litellm # noqa: F401
from .litellm_embedder import LiteLLMEmbedderWrapper
LITELLM_AVAILABLE = True
except ImportError:
LiteLLMEmbedderWrapper = None
LITELLM_AVAILABLE = False
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific embedding backend can be used.
Notes:
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
- "litellm" requires ccw-litellm to be installed in the same environment.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
if SEMANTIC_AVAILABLE:
return True, None
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
if backend == "litellm":
if LITELLM_AVAILABLE:
return True, None
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
__all__ = [
"SEMANTIC_AVAILABLE",
"SEMANTIC_BACKEND",
"GPU_AVAILABLE",
"LITELLM_AVAILABLE",
"check_semantic_available",
"is_embedding_backend_available",
"check_gpu_available",
"BaseEmbedder",
"get_embedder_factory",
"LiteLLMEmbedderWrapper",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
"""Base class for embedders.
Defines the interface that all embedders must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable
import numpy as np
class BaseEmbedder(ABC):
"""Base class for all embedders.
All embedder implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
@abstractmethod
def embedding_dim(self) -> int:
"""Return embedding dimensions.
Returns:
int: Dimension of the embedding vectors.
"""
...
@property
@abstractmethod
def model_name(self) -> str:
"""Return model name.
Returns:
str: Name or identifier of the underlying model.
"""
...
@property
def max_tokens(self) -> int:
"""Return maximum token limit for embeddings.
Returns:
int: Maximum number of tokens that can be embedded at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
"""Embed texts to numpy array.
Args:
texts: Single text or iterable of texts to embed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
"""
...

View File

@@ -0,0 +1,821 @@
"""Code chunking strategies for semantic search.
This module provides various chunking strategies for breaking down source code
into semantic chunks suitable for embedding and search.
Lightweight Mode:
The ChunkConfig supports a `skip_token_count` option for performance optimization.
When enabled, token counting uses a fast character-based estimation (char/4)
instead of expensive tiktoken encoding.
Use cases for lightweight mode:
- Large-scale indexing where speed is critical
- Scenarios where approximate token counts are acceptable
- Memory-constrained environments
- Initial prototyping and development
Example:
# Default mode (accurate tiktoken encoding)
config = ChunkConfig()
chunker = Chunker(config)
# Lightweight mode (fast char/4 estimation)
config = ChunkConfig(skip_token_count=True)
chunker = Chunker(config)
chunks = chunker.chunk_file(content, symbols, path, language)
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SemanticChunk, Symbol
from codexlens.parsers.tokenizer import get_default_tokenizer
@dataclass
class ChunkConfig:
"""Configuration for chunking strategies."""
max_chunk_size: int = 1000 # Max characters per chunk
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
strip_comments: bool = True # Remove comments from chunk content for embedding
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
preserve_original: bool = True # Store original content in metadata when stripping
class CommentStripper:
"""Remove comments from source code while preserving structure."""
@staticmethod
def strip_python_comments(content: str) -> str:
"""Strip Python comments (# style) but preserve docstrings.
Args:
content: Python source code
Returns:
Code with comments removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
in_string = False
string_char = None
for line in lines:
new_line = []
i = 0
while i < len(line):
char = line[i]
# Handle string literals
if char in ('"', "'") and not in_string:
# Check for triple quotes
if line[i:i+3] in ('"""', "'''"):
in_string = True
string_char = line[i:i+3]
new_line.append(line[i:i+3])
i += 3
continue
else:
in_string = True
string_char = char
elif in_string:
if string_char and len(string_char) == 3:
if line[i:i+3] == string_char:
in_string = False
new_line.append(line[i:i+3])
i += 3
string_char = None
continue
elif char == string_char:
# Check for escape
if i > 0 and line[i-1] != '\\':
in_string = False
string_char = None
# Handle comments (only outside strings)
if char == '#' and not in_string:
# Rest of line is comment, skip it
new_line.append('\n' if line.endswith('\n') else '')
break
new_line.append(char)
i += 1
result_lines.append(''.join(new_line))
return ''.join(result_lines)
@staticmethod
def strip_c_style_comments(content: str) -> str:
"""Strip C-style comments (// and /* */) from code.
Args:
content: Source code with C-style comments
Returns:
Code with comments removed
"""
result = []
i = 0
in_string = False
string_char = None
in_multiline_comment = False
while i < len(content):
# Handle multi-line comment end
if in_multiline_comment:
if content[i:i+2] == '*/':
in_multiline_comment = False
i += 2
continue
i += 1
continue
char = content[i]
# Handle string literals
if char in ('"', "'", '`') and not in_string:
in_string = True
string_char = char
result.append(char)
i += 1
continue
elif in_string:
result.append(char)
if char == string_char and (i == 0 or content[i-1] != '\\'):
in_string = False
string_char = None
i += 1
continue
# Handle comments
if content[i:i+2] == '//':
# Single line comment - skip to end of line
while i < len(content) and content[i] != '\n':
i += 1
if i < len(content):
result.append('\n')
i += 1
continue
if content[i:i+2] == '/*':
in_multiline_comment = True
i += 2
continue
result.append(char)
i += 1
return ''.join(result)
@classmethod
def strip_comments(cls, content: str, language: str) -> str:
"""Strip comments based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with comments removed
"""
if language == "python":
return cls.strip_python_comments(content)
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
return cls.strip_c_style_comments(content)
return content
class DocstringStripper:
"""Remove docstrings from source code."""
@staticmethod
def strip_python_docstrings(content: str) -> str:
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
Args:
content: Python source code
Returns:
Code with docstrings removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
# Check for docstring start
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
# Single line docstring
if stripped.count(quote_type) >= 2:
# Skip this line (docstring)
i += 1
continue
# Multi-line docstring - skip until closing
i += 1
while i < len(lines):
if quote_type in lines[i]:
i += 1
break
i += 1
continue
result_lines.append(line)
i += 1
return ''.join(result_lines)
@staticmethod
def strip_jsdoc_comments(content: str) -> str:
"""Strip JSDoc comments (/** ... */) from code.
Args:
content: JavaScript/TypeScript source code
Returns:
Code with JSDoc comments removed
"""
result = []
i = 0
in_jsdoc = False
while i < len(content):
if in_jsdoc:
if content[i:i+2] == '*/':
in_jsdoc = False
i += 2
continue
i += 1
continue
# Check for JSDoc start (/** but not /*)
if content[i:i+3] == '/**':
in_jsdoc = True
i += 3
continue
result.append(content[i])
i += 1
return ''.join(result)
@classmethod
def strip_docstrings(cls, content: str, language: str) -> str:
"""Strip docstrings based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with docstrings removed
"""
if language == "python":
return cls.strip_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.strip_jsdoc_comments(content)
return content
class Chunker:
"""Chunk code files for semantic embedding."""
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
self._comment_stripper = CommentStripper()
self._docstring_stripper = DocstringStripper()
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
"""Process chunk content by stripping comments/docstrings if configured.
Args:
content: Original chunk content
language: Programming language
Returns:
Tuple of (processed_content, original_content_if_preserved)
"""
original = content if self.config.preserve_original else None
processed = content
if self.config.strip_comments:
processed = self._comment_stripper.strip_comments(processed, language)
if self.config.strip_docstrings:
processed = self._docstring_stripper.strip_docstrings(processed, language)
# If nothing changed, don't store original
if processed == content:
original = None
return processed, original
def _estimate_token_count(self, text: str) -> int:
"""Estimate token count based on config.
If skip_token_count is True, uses character-based estimation (char/4).
Otherwise, uses accurate tiktoken encoding.
Args:
text: Text to count tokens for
Returns:
Estimated token count
"""
if self.config.skip_token_count:
# Fast character-based estimation: ~4 chars per token
return max(1, len(text) // 4)
return self._tokenizer.count_tokens(text)
def chunk_by_symbol(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk code by extracted symbols (functions, classes).
Each symbol becomes one chunk with its full content.
Large symbols exceeding max_chunk_size are recursively split using sliding window.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
for symbol in symbols:
start_line, end_line = symbol.range
# Convert to 0-indexed
start_idx = max(0, start_line - 1)
end_idx = min(len(lines), end_line)
chunk_content = "".join(lines[start_idx:end_idx])
if len(chunk_content.strip()) < self.config.min_chunk_size:
continue
# Check if symbol content exceeds max_chunk_size
if len(chunk_content) > self.config.max_chunk_size:
# Create line mapping for correct line number tracking
line_mapping = list(range(start_line, end_line + 1))
# Use sliding window to split large symbol
sub_chunks = self.chunk_sliding_window(
chunk_content,
file_path=file_path,
language=language,
line_mapping=line_mapping
)
# Update sub_chunks with parent symbol metadata
for sub_chunk in sub_chunks:
sub_chunk.metadata["symbol_name"] = symbol.name
sub_chunk.metadata["symbol_kind"] = symbol.kind
sub_chunk.metadata["strategy"] = "symbol_split"
sub_chunk.metadata["chunk_type"] = "code"
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
chunks.extend(sub_chunks)
else:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._estimate_token_count(processed_content)
metadata = {
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=processed_content,
embedding=None,
metadata=metadata
))
return chunks
def chunk_sliding_window(
self,
content: str,
file_path: str | Path,
language: str,
line_mapping: Optional[List[int]] = None,
) -> List[SemanticChunk]:
"""Chunk code using sliding window approach.
Used for files without clear symbol boundaries or very long functions.
Args:
content: Source code content
file_path: Path to source file
language: Programming language
line_mapping: Optional list mapping content line indices to original line numbers
(1-indexed). If provided, line_mapping[i] is the original line number
for the i-th line in content.
"""
chunks: List[SemanticChunk] = []
lines = content.splitlines(keepends=True)
if not lines:
return chunks
# Calculate lines per chunk based on average line length
avg_line_len = len(content) / max(len(lines), 1)
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
# Ensure overlap is less than chunk size to prevent infinite loop
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
start = 0
chunk_idx = 0
while start < len(lines):
end = min(start + lines_per_chunk, len(lines))
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
# Move window forward
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1
start += step
continue
token_count = self._estimate_token_count(processed_content)
# Calculate correct line numbers
if line_mapping:
# Use line mapping to get original line numbers
start_line = line_mapping[start]
end_line = line_mapping[end - 1]
else:
# Default behavior: treat content as starting at line 1
start_line = start + 1
end_line = end
metadata = {
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=processed_content,
embedding=None,
metadata=metadata
))
chunk_idx += 1
# Move window, accounting for overlap
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1 # Failsafe to prevent infinite loop
start += step
# Break if we've reached the end
if end >= len(lines):
break
return chunks
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk a file using the best strategy.
Uses symbol-based chunking if symbols available,
falls back to sliding window for files without symbols.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
if symbols:
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
return self.chunk_sliding_window(content, file_path, language)
class DocstringExtractor:
"""Extract docstrings from source code."""
@staticmethod
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
"""Extract Python docstrings with their line ranges.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
docstrings: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
start_line = i + 1
if stripped.count(quote_type) >= 2:
docstring_content = line
end_line = i + 1
docstrings.append((docstring_content, start_line, end_line))
i += 1
continue
docstring_lines = [line]
i += 1
while i < len(lines):
docstring_lines.append(lines[i])
if quote_type in lines[i]:
break
i += 1
end_line = i + 1
docstring_content = "".join(docstring_lines)
docstrings.append((docstring_content, start_line, end_line))
i += 1
return docstrings
@staticmethod
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
"""Extract JSDoc comments with their line ranges.
Returns: List of (comment_content, start_line, end_line) tuples
"""
comments: List[Tuple[str, int, int]] = []
lines = content.splitlines(keepends=True)
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
if stripped.startswith('/**'):
start_line = i + 1
comment_lines = [line]
i += 1
while i < len(lines):
comment_lines.append(lines[i])
if '*/' in lines[i]:
break
i += 1
end_line = i + 1
comment_content = "".join(comment_lines)
comments.append((comment_content, start_line, end_line))
i += 1
return comments
@classmethod
def extract_docstrings(
cls,
content: str,
language: str
) -> List[Tuple[str, int, int]]:
"""Extract docstrings based on language.
Returns: List of (docstring_content, start_line, end_line) tuples
"""
if language == "python":
return cls.extract_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.extract_jsdoc_comments(content)
return []
class HybridChunker:
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
Composition-based strategy that:
1. Extracts docstrings as dedicated chunks
2. For remaining code, uses base chunker (symbol or sliding window)
"""
def __init__(
self,
base_chunker: Chunker | None = None,
config: ChunkConfig | None = None
) -> None:
"""Initialize hybrid chunker.
Args:
base_chunker: Chunker to use for non-docstring content
config: Configuration for chunking
"""
self.config = config or ChunkConfig()
self.base_chunker = base_chunker or Chunker(self.config)
self.docstring_extractor = DocstringExtractor()
def _get_excluded_line_ranges(
self,
docstrings: List[Tuple[str, int, int]]
) -> set[int]:
"""Get set of line numbers that are part of docstrings."""
excluded_lines: set[int] = set()
for _, start_line, end_line in docstrings:
for line_num in range(start_line, end_line + 1):
excluded_lines.add(line_num)
return excluded_lines
def _filter_symbols_outside_docstrings(
self,
symbols: List[Symbol],
excluded_lines: set[int]
) -> List[Symbol]:
"""Filter symbols to exclude those completely within docstrings."""
filtered: List[Symbol] = []
for symbol in symbols:
start_line, end_line = symbol.range
symbol_lines = set(range(start_line, end_line + 1))
if not symbol_lines.issubset(excluded_lines):
filtered.append(symbol)
return filtered
def _find_parent_symbol(
self,
start_line: int,
end_line: int,
symbols: List[Symbol],
) -> Optional[Symbol]:
"""Find the smallest symbol range that fully contains a docstring span."""
candidates: List[Symbol] = []
for symbol in symbols:
sym_start, sym_end = symbol.range
if sym_start <= start_line and end_line <= sym_end:
candidates.append(symbol)
if not candidates:
return None
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
def chunk_file(
self,
content: str,
symbols: List[Symbol],
file_path: str | Path,
language: str,
symbol_token_counts: Optional[dict[str, int]] = None,
) -> List[SemanticChunk]:
"""Chunk file using hybrid strategy.
Extracts docstrings first, then chunks remaining code.
Args:
content: Source code content
symbols: List of extracted symbols
file_path: Path to source file
language: Programming language
symbol_token_counts: Optional dict mapping symbol names to token counts
"""
chunks: List[SemanticChunk] = []
# Step 1: Extract docstrings as dedicated chunks
docstrings: List[Tuple[str, int, int]] = []
if language == "python":
# Fast path: avoid expensive docstring extraction if delimiters are absent.
if '"""' in content or "'''" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
elif language in {"javascript", "typescript"}:
if "/**" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
else:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
# Fast path: no docstrings -> delegate to base chunker directly.
if not docstrings:
if symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, symbols, file_path, language, symbol_token_counts
)
else:
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
return base_chunks
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
# Use base chunker's token estimation method
token_count = self.base_chunker._estimate_token_count(docstring_content)
metadata = {
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
if parent_symbol is not None:
metadata["parent_symbol"] = parent_symbol.name
metadata["parent_symbol_kind"] = parent_symbol.kind
metadata["parent_symbol_range"] = parent_symbol.range
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata=metadata
))
# Step 2: Get line ranges occupied by docstrings
excluded_lines = self._get_excluded_line_ranges(docstrings)
# Step 3: Filter symbols to exclude docstring-only ranges
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
# Step 4: Chunk remaining content using base chunker
if filtered_symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, filtered_symbols, file_path, language, symbol_token_counts
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
else:
lines = content.splitlines(keepends=True)
remaining_lines: List[str] = []
for i, line in enumerate(lines, start=1):
if i not in excluded_lines:
remaining_lines.append(line)
if remaining_lines:
remaining_content = "".join(remaining_lines)
if len(remaining_content.strip()) >= self.config.min_chunk_size:
base_chunks = self.base_chunker.chunk_sliding_window(
remaining_content, file_path, language
)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,274 @@
"""Smart code extraction for complete code blocks."""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Tuple
from codexlens.entities import SearchResult, Symbol
def extract_complete_code_block(
result: SearchResult,
source_file_path: Optional[str] = None,
context_lines: int = 0,
) -> str:
"""Extract complete code block from a search result.
Args:
result: SearchResult from semantic search.
source_file_path: Optional path to source file for re-reading.
context_lines: Additional lines of context to include above/below.
Returns:
Complete code block as string.
"""
# If we have full content stored, use it
if result.content:
if context_lines == 0:
return result.content
# Need to add context, read from file
# Try to read from source file
file_path = source_file_path or result.path
if not file_path or not Path(file_path).exists():
# Fall back to excerpt
return result.excerpt or ""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
# Get line range
start_line = result.start_line or 1
end_line = result.end_line or len(lines)
# Add context
start_idx = max(0, start_line - 1 - context_lines)
end_idx = min(len(lines), end_line + context_lines)
return "\n".join(lines[start_idx:end_idx])
except Exception:
return result.excerpt or result.content or ""
def extract_symbol_with_context(
file_path: str,
symbol: Symbol,
include_docstring: bool = True,
include_decorators: bool = True,
) -> str:
"""Extract a symbol (function/class) with its docstring and decorators.
Args:
file_path: Path to source file.
symbol: Symbol to extract.
include_docstring: Include docstring if present.
include_decorators: Include decorators/annotations above symbol.
Returns:
Complete symbol code with context.
"""
try:
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
lines = content.splitlines()
start_line, end_line = symbol.range
start_idx = start_line - 1
end_idx = end_line
# Look for decorators above the symbol
if include_decorators and start_idx > 0:
decorator_start = start_idx
# Search backwards for decorators
i = start_idx - 1
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
line = lines[i].strip()
if line.startswith("@"):
decorator_start = i
i -= 1
elif line == "" or line.startswith("#"):
# Skip empty lines and comments, continue looking
i -= 1
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
# JavaScript/Java style comments
decorator_start = i
i -= 1
else:
# Found non-decorator, non-comment line, stop
break
start_idx = decorator_start
return "\n".join(lines[start_idx:end_idx])
except Exception:
return ""
def format_search_result_code(
result: SearchResult,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
highlight_match: bool = False,
) -> str:
"""Format search result code for display.
Args:
result: SearchResult to format.
max_lines: Maximum lines to show (None for all).
show_line_numbers: Include line numbers in output.
highlight_match: Add markers for matched region.
Returns:
Formatted code string.
"""
content = result.content or result.excerpt or ""
if not content:
return ""
lines = content.splitlines()
# Truncate if needed
truncated = False
if max_lines and len(lines) > max_lines:
lines = lines[:max_lines]
truncated = True
# Format with line numbers
if show_line_numbers:
start = result.start_line or 1
formatted_lines = []
for i, line in enumerate(lines):
line_num = start + i
formatted_lines.append(f"{line_num:4d} | {line}")
output = "\n".join(formatted_lines)
else:
output = "\n".join(lines)
if truncated:
output += "\n... (truncated)"
return output
def get_code_block_summary(result: SearchResult) -> str:
"""Get a concise summary of a code block.
Args:
result: SearchResult to summarize.
Returns:
Summary string like "function hello_world (lines 10-25)"
"""
parts = []
if result.symbol_kind:
parts.append(result.symbol_kind)
if result.symbol_name:
parts.append(f"`{result.symbol_name}`")
elif result.excerpt:
# Extract first meaningful identifier
first_line = result.excerpt.split("\n")[0][:50]
parts.append(f'"{first_line}..."')
if result.start_line and result.end_line:
if result.start_line == result.end_line:
parts.append(f"(line {result.start_line})")
else:
parts.append(f"(lines {result.start_line}-{result.end_line})")
if result.path:
file_name = Path(result.path).name
parts.append(f"in {file_name}")
return " ".join(parts) if parts else "unknown code block"
class CodeBlockResult:
"""Enhanced search result with complete code block."""
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
self.result = result
self.source_path = source_path or result.path
self._full_code: Optional[str] = None
@property
def score(self) -> float:
return self.result.score
@property
def path(self) -> str:
return self.result.path
@property
def file_name(self) -> str:
return Path(self.result.path).name
@property
def symbol_name(self) -> Optional[str]:
return self.result.symbol_name
@property
def symbol_kind(self) -> Optional[str]:
return self.result.symbol_kind
@property
def line_range(self) -> Tuple[int, int]:
return (
self.result.start_line or 1,
self.result.end_line or 1
)
@property
def full_code(self) -> str:
"""Get full code block content."""
if self._full_code is None:
self._full_code = extract_complete_code_block(self.result, self.source_path)
return self._full_code
@property
def excerpt(self) -> str:
"""Get short excerpt."""
return self.result.excerpt or ""
@property
def summary(self) -> str:
"""Get code block summary."""
return get_code_block_summary(self.result)
def format(
self,
max_lines: Optional[int] = None,
show_line_numbers: bool = True,
) -> str:
"""Format code for display."""
# Use full code if available
display_result = SearchResult(
path=self.result.path,
score=self.result.score,
content=self.full_code,
start_line=self.result.start_line,
end_line=self.result.end_line,
)
return format_search_result_code(
display_result,
max_lines=max_lines,
show_line_numbers=show_line_numbers
)
def __repr__(self) -> str:
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
def enhance_search_results(
results: List[SearchResult],
) -> List[CodeBlockResult]:
"""Enhance search results with complete code block access.
Args:
results: List of SearchResult from semantic search.
Returns:
List of CodeBlockResult with full code access.
"""
return [CodeBlockResult(r) for r in results]

View File

@@ -0,0 +1,288 @@
"""Embedder for semantic code search using fastembed.
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
GPU acceleration is automatic when available, with transparent CPU fallback.
"""
from __future__ import annotations
import gc
import logging
import threading
from typing import Dict, Iterable, List, Optional
import numpy as np
from . import SEMANTIC_AVAILABLE
from .base import BaseEmbedder
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
logger = logging.getLogger(__name__)
# Global embedder cache for singleton pattern
_embedder_cache: Dict[str, "Embedder"] = {}
_cache_lock = threading.RLock()
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
"""Get or create a cached Embedder instance (thread-safe singleton).
This function provides significant performance improvement by reusing
Embedder instances across multiple searches, avoiding repeated model
loading overhead (~0.8s per load).
Args:
profile: Model profile ("fast", "code", "multilingual", "balanced")
use_gpu: If True, use GPU acceleration when available (default: True)
Returns:
Cached Embedder instance for the given profile
"""
global _embedder_cache
# Cache key includes GPU preference to support mixed configurations
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
# All cache access is protected by _cache_lock to avoid races with
# clear_embedder_cache() during concurrent access.
with _cache_lock:
embedder = _embedder_cache.get(cache_key)
if embedder is not None:
return embedder
# Create new embedder and cache it
embedder = Embedder(profile=profile, use_gpu=use_gpu)
# Pre-load model to ensure it's ready
embedder._load_model()
_embedder_cache[cache_key] = embedder
# Log GPU status on first embedder creation
if use_gpu and is_gpu_available():
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
elif use_gpu:
logger.debug("GPU not available, using CPU for embeddings")
return embedder
def clear_embedder_cache() -> None:
"""Clear the embedder cache and release ONNX resources.
This method ensures proper cleanup of ONNX model resources to prevent
memory leaks when embedders are no longer needed.
"""
global _embedder_cache
with _cache_lock:
# Release ONNX resources before clearing cache
for embedder in _embedder_cache.values():
if embedder._model is not None:
del embedder._model
embedder._model = None
_embedder_cache.clear()
gc.collect()
class Embedder(BaseEmbedder):
"""Generate embeddings for code chunks using fastembed (ONNX-based).
Supported Model Profiles:
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
"""
# Model profiles for different use cases
MODELS = {
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
}
# Dimension mapping for each model
MODEL_DIMS = {
"BAAI/bge-small-en-v1.5": 384,
"jinaai/jina-embeddings-v2-base-code": 768,
"intfloat/multilingual-e5-large": 1024,
"mixedbread-ai/mxbai-embed-large-v1": 1024,
}
# Default model (fast profile)
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
DEFAULT_PROFILE = "fast"
def __init__(
self,
model_name: str | None = None,
profile: str | None = None,
use_gpu: bool = True,
providers: List[str] | None = None,
) -> None:
"""Initialize embedder with model or profile.
Args:
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
If both provided, model_name takes precedence.
use_gpu: If True, use GPU acceleration when available (default: True)
providers: Explicit ONNX providers list (overrides use_gpu if provided)
"""
if not SEMANTIC_AVAILABLE:
raise ImportError(
"Semantic search dependencies not available. "
"Install with: pip install codexlens[semantic]"
)
# Resolve model name from profile or use explicit name
if model_name:
self._model_name = model_name
elif profile and profile in self.MODELS:
self._model_name = self.MODELS[profile]
else:
self._model_name = self.DEFAULT_MODEL
# Configure ONNX execution providers with device_id options for GPU selection
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
if providers is not None:
self._providers = providers
else:
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
self._use_gpu = use_gpu
self._model = None
@property
def model_name(self) -> str:
"""Get model name."""
return self._model_name
@property
def embedding_dim(self) -> int:
"""Get embedding dimension for current model."""
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
@property
def max_tokens(self) -> int:
"""Get maximum token limit for current model.
Returns:
int: Maximum number of tokens based on model profile.
- fast: 512 (lightweight, optimized for speed)
- code: 8192 (code-optimized, larger context)
- multilingual: 512 (standard multilingual model)
- balanced: 512 (general purpose)
"""
# Determine profile from model name
profile = None
for prof, model in self.MODELS.items():
if model == self._model_name:
profile = prof
break
# Return token limit based on profile
if profile == "code":
return 8192
elif profile in ("fast", "multilingual", "balanced"):
return 512
else:
# Default for unknown models
return 512
@property
def providers(self) -> List[str]:
"""Get configured ONNX execution providers."""
return self._providers
@property
def is_gpu_enabled(self) -> bool:
"""Check if GPU acceleration is enabled for this embedder."""
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
# Handle both string providers and tuple providers (name, options)
for p in self._providers:
provider_name = p[0] if isinstance(p, tuple) else p
if provider_name in gpu_providers:
return True
return False
def _load_model(self) -> None:
"""Lazy load the embedding model with configured providers."""
if self._model is not None:
return
from fastembed import TextEmbedding
# providers already include device_id options via get_optimal_providers(with_device_options=True)
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
try:
self._model = TextEmbedding(
model_name=self.model_name,
providers=self._providers,
)
logger.debug(f"Model loaded with providers: {self._providers}")
except TypeError:
# Fallback for older fastembed versions without providers parameter
logger.warning(
"fastembed version doesn't support 'providers' parameter. "
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
)
self._model = TextEmbedding(model_name=self.model_name)
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
"""Generate embeddings for one or more texts.
Args:
texts: Single text or iterable of texts to embed.
Returns:
List of embedding vectors (each is a list of floats).
Note:
This method converts numpy arrays to Python lists for backward compatibility.
For memory-efficient processing, use embed_to_numpy() instead.
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
embeddings = list(self._model.embed(texts))
return [emb.tolist() for emb in embeddings]
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
"""Generate embeddings for one or more texts (returns numpy arrays).
This method is more memory-efficient than embed() as it avoids converting
numpy arrays to Python lists, which can significantly reduce memory usage
during batch processing.
Args:
texts: Single text or iterable of texts to embed.
batch_size: Optional batch size for fastembed processing.
Larger values improve GPU utilization but use more memory.
Returns:
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
"""
self._load_model()
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Pass batch_size to fastembed for optimal GPU utilization
# Default batch_size in fastembed is 256, but larger values can improve throughput
if batch_size is not None:
embeddings = list(self._model.embed(texts, batch_size=batch_size))
else:
embeddings = list(self._model.embed(texts))
return np.array(embeddings)
def embed_single(self, text: str) -> List[float]:
"""Generate embedding for a single text."""
return self.embed(text)[0]

View File

@@ -0,0 +1,158 @@
"""Factory for creating embedders.
Provides a unified interface for instantiating different embedder backends.
Includes caching to avoid repeated model loading overhead.
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Dict, List, Optional
from .base import BaseEmbedder
# Module-level cache for embedder instances
# Key: (backend, profile, model, use_gpu) -> embedder instance
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
_cache_lock = threading.Lock()
_logger = logging.getLogger(__name__)
def get_embedder(
backend: str = "fastembed",
profile: str = "code",
model: str = "default",
use_gpu: bool = True,
endpoints: Optional[List[Dict[str, Any]]] = None,
strategy: str = "latency_aware",
cooldown: float = 60.0,
**kwargs: Any,
) -> BaseEmbedder:
"""Factory function to create embedder based on backend.
Args:
backend: Embedder backend to use. Options:
- "fastembed": Use fastembed (ONNX-based) embedder (default)
- "litellm": Use ccw-litellm embedder
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
Used only when backend="fastembed". Default: "code"
model: Model identifier for litellm backend.
Used only when backend="litellm". Default: "default"
use_gpu: Whether to use GPU acceleration when available (default: True).
Used only when backend="fastembed".
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
Each endpoint is a dict with keys: model, api_key, api_base, weight.
Used only when backend="litellm" and multiple endpoints provided.
strategy: Selection strategy for multi-endpoint mode:
"round_robin", "latency_aware", "weighted_random".
Default: "latency_aware"
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
**kwargs: Additional backend-specific arguments
Returns:
BaseEmbedder: Configured embedder instance
Raises:
ValueError: If backend is not recognized
ImportError: If required backend dependencies are not installed
Examples:
Create fastembed embedder with code profile:
>>> embedder = get_embedder(backend="fastembed", profile="code")
Create fastembed embedder with fast profile and CPU only:
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
Create litellm embedder:
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
Create rotational embedder with multiple endpoints:
>>> endpoints = [
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
... ]
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
"""
# Build cache key from immutable configuration
if backend == "fastembed":
cache_key = ("fastembed", profile, None, use_gpu)
elif backend == "litellm":
# For litellm, use model as part of cache key
# Multi-endpoint mode is not cached as it's more complex
if endpoints and len(endpoints) > 1:
cache_key = None # Skip cache for multi-endpoint
else:
effective_model = endpoints[0]["model"] if endpoints else model
cache_key = ("litellm", None, effective_model, None)
else:
cache_key = None
# Check cache first (thread-safe)
if cache_key is not None:
with _cache_lock:
if cache_key in _embedder_cache:
_logger.debug("Returning cached embedder for %s", cache_key)
return _embedder_cache[cache_key]
# Create new embedder instance
embedder: Optional[BaseEmbedder] = None
if backend == "fastembed":
from .embedder import Embedder
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
elif backend == "litellm":
# Check if multi-endpoint mode is requested
if endpoints and len(endpoints) > 1:
from .rotational_embedder import create_rotational_embedder
# Multi-endpoint is not cached
return create_rotational_embedder(
endpoints_config=endpoints,
strategy=strategy,
default_cooldown=cooldown,
)
elif endpoints and len(endpoints) == 1:
# Single endpoint in list - use it directly
ep = endpoints[0]
ep_kwargs = {**kwargs}
if "api_key" in ep:
ep_kwargs["api_key"] = ep["api_key"]
if "api_base" in ep:
ep_kwargs["api_base"] = ep["api_base"]
from .litellm_embedder import LiteLLMEmbedderWrapper
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
else:
# No endpoints list - use model parameter
from .litellm_embedder import LiteLLMEmbedderWrapper
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
else:
raise ValueError(
f"Unknown backend: {backend}. "
f"Supported backends: 'fastembed', 'litellm'"
)
# Cache the embedder for future use (thread-safe)
if cache_key is not None and embedder is not None:
with _cache_lock:
# Double-check to avoid race condition
if cache_key not in _embedder_cache:
_embedder_cache[cache_key] = embedder
_logger.debug("Cached new embedder for %s", cache_key)
else:
# Another thread created it already, use that one
embedder = _embedder_cache[cache_key]
return embedder # type: ignore
def clear_embedder_cache() -> int:
"""Clear the embedder cache.
Returns:
Number of embedders cleared from cache
"""
with _cache_lock:
count = len(_embedder_cache)
_embedder_cache.clear()
_logger.debug("Cleared %d embedders from cache", count)
return count

View File

@@ -0,0 +1,431 @@
"""GPU acceleration support for semantic embeddings.
This module provides GPU detection, initialization, and fallback handling
for ONNX-based embedding generation.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional
logger = logging.getLogger(__name__)
@dataclass
class GPUDevice:
"""Individual GPU device info."""
device_id: int
name: str
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
vendor: str # "nvidia", "amd", "intel", "unknown"
@dataclass
class GPUInfo:
"""GPU availability and configuration info."""
gpu_available: bool = False
cuda_available: bool = False
gpu_count: int = 0
gpu_name: Optional[str] = None
onnx_providers: List[str] = None
devices: List[GPUDevice] = None # List of detected GPU devices
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
def __post_init__(self):
if self.onnx_providers is None:
self.onnx_providers = ["CPUExecutionProvider"]
if self.devices is None:
self.devices = []
_gpu_info_cache: Optional[GPUInfo] = None
def _enumerate_gpus() -> List[GPUDevice]:
"""Enumerate available GPU devices using WMI on Windows.
Returns:
List of GPUDevice with device info, ordered by device_id.
"""
devices = []
try:
import subprocess
import sys
if sys.platform == "win32":
# Use PowerShell to query GPU information via WMI
cmd = [
"powershell", "-NoProfile", "-Command",
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode == 0 and result.stdout.strip():
import json
gpu_data = json.loads(result.stdout)
# Handle single GPU case (returns dict instead of list)
if isinstance(gpu_data, dict):
gpu_data = [gpu_data]
for idx, gpu in enumerate(gpu_data):
name = gpu.get("Name", "Unknown GPU")
compat = gpu.get("AdapterCompatibility", "").lower()
# Determine vendor
name_lower = name.lower()
if "nvidia" in name_lower or "nvidia" in compat:
vendor = "nvidia"
is_discrete = True
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
vendor = "amd"
is_discrete = True
elif "intel" in name_lower or "intel" in compat:
vendor = "intel"
# Intel UHD/Iris are integrated, Intel Arc is discrete
is_discrete = "arc" in name_lower
else:
vendor = "unknown"
is_discrete = False
devices.append(GPUDevice(
device_id=idx,
name=name,
is_discrete=is_discrete,
vendor=vendor
))
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
except Exception as e:
logger.debug(f"GPU enumeration failed: {e}")
return devices
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
"""Determine the preferred GPU device_id for embedding.
Preference order:
1. NVIDIA discrete GPU (best DirectML/CUDA support)
2. AMD discrete GPU
3. Intel Arc (discrete)
4. Intel integrated (fallback)
Returns:
device_id of preferred GPU, or None to use default.
"""
if not devices:
return None
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
priority_order = [
("nvidia", True), # NVIDIA discrete
("amd", True), # AMD discrete
("intel", True), # Intel Arc (discrete)
("intel", False), # Intel integrated (fallback)
]
for target_vendor, target_discrete in priority_order:
for device in devices:
if device.vendor == target_vendor and device.is_discrete == target_discrete:
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
return device.device_id
# If no match, use first device
if devices:
return devices[0].device_id
return None
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
"""Detect available GPU resources for embedding acceleration.
Args:
force_refresh: If True, re-detect GPU even if cached.
Returns:
GPUInfo with detection results.
"""
global _gpu_info_cache
if _gpu_info_cache is not None and not force_refresh:
return _gpu_info_cache
info = GPUInfo()
# Enumerate GPU devices first
info.devices = _enumerate_gpus()
info.gpu_count = len(info.devices)
if info.devices:
# Set preferred device (discrete GPU preferred over integrated)
info.preferred_device_id = _get_preferred_device_id(info.devices)
# Set gpu_name to preferred device name
for dev in info.devices:
if dev.device_id == info.preferred_device_id:
info.gpu_name = dev.name
break
# Check PyTorch CUDA availability (most reliable detection)
try:
import torch
if torch.cuda.is_available():
info.cuda_available = True
info.gpu_available = True
info.gpu_count = torch.cuda.device_count()
if info.gpu_count > 0:
info.gpu_name = torch.cuda.get_device_name(0)
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
except ImportError:
logger.debug("PyTorch not available for GPU detection")
# Check ONNX Runtime providers with validation
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
# Build provider list with priority order
providers = []
# Test each provider to ensure it actually works
def test_provider(provider_name: str) -> bool:
"""Test if a provider actually works by creating a dummy session."""
try:
# Create a minimal ONNX model to test provider
import numpy as np
# Simple test: just check if provider can be instantiated
sess_options = ort.SessionOptions()
sess_options.log_severity_level = 4 # Suppress warnings
return True
except Exception:
return False
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
if "CUDAExecutionProvider" in available_providers:
# Verify CUDA is actually usable by checking for cuBLAS
cuda_works = False
try:
import ctypes
# Try to load cuBLAS to verify CUDA installation
try:
ctypes.CDLL("cublas64_12.dll")
cuda_works = True
except OSError:
try:
ctypes.CDLL("cublas64_11.dll")
cuda_works = True
except OSError:
pass
except Exception:
pass
if cuda_works:
providers.append("CUDAExecutionProvider")
info.gpu_available = True
logger.debug("ONNX CUDAExecutionProvider available and working")
else:
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
# TensorRT provider (optimized NVIDIA inference)
if "TensorrtExecutionProvider" in available_providers:
# TensorRT requires additional libraries, skip for now
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
if "DmlExecutionProvider" in available_providers:
providers.append("DmlExecutionProvider")
info.gpu_available = True
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
# ROCm provider (AMD GPU on Linux)
if "ROCMExecutionProvider" in available_providers:
providers.append("ROCMExecutionProvider")
info.gpu_available = True
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
# CoreML provider (Apple Silicon)
if "CoreMLExecutionProvider" in available_providers:
providers.append("CoreMLExecutionProvider")
info.gpu_available = True
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
# Always include CPU as fallback
providers.append("CPUExecutionProvider")
info.onnx_providers = providers
except ImportError:
logger.debug("ONNX Runtime not available")
info.onnx_providers = ["CPUExecutionProvider"]
_gpu_info_cache = info
return info
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
"""Get optimal ONNX execution providers based on availability.
Args:
use_gpu: If True, include GPU providers when available.
If False, force CPU-only execution.
with_device_options: If True, return providers as tuples with device_id options
for proper GPU device selection (required for DirectML).
Returns:
List of provider names or tuples (provider_name, options_dict) in priority order.
"""
if not use_gpu:
return ["CPUExecutionProvider"]
gpu_info = detect_gpu()
# Check if GPU was requested but not available - log warning
if not gpu_info.gpu_available:
try:
import onnxruntime as ort
available_providers = ort.get_available_providers()
except ImportError:
available_providers = []
logger.warning(
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
f"was found. Available providers: {available_providers}. Falling back to CPU."
)
else:
# Log which GPU provider is being used
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
if gpu_providers:
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
if not with_device_options:
return gpu_info.onnx_providers
# Build providers with device_id options for GPU providers
device_id = get_selected_device_id()
providers = []
for provider in gpu_info.onnx_providers:
if provider == "DmlExecutionProvider" and device_id is not None:
# DirectML requires device_id in provider_options tuple
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
elif provider == "CUDAExecutionProvider" and device_id is not None:
# CUDA also supports device_id in provider_options
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
elif provider == "ROCMExecutionProvider" and device_id is not None:
# ROCm supports device_id
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
else:
# CPU and other providers don't need device_id
providers.append(provider)
return providers
def is_gpu_available() -> bool:
"""Check if any GPU acceleration is available."""
return detect_gpu().gpu_available
def get_gpu_summary() -> str:
"""Get human-readable GPU status summary."""
info = detect_gpu()
if not info.gpu_available:
return "GPU: Not available (using CPU)"
parts = []
if info.gpu_name:
parts.append(f"GPU: {info.gpu_name}")
if info.gpu_count > 1:
parts.append(f"({info.gpu_count} devices)")
# Show active providers (excluding CPU fallback)
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
if gpu_providers:
parts.append(f"Providers: {', '.join(gpu_providers)}")
return " | ".join(parts) if parts else "GPU: Available"
def clear_gpu_cache() -> None:
"""Clear cached GPU detection info."""
global _gpu_info_cache
_gpu_info_cache = None
# User-selected device ID (overrides auto-detection)
_selected_device_id: Optional[int] = None
def get_gpu_devices() -> List[dict]:
"""Get list of available GPU devices for frontend selection.
Returns:
List of dicts with device info for each GPU.
"""
info = detect_gpu()
devices = []
for dev in info.devices:
devices.append({
"device_id": dev.device_id,
"name": dev.name,
"vendor": dev.vendor,
"is_discrete": dev.is_discrete,
"is_preferred": dev.device_id == info.preferred_device_id,
"is_selected": dev.device_id == get_selected_device_id(),
})
return devices
def get_selected_device_id() -> Optional[int]:
"""Get the user-selected GPU device_id.
Returns:
User-selected device_id, or auto-detected preferred device_id if not set.
"""
global _selected_device_id
if _selected_device_id is not None:
return _selected_device_id
# Fall back to auto-detected preferred device
info = detect_gpu()
return info.preferred_device_id
def set_selected_device_id(device_id: Optional[int]) -> bool:
"""Set the GPU device_id to use for embeddings.
Args:
device_id: GPU device_id to use, or None to use auto-detection.
Returns:
True if device_id is valid, False otherwise.
"""
global _selected_device_id
if device_id is None:
_selected_device_id = None
logger.info("GPU selection reset to auto-detection")
return True
# Validate device_id exists
info = detect_gpu()
valid_ids = [dev.device_id for dev in info.devices]
if device_id in valid_ids:
_selected_device_id = device_id
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
logger.info(f"GPU selection set to device {device_id}: {device_name}")
return True
else:
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
return False

View File

@@ -0,0 +1,144 @@
"""LiteLLM embedder wrapper for CodexLens.
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
"""
from __future__ import annotations
from typing import Iterable
import numpy as np
from .base import BaseEmbedder
class LiteLLMEmbedderWrapper(BaseEmbedder):
"""Wrapper for ccw-litellm LiteLLMEmbedder.
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
BaseEmbedder interface, enabling seamless integration with CodexLens
semantic search functionality.
Args:
model: Model identifier for LiteLLM (default: "default")
**kwargs: Additional arguments passed to LiteLLMEmbedder
Raises:
ImportError: If ccw-litellm package is not installed
"""
def __init__(self, model: str = "default", **kwargs) -> None:
"""Initialize LiteLLM embedder wrapper.
Args:
model: Model identifier for LiteLLM (default: "default")
**kwargs: Additional arguments passed to LiteLLMEmbedder
Raises:
ImportError: If ccw-litellm package is not installed
"""
try:
from ccw_litellm import LiteLLMEmbedder
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
except ImportError as e:
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from e
@property
def embedding_dim(self) -> int:
"""Return embedding dimensions from LiteLLMEmbedder.
Returns:
int: Dimension of the embedding vectors.
"""
return self._embedder.dimensions
@property
def model_name(self) -> str:
"""Return model name from LiteLLMEmbedder.
Returns:
str: Name or identifier of the underlying model.
"""
return self._embedder.model_name
@property
def max_tokens(self) -> int:
"""Return maximum token limit for the embedding model.
Returns:
int: Maximum number of tokens that can be embedded at once.
Reads from LiteLLM config's max_input_tokens property.
"""
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
if hasattr(self._embedder, 'max_input_tokens'):
return self._embedder.max_input_tokens
# Fallback: infer from model name
model_name_lower = self.model_name.lower()
# Large models (8B or "large" in name)
if '8b' in model_name_lower or 'large' in model_name_lower:
return 32768
# OpenAI text-embedding-3-* models
if 'text-embedding-3' in model_name_lower:
return 8191
# Default fallback
return 8192
def _sanitize_text(self, text: str) -> str:
"""Sanitize text to work around ModelScope API routing bug.
ModelScope incorrectly routes text starting with lowercase 'import'
to an Ollama endpoint, causing failures. This adds a leading space
to work around the issue without affecting embedding quality.
Args:
text: Text to sanitize.
Returns:
Sanitized text safe for embedding API.
"""
if text.startswith('import'):
return ' ' + text
return text
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts to numpy array using LiteLLMEmbedder.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments (ignored for LiteLLM backend).
Accepts batch_size for API compatibility with fastembed.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
# Sanitize texts to avoid ModelScope routing bug
texts = [self._sanitize_text(t) for t in texts]
# LiteLLM handles batching internally, ignore batch_size parameter
return self._embedder.embed(texts)
def embed_single(self, text: str) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text to embed.
Returns:
list[float]: Embedding vector as a list of floats.
"""
# Sanitize text before embedding
sanitized = self._sanitize_text(text)
embedding = self._embedder.embed([sanitized])
return embedding[0].tolist()

View File

@@ -0,0 +1,25 @@
"""Reranker backends for second-stage search ranking.
This subpackage provides a unified interface and factory for different reranking
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
"""
from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
__all__ = [
"BaseReranker",
"check_reranker_available",
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"FastEmbedReranker",
"check_fastembed_reranker_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -0,0 +1,403 @@
"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
max_input_tokens: int | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
# Store max_input_tokens with model-aware defaults
if max_input_tokens is not None:
self._max_input_tokens = max_input_tokens
else:
# Infer from model name
model_lower = self.model_name.lower()
if '8b' in model_lower or 'large' in model_lower:
self._max_input_tokens = 32768
else:
self._max_input_tokens = 8192
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking."""
return self._max_input_tokens
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count using fast heuristic.
Uses len(text) // 4 as approximation (~4 chars per token for English).
Not perfectly accurate for all models/languages but sufficient for
batch sizing decisions where exact counts aren't critical.
"""
return len(text) // 4
def _create_token_aware_batches(
self,
query: str,
documents: Sequence[str],
) -> list[list[tuple[int, str]]]:
"""Split documents into batches that fit within token limits.
Uses 90% of max_input_tokens as safety margin.
Each batch includes the query tokens overhead.
"""
max_tokens = int(self._max_input_tokens * 0.9)
query_tokens = self._estimate_tokens(query)
batches: list[list[tuple[int, str]]] = []
current_batch: list[tuple[int, str]] = []
current_tokens = query_tokens # Start with query overhead
for idx, doc in enumerate(documents):
doc_tokens = self._estimate_tokens(doc)
# Warn if single document exceeds token limit (will be truncated by API)
if doc_tokens > max_tokens - query_tokens:
logger.warning(
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
f"(limit: {max_tokens - query_tokens} after query overhead). "
"Document will likely be truncated by the API."
)
# If batch would exceed limit, start new batch
if current_tokens + doc_tokens > max_tokens and current_batch:
batches.append(current_batch)
current_batch = []
current_tokens = query_tokens
current_batch.append((idx, doc))
current_tokens += doc_tokens
if current_batch:
batches.append(current_batch)
return batches
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
# Create token-aware batches
batches = self._create_token_aware_batches(query, documents)
if len(batches) == 1:
# Single batch - original behavior
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
# Multiple batches - process each and merge results
logger.info(
f"Splitting {len(documents)} documents into {len(batches)} batches "
f"(max_input_tokens: {self._max_input_tokens})"
)
all_scores: list[float] = [0.0] * len(documents)
for batch in batches:
batch_docs = [doc for _, doc in batch]
payload = self._build_payload(query=query, documents=batch_docs)
data = self._request_json(payload)
results = data.get("results")
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
# Map scores back to original indices
for (orig_idx, _), score in zip(batch, batch_scores):
all_scores[orig_idx] = score
return all_scores
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores

View File

@@ -0,0 +1,46 @@
"""Base class for rerankers.
Defines the interface that all rerankers must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence
class BaseReranker(ABC):
"""Base class for all rerankers.
All reranker implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@property
def max_input_tokens(self) -> int:
"""Return maximum token limit for reranking.
Returns:
int: Maximum number of tokens that can be processed at once.
Default is 8192 if not overridden by implementation.
"""
return 8192
@abstractmethod
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair).
"""
...

View File

@@ -0,0 +1,159 @@
"""Factory for creating rerankers.
Provides a unified interface for instantiating different reranker backends.
"""
from __future__ import annotations
from typing import Any
from .base import BaseReranker
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used.
Notes:
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
- "onnx" redirects to "fastembed" for backward compatibility.
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
"""
backend = (backend or "").strip().lower()
if backend == "legacy":
from .legacy import check_cross_encoder_available
return check_cross_encoder_available()
if backend == "fastembed":
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "onnx":
# Redirect to fastembed for backward compatibility
from .fastembed_reranker import check_fastembed_reranker_available
return check_fastembed_reranker_available()
if backend == "litellm":
try:
import ccw_litellm # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
)
try:
from .litellm_reranker import LiteLLMReranker # noqa: F401
except Exception as exc: # pragma: no cover - defensive
return False, f"LiteLLM reranker backend not available: {exc}"
return True, None
if backend == "api":
from .api_reranker import check_httpx_available
return check_httpx_available()
return False, (
f"Invalid reranker backend: {backend}. "
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
)
def get_reranker(
backend: str = "fastembed",
model_name: str | None = None,
*,
device: str | None = None,
**kwargs: Any,
) -> BaseReranker:
"""Factory function to create reranker based on backend.
Args:
backend: Reranker backend to use. Options:
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
- "onnx": Redirects to fastembed for backward compatibility
- "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, for API mode)
- "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend:
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
- onnx: (redirects to fastembed)
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default
device: Optional device string for backends that support it (legacy only).
**kwargs: Additional backend-specific arguments.
Returns:
BaseReranker: Configured reranker instance.
Raises:
ValueError: If backend is not recognized.
ImportError: If required backend dependencies are not installed or backend is unavailable.
"""
backend = (backend or "").strip().lower()
if backend == "fastembed":
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "onnx":
# Redirect to fastembed for backward compatibility
ok, err = check_reranker_available("fastembed")
if not ok:
raise ImportError(err)
from .fastembed_reranker import FastEmbedReranker
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
_ = device # Device selection is managed via fastembed providers.
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy":
ok, err = check_reranker_available("legacy")
if not ok:
raise ImportError(err)
from .legacy import CrossEncoderReranker
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
if backend == "litellm":
ok, err = check_reranker_available("litellm")
if not ok:
raise ImportError(err)
from .litellm_reranker import LiteLLMReranker
_ = device # Device selection is not applicable to remote LLM backends.
resolved_model_name = (model_name or "").strip() or "default"
return LiteLLMReranker(model=resolved_model_name, **kwargs)
if backend == "api":
ok, err = check_reranker_available("api")
if not ok:
raise ImportError(err)
from .api_reranker import APIReranker
_ = device # Device selection is not applicable to remote HTTP backends.
resolved_model_name = (model_name or "").strip() or None
return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
)

View File

@@ -0,0 +1,257 @@
"""FastEmbed-based reranker backend.
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
Install:
pip install fastembed>=0.4.0
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
"""Check whether fastembed reranker dependencies are available."""
try:
import fastembed # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
)
try:
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"fastembed TextCrossEncoder not available: {exc}. "
"Upgrade with: pip install fastembed>=0.4.0",
)
return True, None
class FastEmbedReranker(BaseReranker):
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
# Alternative models supported by fastembed:
# - "BAAI/bge-reranker-base"
# - "BAAI/bge-reranker-large"
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
cache_dir: str | None = None,
threads: int | None = None,
) -> None:
"""Initialize FastEmbed reranker.
Args:
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
use_gpu: Whether to use GPU acceleration when available.
cache_dir: Optional directory for caching downloaded models.
threads: Optional number of threads for ONNX Runtime.
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.cache_dir = cache_dir
self.threads = threads
self._encoder: Any | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
"""Lazy-load the TextCrossEncoder model."""
if self._encoder is not None:
return
ok, err = check_fastembed_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._encoder is not None:
return
from fastembed.rerank.cross_encoder import TextCrossEncoder
# Determine providers based on GPU preference
providers: list[str] | None = None
if self.use_gpu:
try:
from ..gpu_support import get_optimal_providers
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
except Exception:
# Fallback: let fastembed decide
providers = None
# Build initialization kwargs
init_kwargs: dict[str, Any] = {}
if self.cache_dir:
init_kwargs["cache_dir"] = self.cache_dir
if self.threads is not None:
init_kwargs["threads"] = self.threads
if providers:
init_kwargs["providers"] = providers
logger.debug(
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
self.model_name,
self.use_gpu,
)
self._encoder = TextCrossEncoder(
model_name=self.model_name,
**init_kwargs,
)
logger.debug("FastEmbed reranker model loaded successfully")
@staticmethod
def _sigmoid(x: float) -> float:
"""Numerically stable sigmoid function."""
if x < -709:
return 0.0
if x > 709:
return 1.0
import math
return 1.0 / (1.0 + math.exp(-x))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair), normalized to [0, 1] range.
"""
if not pairs:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
# For batch scoring of multiple query-doc pairs, we need to process them.
# Group by query for efficiency when same query appears multiple times.
query_to_docs: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
if query not in query_to_docs:
query_to_docs[query] = []
query_to_docs[query].append((idx, doc))
# Score each query group
scores: list[float] = [0.0] * len(pairs)
for query, indexed_docs in query_to_docs.items():
docs = [doc for _, doc in indexed_docs]
indices = [idx for idx, _ in indexed_docs]
try:
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=docs,
batch_size=batch_size,
)
)
# Map scores back to original positions and normalize with sigmoid
for i, raw_score in enumerate(raw_scores):
if i < len(indices):
original_idx = indices[i]
# Normalize score to [0, 1] using stable sigmoid
scores[original_idx] = self._sigmoid(float(raw_score))
except Exception as e:
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
# Leave scores as 0.0 for failed queries
return scores
def rerank(
self,
query: str,
documents: Sequence[str],
*,
top_k: int | None = None,
batch_size: int = 32,
) -> list[tuple[float, str, int]]:
"""Rerank documents for a single query.
This is a convenience method that provides results in ranked order.
Args:
query: The query string.
documents: List of documents to rerank.
top_k: Return only top K results. None returns all.
batch_size: Batch size for scoring.
Returns:
List of (score, document, original_index) tuples, sorted by score descending.
"""
if not documents:
return []
self._load_model()
if self._encoder is None: # pragma: no cover - defensive
return []
try:
# TextCrossEncoder.rerank returns raw float scores in same order as input
raw_scores = list(
self._encoder.rerank(
query=query,
documents=list(documents),
batch_size=batch_size,
)
)
# Convert to our format: (normalized_score, document, original_index)
ranked = []
for idx, raw_score in enumerate(raw_scores):
if idx < len(documents):
# Normalize score to [0, 1] using stable sigmoid
normalized = self._sigmoid(float(raw_score))
ranked.append((normalized, documents[idx], idx))
# Sort by score descending
ranked.sort(key=lambda x: x[0], reverse=True)
if top_k is not None and top_k > 0:
ranked = ranked[:top_k]
return ranked
except Exception as e:
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
return []

View File

@@ -0,0 +1,91 @@
"""Legacy sentence-transformers cross-encoder reranker.
Install with: pip install codexlens[reranker-legacy]
"""
from __future__ import annotations
import logging
import threading
from typing import List, Sequence, Tuple
from .base import BaseReranker
logger = logging.getLogger(__name__)
try:
from sentence_transformers import CrossEncoder as _CrossEncoder
CROSS_ENCODER_AVAILABLE = True
_import_error: str | None = None
except ImportError as exc: # pragma: no cover - optional dependency
_CrossEncoder = None # type: ignore[assignment]
CROSS_ENCODER_AVAILABLE = False
_import_error = str(exc)
def check_cross_encoder_available() -> tuple[bool, str | None]:
if CROSS_ENCODER_AVAILABLE:
return True, None
return (
False,
_import_error
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
)
class CrossEncoderReranker(BaseReranker):
"""Cross-encoder reranker with lazy model loading."""
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = (model_name or "").strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.device = (device or "").strip() or None
self._model = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None:
return
ok, err = check_cross_encoder_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None:
return
try:
if self.device:
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
else:
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
except Exception as exc:
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
raise
def score_pairs(
self,
pairs: Sequence[Tuple[str, str]],
*,
batch_size: int = 32,
) -> List[float]:
"""Score (query, doc) pairs using the cross-encoder.
Returns:
List of scores (one per pair) in the model's native scale (usually logits).
"""
if not pairs:
return []
self._load_model()
if self._model is None: # pragma: no cover - defensive
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
return [float(s) for s in scores]

View File

@@ -0,0 +1,214 @@
"""Experimental LiteLLM reranker backend.
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
relevance of a single (query, document) pair per request.
Notes:
- This backend is experimental and may be slow/expensive compared to local
rerankers.
- It relies on `ccw-litellm` for a unified LLM API across providers.
"""
from __future__ import annotations
import json
import logging
import re
import threading
import time
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
def _coerce_score_to_unit_interval(score: float) -> float:
"""Coerce a numeric score into [0, 1].
The prompt asks for a float in [0, 1], but some models may respond with 0-10
or 0-100 scales. This function attempts a conservative normalization.
"""
if 0.0 <= score <= 1.0:
return score
if 0.0 <= score <= 10.0:
return score / 10.0
if 0.0 <= score <= 100.0:
return score / 100.0
return max(0.0, min(1.0, score))
def _extract_score(text: str) -> float | None:
"""Extract a numeric relevance score from an LLM response."""
content = (text or "").strip()
if not content:
return None
# Prefer JSON if present.
if "{" in content and "}" in content:
try:
start = content.index("{")
end = content.rindex("}") + 1
payload = json.loads(content[start:end])
if isinstance(payload, dict) and "score" in payload:
return float(payload["score"])
except Exception:
pass
match = _NUMBER_RE.search(content)
if not match:
return None
try:
return float(match.group(0))
except ValueError:
return None
class LiteLLMReranker(BaseReranker):
"""Experimental reranker that uses a LiteLLM-compatible model.
This reranker scores each (query, doc) pair in isolation (single-pair mode)
to improve prompt reliability across providers.
"""
_SYSTEM_PROMPT = (
"You are a relevance scoring assistant.\n"
"Given a search query and a document snippet, output a single numeric "
"relevance score between 0 and 1.\n\n"
"Scoring guidance:\n"
"- 1.0: The document directly answers the query.\n"
"- 0.5: The document is partially relevant.\n"
"- 0.0: The document is unrelated.\n\n"
"Output requirements:\n"
"- Output ONLY the number (e.g., 0.73).\n"
"- Do not include any other text."
)
def __init__(
self,
model: str = "default",
*,
requests_per_minute: float | None = None,
min_interval_seconds: float | None = None,
default_score: float = 0.0,
max_doc_chars: int = 8000,
**litellm_kwargs: Any,
) -> None:
"""Initialize the reranker.
Args:
model: Model name from ccw-litellm configuration (default: "default").
requests_per_minute: Optional rate limit in requests per minute.
min_interval_seconds: Optional minimum interval between requests. If set,
it takes precedence over requests_per_minute.
default_score: Score to use when an API call fails or parsing fails.
max_doc_chars: Maximum number of document characters to include in the prompt.
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
Raises:
ImportError: If ccw-litellm is not installed.
ValueError: If model is blank.
"""
self.model_name = (model or "").strip()
if not self.model_name:
raise ValueError("model cannot be blank")
self.default_score = float(default_score)
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
if min_interval_seconds is not None:
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
elif requests_per_minute is not None and float(requests_per_minute) > 0:
self._min_interval_seconds = 60.0 / float(requests_per_minute)
else:
self._min_interval_seconds = 0.0
# Prefer deterministic output by default; allow overrides via kwargs.
litellm_kwargs = dict(litellm_kwargs)
litellm_kwargs.setdefault("temperature", 0.0)
litellm_kwargs.setdefault("max_tokens", 16)
try:
from ccw_litellm import ChatMessage, LiteLLMClient
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from exc
self._ChatMessage = ChatMessage
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
self._lock = threading.RLock()
self._last_request_at = 0.0
def _sanitize_text(self, text: str) -> str:
# Keep consistent with LiteLLMEmbedderWrapper workaround.
if text.startswith("import"):
return " " + text
return text
def _rate_limit(self) -> None:
if self._min_interval_seconds <= 0:
return
with self._lock:
now = time.monotonic()
elapsed = now - self._last_request_at
if elapsed < self._min_interval_seconds:
time.sleep(self._min_interval_seconds - elapsed)
self._last_request_at = time.monotonic()
def _build_user_prompt(self, query: str, doc: str) -> str:
sanitized_query = self._sanitize_text(query or "")
sanitized_doc = self._sanitize_text(doc or "")
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
sanitized_doc = sanitized_doc[: self.max_doc_chars]
return (
"Query:\n"
f"{sanitized_query}\n\n"
"Document:\n"
f"{sanitized_doc}\n\n"
"Return the relevance score (0 to 1) as a single number:"
)
def _score_single_pair(self, query: str, doc: str) -> float:
messages = [
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
]
try:
self._rate_limit()
response = self._client.chat(messages)
except Exception as exc:
logger.debug("LiteLLM reranker request failed: %s", exc)
return self.default_score
raw = getattr(response, "content", "") or ""
score = _extract_score(raw)
if score is None:
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
return self.default_score
return _coerce_score_to_unit_interval(float(score))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with per-pair LLM calls."""
if not pairs:
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for i in range(0, len(pairs), bs):
batch = pairs[i : i + bs]
for query, doc in batch:
scores.append(self._score_single_pair(query, doc))
return scores

View File

@@ -0,0 +1,268 @@
"""Optimum + ONNX Runtime reranker backend.
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
classification models. It is designed to run without requiring PyTorch at
runtime by using numpy tensors and ONNX Runtime execution providers.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Iterable, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_onnx_reranker_available() -> tuple[bool, str | None]:
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
try:
import numpy # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
class ONNXReranker(BaseReranker):
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
providers: list[Any] | None = None,
max_length: int | None = None,
) -> None:
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.providers = providers
self.max_length = int(max_length) if max_length is not None else None
self._tokenizer: Any | None = None
self._model: Any | None = None
self._model_input_names: set[str] | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_onnx_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
if self.providers is None:
from ..gpu_support import get_optimal_providers
# Include device_id options for DirectML/CUDA selection when available.
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
)
# Some Optimum versions accept `providers`, others accept a single `provider`.
# Prefer passing the full providers list, with a conservative fallback.
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
model_kwargs = {}
try:
self._model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
**model_kwargs,
)
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments.
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache model input names to filter tokenizer outputs defensively.
input_names: set[str] | None = None
for attr in ("input_names", "model_input_names"):
names = getattr(self._model, attr, None)
if isinstance(names, (list, tuple)) and names:
input_names = {str(n) for n in names}
break
if input_names is None:
try:
session = getattr(self._model, "model", None)
if session is not None and hasattr(session, "get_inputs"):
input_names = {i.name for i in session.get_inputs()}
except Exception:
input_names = None
self._model_input_names = input_names
@staticmethod
def _sigmoid(x: "Any") -> "Any":
import numpy as np
x = np.clip(x, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _select_relevance_logit(logits: "Any") -> "Any":
import numpy as np
arr = np.asarray(logits)
if arr.ndim == 0:
return arr.reshape(1)
if arr.ndim == 1:
return arr
if arr.ndim >= 2:
# Common cases:
# - Regression: (batch, 1)
# - Binary classification: (batch, 2)
if arr.shape[-1] == 1:
return arr[..., 0]
if arr.shape[-1] == 2:
# Convert 2-logit softmax into a single logit via difference.
return arr[..., 1] - arr[..., 0]
return arr.max(axis=-1)
return arr.reshape(-1)
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
queries = [q for q, _ in batch]
docs = [d for _, d in batch]
tokenizer_kwargs: dict[str, Any] = {
"text": queries,
"text_pair": docs,
"padding": True,
"truncation": True,
"return_tensors": "np",
}
max_len = self.max_length
if max_len is None:
try:
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
if 0 < model_max < 10_000:
max_len = model_max
else:
max_len = 512
except Exception:
max_len = 512
if max_len is not None and max_len > 0:
tokenizer_kwargs["max_length"] = int(max_len)
encoded = self._tokenizer(**tokenizer_kwargs)
inputs = dict(encoded)
# Some models do not accept token_type_ids; filter to known input names if available.
if self._model_input_names:
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
return inputs
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
if self._model is None:
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
outputs = self._model(**inputs)
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, dict) and "logits" in outputs:
return outputs["logits"]
if isinstance(outputs, (list, tuple)) and outputs:
return outputs[0]
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
if not pairs:
return []
self._load_model()
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
return []
import numpy as np
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for batch in _iter_batches(list(pairs), bs):
inputs = self._tokenize_batch(batch)
logits = self._forward_logits(inputs)
rel_logits = self._select_relevance_logit(logits)
probs = self._sigmoid(rel_logits)
probs = np.clip(probs, 0.0, 1.0)
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
if len(scores) != len(pairs):
logger.debug(
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
)
return scores[: len(pairs)]
return scores

View File

@@ -0,0 +1,434 @@
"""Rotational embedder for multi-endpoint API load balancing.
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
to maximize throughput while respecting rate limits.
"""
from __future__ import annotations
import logging
import random
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional
import numpy as np
from .base import BaseEmbedder
logger = logging.getLogger(__name__)
class EndpointStatus(Enum):
"""Status of an API endpoint."""
AVAILABLE = "available"
COOLING = "cooling" # Rate limited, temporarily unavailable
FAILED = "failed" # Permanent failure (auth error, etc.)
class SelectionStrategy(Enum):
"""Strategy for selecting endpoints."""
ROUND_ROBIN = "round_robin"
LATENCY_AWARE = "latency_aware"
WEIGHTED_RANDOM = "weighted_random"
@dataclass
class EndpointConfig:
"""Configuration for a single API endpoint."""
model: str
api_key: Optional[str] = None
api_base: Optional[str] = None
weight: float = 1.0 # Higher weight = more requests
max_concurrent: int = 4 # Max concurrent requests to this endpoint
@dataclass
class EndpointState:
"""Runtime state for an endpoint."""
config: EndpointConfig
embedder: Any = None # LiteLLMEmbedderWrapper instance
# Health metrics
status: EndpointStatus = EndpointStatus.AVAILABLE
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
# Performance metrics
total_requests: int = 0
total_failures: int = 0
avg_latency_ms: float = 0.0
last_latency_ms: float = 0.0
# Concurrency tracking
active_requests: int = 0
lock: threading.Lock = field(default_factory=threading.Lock)
def is_available(self) -> bool:
"""Check if endpoint is available for requests."""
if self.status == EndpointStatus.FAILED:
return False
if self.status == EndpointStatus.COOLING:
if time.time() >= self.cooldown_until:
self.status = EndpointStatus.AVAILABLE
return True
return False
return True
def set_cooldown(self, seconds: float) -> None:
"""Put endpoint in cooldown state."""
self.status = EndpointStatus.COOLING
self.cooldown_until = time.time() + seconds
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
def mark_failed(self) -> None:
"""Mark endpoint as permanently failed."""
self.status = EndpointStatus.FAILED
logger.error(f"Endpoint {self.config.model} marked as failed")
def record_success(self, latency_ms: float) -> None:
"""Record successful request."""
self.total_requests += 1
self.last_latency_ms = latency_ms
# Exponential moving average for latency
alpha = 0.3
if self.avg_latency_ms == 0:
self.avg_latency_ms = latency_ms
else:
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
def record_failure(self) -> None:
"""Record failed request."""
self.total_requests += 1
self.total_failures += 1
@property
def health_score(self) -> float:
"""Calculate health score (0-1) based on metrics."""
if not self.is_available():
return 0.0
# Base score from success rate
if self.total_requests > 0:
success_rate = 1 - (self.total_failures / self.total_requests)
else:
success_rate = 1.0
# Latency factor (faster = higher score)
# Normalize: 100ms = 1.0, 1000ms = 0.1
if self.avg_latency_ms > 0:
latency_factor = min(1.0, 100 / self.avg_latency_ms)
else:
latency_factor = 1.0
# Availability factor (less concurrent = more available)
if self.config.max_concurrent > 0:
availability = 1 - (self.active_requests / self.config.max_concurrent)
else:
availability = 1.0
# Combined score with weights
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
class RotationalEmbedder(BaseEmbedder):
"""Embedder that load balances across multiple API endpoints.
Features:
- Intelligent endpoint selection based on latency and health
- Automatic failover on rate limits (429) and server errors
- Cooldown management to respect rate limits
- Thread-safe concurrent request handling
Args:
endpoints: List of endpoint configurations
strategy: Selection strategy (default: latency_aware)
default_cooldown: Default cooldown seconds for rate limits (default: 60)
max_retries: Maximum retry attempts across all endpoints (default: 3)
"""
def __init__(
self,
endpoints: List[EndpointConfig],
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
default_cooldown: float = 60.0,
max_retries: int = 3,
) -> None:
if not endpoints:
raise ValueError("At least one endpoint must be provided")
self.strategy = strategy
self.default_cooldown = default_cooldown
self.max_retries = max_retries
# Initialize endpoint states
self._endpoints: List[EndpointState] = []
self._lock = threading.Lock()
self._round_robin_index = 0
# Create embedder instances for each endpoint
from .litellm_embedder import LiteLLMEmbedderWrapper
for config in endpoints:
# Build kwargs for LiteLLMEmbedderWrapper
kwargs: Dict[str, Any] = {}
if config.api_key:
kwargs["api_key"] = config.api_key
if config.api_base:
kwargs["api_base"] = config.api_base
try:
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
state = EndpointState(config=config, embedder=embedder)
self._endpoints.append(state)
logger.info(f"Initialized endpoint: {config.model}")
except Exception as e:
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
if not self._endpoints:
raise ValueError("Failed to initialize any endpoints")
# Cache embedding properties from first endpoint
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
self._max_tokens = self._endpoints[0].embedder.max_tokens
@property
def embedding_dim(self) -> int:
"""Return embedding dimensions."""
return self._embedding_dim
@property
def model_name(self) -> str:
"""Return model name."""
return self._model_name
@property
def max_tokens(self) -> int:
"""Return maximum token limit."""
return self._max_tokens
@property
def endpoint_count(self) -> int:
"""Return number of configured endpoints."""
return len(self._endpoints)
@property
def available_endpoint_count(self) -> int:
"""Return number of available endpoints."""
return sum(1 for ep in self._endpoints if ep.is_available())
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
"""Get statistics for all endpoints."""
stats = []
for ep in self._endpoints:
stats.append({
"model": ep.config.model,
"status": ep.status.value,
"total_requests": ep.total_requests,
"total_failures": ep.total_failures,
"avg_latency_ms": round(ep.avg_latency_ms, 2),
"health_score": round(ep.health_score, 3),
"active_requests": ep.active_requests,
})
return stats
def _select_endpoint(self) -> Optional[EndpointState]:
"""Select best available endpoint based on strategy."""
available = [ep for ep in self._endpoints if ep.is_available()]
if not available:
return None
if self.strategy == SelectionStrategy.ROUND_ROBIN:
with self._lock:
self._round_robin_index = (self._round_robin_index + 1) % len(available)
return available[self._round_robin_index]
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
# Sort by health score (descending) and pick top candidate
# Add small random factor to prevent thundering herd
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
scored.sort(key=lambda x: x[1], reverse=True)
return scored[0][0]
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
# Weighted random selection based on health scores
scores = [ep.health_score for ep in available]
total = sum(scores)
if total == 0:
return random.choice(available)
weights = [s / total for s in scores]
return random.choices(available, weights=weights, k=1)[0]
return available[0]
def _parse_retry_after(self, error: Exception) -> Optional[float]:
"""Extract Retry-After value from error if available."""
error_str = str(error)
# Try to find Retry-After in error message
import re
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
if match:
return float(match.group(1))
return None
def _is_rate_limit_error(self, error: Exception) -> bool:
"""Check if error is a rate limit error."""
error_str = str(error).lower()
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
def _is_retryable_error(self, error: Exception) -> bool:
"""Check if error is retryable (not auth/config error)."""
error_str = str(error).lower()
# Retryable errors
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
"timeout", "connection", "service unavailable"]):
return True
# Non-retryable errors (auth, config)
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
"unauthorized", "api key"]):
return False
# Default to retryable for unknown errors
return True
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
"""Embed texts using load-balanced endpoint selection.
Args:
texts: Single text or iterable of texts to embed.
**kwargs: Additional arguments passed to underlying embedder.
Returns:
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
Raises:
RuntimeError: If all endpoints fail after retries.
"""
if isinstance(texts, str):
texts = [texts]
else:
texts = list(texts)
last_error: Optional[Exception] = None
tried_endpoints: set = set()
for attempt in range(self.max_retries + 1):
endpoint = self._select_endpoint()
if endpoint is None:
# All endpoints unavailable, wait for shortest cooldown
min_cooldown = min(
(ep.cooldown_until - time.time() for ep in self._endpoints
if ep.status == EndpointStatus.COOLING),
default=self.default_cooldown
)
if min_cooldown > 0 and attempt < self.max_retries:
wait_time = min(min_cooldown, 30) # Cap wait at 30s
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
time.sleep(wait_time)
continue
break
# Track tried endpoints to avoid infinite loops
endpoint_id = id(endpoint)
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
# Already tried all endpoints
break
tried_endpoints.add(endpoint_id)
# Acquire slot
with endpoint.lock:
endpoint.active_requests += 1
try:
start_time = time.time()
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
latency_ms = (time.time() - start_time) * 1000
# Record success
endpoint.record_success(latency_ms)
return result
except Exception as e:
last_error = e
endpoint.record_failure()
if self._is_rate_limit_error(e):
# Rate limited - set cooldown
retry_after = self._parse_retry_after(e) or self.default_cooldown
endpoint.set_cooldown(retry_after)
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
f"cooling for {retry_after}s")
elif not self._is_retryable_error(e):
# Permanent failure (auth error, etc.)
endpoint.mark_failed()
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
else:
# Temporary error - short cooldown
endpoint.set_cooldown(5.0)
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
finally:
with endpoint.lock:
endpoint.active_requests -= 1
# All retries exhausted
available = self.available_endpoint_count
raise RuntimeError(
f"All embedding attempts failed after {self.max_retries + 1} tries. "
f"Available endpoints: {available}/{len(self._endpoints)}. "
f"Last error: {last_error}"
)
def create_rotational_embedder(
endpoints_config: List[Dict[str, Any]],
strategy: str = "latency_aware",
default_cooldown: float = 60.0,
) -> RotationalEmbedder:
"""Factory function to create RotationalEmbedder from config dicts.
Args:
endpoints_config: List of endpoint configuration dicts with keys:
- model: Model identifier (required)
- api_key: API key (optional)
- api_base: API base URL (optional)
- weight: Request weight (optional, default 1.0)
- max_concurrent: Max concurrent requests (optional, default 4)
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
default_cooldown: Default cooldown seconds for rate limits
Returns:
Configured RotationalEmbedder instance
Example config:
endpoints_config = [
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
]
"""
endpoints = []
for cfg in endpoints_config:
endpoints.append(EndpointConfig(
model=cfg["model"],
api_key=cfg.get("api_key"),
api_base=cfg.get("api_base"),
weight=cfg.get("weight", 1.0),
max_concurrent=cfg.get("max_concurrent", 4),
))
strategy_enum = SelectionStrategy[strategy.upper()]
return RotationalEmbedder(
endpoints=endpoints,
strategy=strategy_enum,
default_cooldown=default_cooldown,
)

View File

@@ -0,0 +1,567 @@
"""ONNX-optimized SPLADE sparse encoder for code search.
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
that combine the interpretability of BM25 with neural relevance modeling.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
Install (GPU):
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
"""
from __future__ import annotations
import logging
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def check_splade_available() -> Tuple[bool, Optional[str]]:
"""Check whether SPLADE dependencies are available.
Returns:
Tuple of (available: bool, error_message: Optional[str])
"""
try:
import numpy # noqa: F401
except ImportError as exc:
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc:
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
except ImportError as exc:
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc:
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
# Global cache for SPLADE encoders (singleton pattern)
_splade_cache: Dict[str, "SpladeEncoder"] = {}
_cache_lock = threading.RLock()
def get_splade_encoder(
model_name: str = "naver/splade-cocondenser-ensembledistil",
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
cache_dir: Optional[str] = None,
) -> "SpladeEncoder":
"""Get or create cached SPLADE encoder (thread-safe singleton).
This function provides significant performance improvement by reusing
SpladeEncoder instances across multiple searches, avoiding repeated model
loading overhead.
Args:
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
Returns:
Cached SpladeEncoder instance for the given configuration
"""
global _splade_cache
# Cache key includes all configuration parameters
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
with _cache_lock:
encoder = _splade_cache.get(cache_key)
if encoder is not None:
return encoder
# Create new encoder and cache it
encoder = SpladeEncoder(
model_name=model_name,
use_gpu=use_gpu,
max_length=max_length,
sparsity_threshold=sparsity_threshold,
cache_dir=cache_dir,
)
# Pre-load model to ensure it's ready
encoder._load_model()
_splade_cache[cache_key] = encoder
return encoder
def clear_splade_cache() -> None:
"""Clear the SPLADE encoder cache and release ONNX resources.
This method ensures proper cleanup of ONNX model resources to prevent
memory leaks when encoders are no longer needed.
"""
global _splade_cache
with _cache_lock:
# Release ONNX resources before clearing cache
for encoder in _splade_cache.values():
if encoder._model is not None:
del encoder._model
encoder._model = None
if encoder._tokenizer is not None:
del encoder._tokenizer
encoder._tokenizer = None
_splade_cache.clear()
class SpladeEncoder:
"""ONNX-optimized SPLADE sparse encoder.
Produces sparse vectors with vocabulary-aligned dimensions.
Output: Dict[int, float] mapping token_id to weight.
SPLADE activation formula:
splade_repr = log(1 + ReLU(logits)) * attention_mask
splade_vec = max_pooling(splade_repr, axis=sequence_length)
References:
- SPLADE: https://arxiv.org/abs/2107.05720
- SPLADE v2: https://arxiv.org/abs/2109.10086
"""
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
def __init__(
self,
model_name: str = DEFAULT_MODEL,
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
providers: Optional[List[Any]] = None,
cache_dir: Optional[str] = None,
) -> None:
"""Initialize SPLADE encoder.
Args:
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
providers: Explicit ONNX providers list (overrides use_gpu)
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.max_length = int(max_length) if max_length > 0 else 512
self.sparsity_threshold = float(sparsity_threshold)
self.providers = providers
# Setup ONNX cache directory
if cache_dir:
self._cache_dir = Path(cache_dir)
else:
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
self._tokenizer: Any | None = None
self._model: Any | None = None
self._vocab_size: int | None = None
self._lock = threading.RLock()
def _get_local_cache_path(self) -> Path:
"""Get local cache path for this model's ONNX files.
Returns:
Path to the local ONNX cache directory for this model
"""
# Replace / with -- for filesystem-safe naming
safe_name = self.model_name.replace("/", "--")
return self._cache_dir / safe_name
def _load_model(self) -> None:
"""Lazy load ONNX model and tokenizer.
First checks local cache for ONNX model, falling back to
HuggingFace download and conversion if not cached.
"""
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_splade_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForMaskedLM
from transformers import AutoTokenizer
if self.providers is None:
from .gpu_support import get_optimal_providers, get_selected_device_id
# Get providers as pure string list (cache-friendly)
# NOTE: with_device_options=False to avoid tuple-based providers
# which break optimum's caching mechanism
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=False
)
# Get device_id separately for provider_options
self._device_id = get_selected_device_id() if self.use_gpu else None
# Some Optimum versions accept `providers`, others accept a single `provider`
# Prefer passing the full providers list, with a conservative fallback
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
# Pass device_id via provider_options for GPU selection
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
# Build provider_options dict for each GPU provider
provider_options = {}
for p in self.providers:
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
provider_options[p] = {"device_id": self._device_id}
if provider_options:
model_kwargs["provider_options"] = provider_options
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception as e:
logger.debug(f"Failed to inspect ORTModel signature: {e}")
model_kwargs = {}
# Check for local ONNX cache first
local_cache = self._get_local_cache_path()
onnx_model_path = local_cache / "model.onnx"
if onnx_model_path.exists():
# Load from local cache
logger.info(f"Loading SPLADE from local cache: {local_cache}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
str(local_cache),
**model_kwargs,
)
self._tokenizer = AutoTokenizer.from_pretrained(
str(local_cache), use_fast=True
)
self._vocab_size = len(self._tokenizer)
logger.info(
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
)
return
except Exception as e:
logger.warning(f"Failed to load from cache, redownloading: {e}")
# Download and convert from HuggingFace
logger.info(f"Downloading SPLADE model: {self.model_name}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True, # Export to ONNX
**model_kwargs,
)
logger.debug(f"SPLADE model loaded: {self.model_name}")
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True,
)
logger.warning(
"Optimum version doesn't support provider parameters. "
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache vocabulary size
self._vocab_size = len(self._tokenizer)
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
# Save to local cache for future use
try:
local_cache.mkdir(parents=True, exist_ok=True)
self._model.save_pretrained(str(local_cache))
self._tokenizer.save_pretrained(str(local_cache))
logger.info(f"SPLADE model cached to: {local_cache}")
except Exception as e:
logger.warning(f"Failed to cache SPLADE model: {e}")
@staticmethod
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
"""Apply SPLADE activation function to model outputs.
Formula: log(1 + ReLU(logits)) * attention_mask
Args:
logits: Model output logits (batch, seq_len, vocab_size)
attention_mask: Attention mask (batch, seq_len)
Returns:
SPLADE representations (batch, seq_len, vocab_size)
"""
import numpy as np
# ReLU activation
relu_logits = np.maximum(0, logits)
# Log(1 + x) transformation
log_relu = np.log1p(relu_logits)
# Apply attention mask (expand to match vocab dimension)
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
mask_expanded = np.expand_dims(attention_mask, axis=-1)
# Element-wise multiplication
splade_repr = log_relu * mask_expanded
return splade_repr
@staticmethod
def _max_pooling(splade_repr: Any) -> Any:
"""Max pooling over sequence length dimension.
Args:
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
Returns:
Pooled sparse vectors (batch, vocab_size)
"""
import numpy as np
# Max pooling over sequence dimension (axis=1)
return np.max(splade_repr, axis=1)
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
"""Convert dense vector to sparse dictionary.
Args:
dense_vec: Dense vector (vocab_size,)
Returns:
Sparse dictionary {token_id: weight} with weights above threshold
"""
import numpy as np
# Find non-zero indices above threshold
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
# Create sparse dictionary
sparse_dict = {
int(idx): float(dense_vec[idx])
for idx in nonzero_indices
}
return sparse_dict
def warmup(self, text: str = "warmup query") -> None:
"""Warmup the encoder by running a dummy inference.
First-time model inference includes initialization overhead.
Call this method once before the first real search to avoid
latency spikes.
Args:
text: Dummy text for warmup (default: "warmup query")
"""
logger.info("Warming up SPLADE encoder...")
# Trigger model loading and first inference
_ = self.encode_text(text)
logger.info("SPLADE encoder warmup complete")
def encode_text(self, text: str) -> Dict[int, float]:
"""Encode text to sparse vector {token_id: weight}.
Args:
text: Input text to encode
Returns:
Sparse vector as dictionary mapping token_id to weight
"""
self._load_model()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
import numpy as np
# Tokenize input
encoded = self._tokenizer(
text,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="np",
)
# Forward pass through model
outputs = self._model(**encoded)
# Extract logits
if hasattr(outputs, "logits"):
logits = outputs.logits
elif isinstance(outputs, dict) and "logits" in outputs:
logits = outputs["logits"]
elif isinstance(outputs, (list, tuple)) and outputs:
logits = outputs[0]
else:
raise RuntimeError("Unexpected model output format")
# Apply SPLADE activation
attention_mask = encoded["attention_mask"]
splade_repr = self._splade_activation(logits, attention_mask)
# Max pooling over sequence length
splade_vec = self._max_pooling(splade_repr)
# Convert to sparse dictionary (single item batch)
sparse_dict = self._to_sparse_dict(splade_vec[0])
return sparse_dict
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
"""Batch encode texts to sparse vectors.
Args:
texts: List of input texts to encode
batch_size: Batch size for encoding (default: 32)
Returns:
List of sparse vectors as dictionaries
"""
if not texts:
return []
self._load_model()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
import numpy as np
results: List[Dict[int, float]] = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenize batch
encoded = self._tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="np",
)
# Forward pass through model
outputs = self._model(**encoded)
# Extract logits
if hasattr(outputs, "logits"):
logits = outputs.logits
elif isinstance(outputs, dict) and "logits" in outputs:
logits = outputs["logits"]
elif isinstance(outputs, (list, tuple)) and outputs:
logits = outputs[0]
else:
raise RuntimeError("Unexpected model output format")
# Apply SPLADE activation
attention_mask = encoded["attention_mask"]
splade_repr = self._splade_activation(logits, attention_mask)
# Max pooling over sequence length
splade_vecs = self._max_pooling(splade_repr)
# Convert each vector to sparse dictionary
for vec in splade_vecs:
sparse_dict = self._to_sparse_dict(vec)
results.append(sparse_dict)
return results
@property
def vocab_size(self) -> int:
"""Return vocabulary size (~30k for BERT-based models).
Returns:
Vocabulary size (number of tokens in tokenizer)
"""
if self._vocab_size is not None:
return self._vocab_size
self._load_model()
return self._vocab_size or 0
def get_token(self, token_id: int) -> str:
"""Convert token_id to string (for debugging).
Args:
token_id: Token ID to convert
Returns:
Token string
"""
self._load_model()
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded")
return self._tokenizer.decode([token_id])
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
"""Get top-k tokens with highest weights from sparse vector.
Useful for debugging and understanding what the model is focusing on.
Args:
sparse_vec: Sparse vector as {token_id: weight}
top_k: Number of top tokens to return
Returns:
List of (token_string, weight) tuples, sorted by weight descending
"""
self._load_model()
if not sparse_vec:
return []
# Sort by weight descending
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
# Take top-k and convert token_ids to strings
top_items = sorted_items[:top_k]
return [
(self.get_token(token_id), weight)
for token_id, weight in top_items
]

File diff suppressed because it is too large Load Diff