mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-12 02:37:45 +08:00
Refactor code structure and remove redundant changes
This commit is contained in:
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
118
codex-lens/build/lib/codexlens/semantic/__init__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Optional semantic search module for CodexLens.
|
||||
|
||||
Install with: pip install codexlens[semantic]
|
||||
Uses fastembed (ONNX-based, lightweight ~200MB)
|
||||
|
||||
GPU Acceleration:
|
||||
- Automatic GPU detection and usage when available
|
||||
- Supports CUDA (NVIDIA), TensorRT, DirectML (Windows), ROCm (AMD), CoreML (Apple)
|
||||
- Install GPU support: pip install onnxruntime-gpu (NVIDIA) or onnxruntime-directml (Windows)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
SEMANTIC_AVAILABLE = False
|
||||
SEMANTIC_BACKEND: str | None = None
|
||||
GPU_AVAILABLE = False
|
||||
LITELLM_AVAILABLE = False
|
||||
_import_error: str | None = None
|
||||
|
||||
|
||||
def _detect_backend() -> tuple[bool, str | None, bool, str | None]:
|
||||
"""Detect if fastembed and GPU are available."""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError as e:
|
||||
return False, None, False, f"numpy not available: {e}"
|
||||
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
return False, None, False, "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||
|
||||
# Check GPU availability
|
||||
gpu_available = False
|
||||
try:
|
||||
from .gpu_support import is_gpu_available
|
||||
gpu_available = is_gpu_available()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return True, "fastembed", gpu_available, None
|
||||
|
||||
|
||||
# Initialize on module load
|
||||
SEMANTIC_AVAILABLE, SEMANTIC_BACKEND, GPU_AVAILABLE, _import_error = _detect_backend()
|
||||
|
||||
|
||||
def check_semantic_available() -> tuple[bool, str | None]:
|
||||
"""Check if semantic search dependencies are available."""
|
||||
return SEMANTIC_AVAILABLE, _import_error
|
||||
|
||||
|
||||
def check_gpu_available() -> tuple[bool, str]:
|
||||
"""Check if GPU acceleration is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, status_message)
|
||||
"""
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
return False, "Semantic search not available"
|
||||
|
||||
try:
|
||||
from .gpu_support import is_gpu_available, get_gpu_summary
|
||||
if is_gpu_available():
|
||||
return True, get_gpu_summary()
|
||||
return False, "No GPU detected (using CPU)"
|
||||
except ImportError:
|
||||
return False, "GPU support module not available"
|
||||
|
||||
|
||||
# Export embedder components
|
||||
# BaseEmbedder is always available (abstract base class)
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Factory function for creating embedders
|
||||
from .factory import get_embedder as get_embedder_factory
|
||||
|
||||
# Optional: LiteLLMEmbedderWrapper (only if ccw-litellm is installed)
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
LITELLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
LiteLLMEmbedderWrapper = None
|
||||
LITELLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_embedding_backend_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific embedding backend can be used.
|
||||
|
||||
Notes:
|
||||
- "fastembed" requires the optional semantic deps (pip install codexlens[semantic]).
|
||||
- "litellm" requires ccw-litellm to be installed in the same environment.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
if backend == "fastembed":
|
||||
if SEMANTIC_AVAILABLE:
|
||||
return True, None
|
||||
return False, _import_error or "fastembed not available. Install with: pip install codexlens[semantic]"
|
||||
if backend == "litellm":
|
||||
if LITELLM_AVAILABLE:
|
||||
return True, None
|
||||
return False, "ccw-litellm not available. Install with: pip install ccw-litellm"
|
||||
return False, f"Invalid embedding backend: {backend}. Must be 'fastembed' or 'litellm'."
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SEMANTIC_AVAILABLE",
|
||||
"SEMANTIC_BACKEND",
|
||||
"GPU_AVAILABLE",
|
||||
"LITELLM_AVAILABLE",
|
||||
"check_semantic_available",
|
||||
"is_embedding_backend_available",
|
||||
"check_gpu_available",
|
||||
"BaseEmbedder",
|
||||
"get_embedder_factory",
|
||||
"LiteLLMEmbedderWrapper",
|
||||
]
|
||||
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
1068
codex-lens/build/lib/codexlens/semantic/ann_index.py
Normal file
File diff suppressed because it is too large
Load Diff
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
61
codex-lens/build/lib/codexlens/semantic/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Base class for embedders.
|
||||
|
||||
Defines the interface that all embedders must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Base class for all embedders.
|
||||
|
||||
All embedder implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions.
|
||||
|
||||
Returns:
|
||||
int: Dimension of the embedding vectors.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str:
|
||||
"""Return model name.
|
||||
|
||||
Returns:
|
||||
str: Name or identifier of the underlying model.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for embeddings.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def embed_to_numpy(self, texts: str | Iterable[str]) -> np.ndarray:
|
||||
"""Embed texts to numpy array.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
...
|
||||
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
821
codex-lens/build/lib/codexlens/semantic/chunker.py
Normal file
@@ -0,0 +1,821 @@
|
||||
"""Code chunking strategies for semantic search.
|
||||
|
||||
This module provides various chunking strategies for breaking down source code
|
||||
into semantic chunks suitable for embedding and search.
|
||||
|
||||
Lightweight Mode:
|
||||
The ChunkConfig supports a `skip_token_count` option for performance optimization.
|
||||
When enabled, token counting uses a fast character-based estimation (char/4)
|
||||
instead of expensive tiktoken encoding.
|
||||
|
||||
Use cases for lightweight mode:
|
||||
- Large-scale indexing where speed is critical
|
||||
- Scenarios where approximate token counts are acceptable
|
||||
- Memory-constrained environments
|
||||
- Initial prototyping and development
|
||||
|
||||
Example:
|
||||
# Default mode (accurate tiktoken encoding)
|
||||
config = ChunkConfig()
|
||||
chunker = Chunker(config)
|
||||
|
||||
# Lightweight mode (fast char/4 estimation)
|
||||
config = ChunkConfig(skip_token_count=True)
|
||||
chunker = Chunker(config)
|
||||
chunks = chunker.chunk_file(content, symbols, path, language)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SemanticChunk, Symbol
|
||||
from codexlens.parsers.tokenizer import get_default_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkConfig:
|
||||
"""Configuration for chunking strategies."""
|
||||
max_chunk_size: int = 1000 # Max characters per chunk
|
||||
overlap: int = 200 # Overlap for sliding window (increased from 100 for better context)
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
|
||||
strip_comments: bool = True # Remove comments from chunk content for embedding
|
||||
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
|
||||
preserve_original: bool = True # Store original content in metadata when stripping
|
||||
|
||||
|
||||
class CommentStripper:
|
||||
"""Remove comments from source code while preserving structure."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_comments(content: str) -> str:
|
||||
"""Strip Python comments (# style) but preserve docstrings.
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
for line in lines:
|
||||
new_line = []
|
||||
i = 0
|
||||
while i < len(line):
|
||||
char = line[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'") and not in_string:
|
||||
# Check for triple quotes
|
||||
if line[i:i+3] in ('"""', "'''"):
|
||||
in_string = True
|
||||
string_char = line[i:i+3]
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
continue
|
||||
else:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif in_string:
|
||||
if string_char and len(string_char) == 3:
|
||||
if line[i:i+3] == string_char:
|
||||
in_string = False
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
string_char = None
|
||||
continue
|
||||
elif char == string_char:
|
||||
# Check for escape
|
||||
if i > 0 and line[i-1] != '\\':
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
# Handle comments (only outside strings)
|
||||
if char == '#' and not in_string:
|
||||
# Rest of line is comment, skip it
|
||||
new_line.append('\n' if line.endswith('\n') else '')
|
||||
break
|
||||
|
||||
new_line.append(char)
|
||||
i += 1
|
||||
|
||||
result_lines.append(''.join(new_line))
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_c_style_comments(content: str) -> str:
|
||||
"""Strip C-style comments (// and /* */) from code.
|
||||
|
||||
Args:
|
||||
content: Source code with C-style comments
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_multiline_comment = False
|
||||
|
||||
while i < len(content):
|
||||
# Handle multi-line comment end
|
||||
if in_multiline_comment:
|
||||
if content[i:i+2] == '*/':
|
||||
in_multiline_comment = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
char = content[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'", '`') and not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
elif in_string:
|
||||
result.append(char)
|
||||
if char == string_char and (i == 0 or content[i-1] != '\\'):
|
||||
in_string = False
|
||||
string_char = None
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle comments
|
||||
if content[i:i+2] == '//':
|
||||
# Single line comment - skip to end of line
|
||||
while i < len(content) and content[i] != '\n':
|
||||
i += 1
|
||||
if i < len(content):
|
||||
result.append('\n')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if content[i:i+2] == '/*':
|
||||
in_multiline_comment = True
|
||||
i += 2
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_comments(cls, content: str, language: str) -> str:
|
||||
"""Strip comments based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_comments(content)
|
||||
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
|
||||
return cls.strip_c_style_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class DocstringStripper:
|
||||
"""Remove docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_docstrings(content: str) -> str:
|
||||
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
# Check for docstring start
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
|
||||
# Single line docstring
|
||||
if stripped.count(quote_type) >= 2:
|
||||
# Skip this line (docstring)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Multi-line docstring - skip until closing
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
if quote_type in lines[i]:
|
||||
i += 1
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
i += 1
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_jsdoc_comments(content: str) -> str:
|
||||
"""Strip JSDoc comments (/** ... */) from code.
|
||||
|
||||
Args:
|
||||
content: JavaScript/TypeScript source code
|
||||
|
||||
Returns:
|
||||
Code with JSDoc comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_jsdoc = False
|
||||
|
||||
while i < len(content):
|
||||
if in_jsdoc:
|
||||
if content[i:i+2] == '*/':
|
||||
in_jsdoc = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for JSDoc start (/** but not /*)
|
||||
if content[i:i+3] == '/**':
|
||||
in_jsdoc = True
|
||||
i += 3
|
||||
continue
|
||||
|
||||
result.append(content[i])
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_docstrings(cls, content: str, language: str) -> str:
|
||||
"""Strip docstrings based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.strip_jsdoc_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class Chunker:
|
||||
"""Chunk code files for semantic embedding."""
|
||||
|
||||
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||
self.config = config or ChunkConfig()
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
self._comment_stripper = CommentStripper()
|
||||
self._docstring_stripper = DocstringStripper()
|
||||
|
||||
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process chunk content by stripping comments/docstrings if configured.
|
||||
|
||||
Args:
|
||||
content: Original chunk content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_content, original_content_if_preserved)
|
||||
"""
|
||||
original = content if self.config.preserve_original else None
|
||||
processed = content
|
||||
|
||||
if self.config.strip_comments:
|
||||
processed = self._comment_stripper.strip_comments(processed, language)
|
||||
|
||||
if self.config.strip_docstrings:
|
||||
processed = self._docstring_stripper.strip_docstrings(processed, language)
|
||||
|
||||
# If nothing changed, don't store original
|
||||
if processed == content:
|
||||
original = None
|
||||
|
||||
return processed, original
|
||||
|
||||
def _estimate_token_count(self, text: str) -> int:
|
||||
"""Estimate token count based on config.
|
||||
|
||||
If skip_token_count is True, uses character-based estimation (char/4).
|
||||
Otherwise, uses accurate tiktoken encoding.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
if self.config.skip_token_count:
|
||||
# Fast character-based estimation: ~4 chars per token
|
||||
return max(1, len(text) // 4)
|
||||
return self._tokenizer.count_tokens(text)
|
||||
|
||||
def chunk_by_symbol(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code by extracted symbols (functions, classes).
|
||||
|
||||
Each symbol becomes one chunk with its full content.
|
||||
Large symbols exceeding max_chunk_size are recursively split using sliding window.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start_line - 1)
|
||||
end_idx = min(len(lines), end_line)
|
||||
|
||||
chunk_content = "".join(lines[start_idx:end_idx])
|
||||
if len(chunk_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Check if symbol content exceeds max_chunk_size
|
||||
if len(chunk_content) > self.config.max_chunk_size:
|
||||
# Create line mapping for correct line number tracking
|
||||
line_mapping = list(range(start_line, end_line + 1))
|
||||
|
||||
# Use sliding window to split large symbol
|
||||
sub_chunks = self.chunk_sliding_window(
|
||||
chunk_content,
|
||||
file_path=file_path,
|
||||
language=language,
|
||||
line_mapping=line_mapping
|
||||
)
|
||||
|
||||
# Update sub_chunks with parent symbol metadata
|
||||
for sub_chunk in sub_chunks:
|
||||
sub_chunk.metadata["symbol_name"] = symbol.name
|
||||
sub_chunk.metadata["symbol_kind"] = symbol.kind
|
||||
sub_chunk.metadata["strategy"] = "symbol_split"
|
||||
sub_chunk.metadata["chunk_type"] = "code"
|
||||
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
|
||||
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_sliding_window(
|
||||
self,
|
||||
content: str,
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
line_mapping: Optional[List[int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk code using sliding window approach.
|
||||
|
||||
Used for files without clear symbol boundaries or very long functions.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
line_mapping: Optional list mapping content line indices to original line numbers
|
||||
(1-indexed). If provided, line_mapping[i] is the original line number
|
||||
for the i-th line in content.
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
if not lines:
|
||||
return chunks
|
||||
|
||||
# Calculate lines per chunk based on average line length
|
||||
avg_line_len = len(content) / max(len(lines), 1)
|
||||
lines_per_chunk = max(10, int(self.config.max_chunk_size / max(avg_line_len, 1)))
|
||||
overlap_lines = max(2, int(self.config.overlap / max(avg_line_len, 1)))
|
||||
# Ensure overlap is less than chunk size to prevent infinite loop
|
||||
overlap_lines = min(overlap_lines, lines_per_chunk - 1)
|
||||
|
||||
start = 0
|
||||
chunk_idx = 0
|
||||
|
||||
while start < len(lines):
|
||||
end = min(start + lines_per_chunk, len(lines))
|
||||
chunk_content = "".join(lines[start:end])
|
||||
|
||||
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
# Move window forward
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1
|
||||
start += step
|
||||
continue
|
||||
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
# Calculate correct line numbers
|
||||
if line_mapping:
|
||||
# Use line mapping to get original line numbers
|
||||
start_line = line_mapping[start]
|
||||
end_line = line_mapping[end - 1]
|
||||
else:
|
||||
# Default behavior: treat content as starting at line 1
|
||||
start_line = start + 1
|
||||
end_line = end
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
chunk_idx += 1
|
||||
|
||||
# Move window, accounting for overlap
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1 # Failsafe to prevent infinite loop
|
||||
start += step
|
||||
|
||||
# Break if we've reached the end
|
||||
if end >= len(lines):
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk a file using the best strategy.
|
||||
|
||||
Uses symbol-based chunking if symbols available,
|
||||
falls back to sliding window for files without symbols.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
if symbols:
|
||||
return self.chunk_by_symbol(content, symbols, file_path, language, symbol_token_counts)
|
||||
return self.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
class DocstringExtractor:
|
||||
"""Extract docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def extract_python_docstrings(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract Python docstrings with their line ranges.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
start_line = i + 1
|
||||
|
||||
if stripped.count(quote_type) >= 2:
|
||||
docstring_content = line
|
||||
end_line = i + 1
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
docstring_lines = [line]
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
docstring_lines.append(lines[i])
|
||||
if quote_type in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
docstring_content = "".join(docstring_lines)
|
||||
docstrings.append((docstring_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return docstrings
|
||||
|
||||
@staticmethod
|
||||
def extract_jsdoc_comments(content: str) -> List[Tuple[str, int, int]]:
|
||||
"""Extract JSDoc comments with their line ranges.
|
||||
|
||||
Returns: List of (comment_content, start_line, end_line) tuples
|
||||
"""
|
||||
comments: List[Tuple[str, int, int]] = []
|
||||
lines = content.splitlines(keepends=True)
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith('/**'):
|
||||
start_line = i + 1
|
||||
comment_lines = [line]
|
||||
i += 1
|
||||
|
||||
while i < len(lines):
|
||||
comment_lines.append(lines[i])
|
||||
if '*/' in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
end_line = i + 1
|
||||
comment_content = "".join(comment_lines)
|
||||
comments.append((comment_content, start_line, end_line))
|
||||
|
||||
i += 1
|
||||
|
||||
return comments
|
||||
|
||||
@classmethod
|
||||
def extract_docstrings(
|
||||
cls,
|
||||
content: str,
|
||||
language: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""Extract docstrings based on language.
|
||||
|
||||
Returns: List of (docstring_content, start_line, end_line) tuples
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.extract_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.extract_jsdoc_comments(content)
|
||||
return []
|
||||
|
||||
|
||||
class HybridChunker:
|
||||
"""Hybrid chunker that prioritizes docstrings before symbol-based chunking.
|
||||
|
||||
Composition-based strategy that:
|
||||
1. Extracts docstrings as dedicated chunks
|
||||
2. For remaining code, uses base chunker (symbol or sliding window)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_chunker: Chunker | None = None,
|
||||
config: ChunkConfig | None = None
|
||||
) -> None:
|
||||
"""Initialize hybrid chunker.
|
||||
|
||||
Args:
|
||||
base_chunker: Chunker to use for non-docstring content
|
||||
config: Configuration for chunking
|
||||
"""
|
||||
self.config = config or ChunkConfig()
|
||||
self.base_chunker = base_chunker or Chunker(self.config)
|
||||
self.docstring_extractor = DocstringExtractor()
|
||||
|
||||
def _get_excluded_line_ranges(
|
||||
self,
|
||||
docstrings: List[Tuple[str, int, int]]
|
||||
) -> set[int]:
|
||||
"""Get set of line numbers that are part of docstrings."""
|
||||
excluded_lines: set[int] = set()
|
||||
for _, start_line, end_line in docstrings:
|
||||
for line_num in range(start_line, end_line + 1):
|
||||
excluded_lines.add(line_num)
|
||||
return excluded_lines
|
||||
|
||||
def _filter_symbols_outside_docstrings(
|
||||
self,
|
||||
symbols: List[Symbol],
|
||||
excluded_lines: set[int]
|
||||
) -> List[Symbol]:
|
||||
"""Filter symbols to exclude those completely within docstrings."""
|
||||
filtered: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
start_line, end_line = symbol.range
|
||||
symbol_lines = set(range(start_line, end_line + 1))
|
||||
if not symbol_lines.issubset(excluded_lines):
|
||||
filtered.append(symbol)
|
||||
return filtered
|
||||
|
||||
def _find_parent_symbol(
|
||||
self,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
symbols: List[Symbol],
|
||||
) -> Optional[Symbol]:
|
||||
"""Find the smallest symbol range that fully contains a docstring span."""
|
||||
candidates: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
sym_start, sym_end = symbol.range
|
||||
if sym_start <= start_line and end_line <= sym_end:
|
||||
candidates.append(symbol)
|
||||
if not candidates:
|
||||
return None
|
||||
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
symbols: List[Symbol],
|
||||
file_path: str | Path,
|
||||
language: str,
|
||||
symbol_token_counts: Optional[dict[str, int]] = None,
|
||||
) -> List[SemanticChunk]:
|
||||
"""Chunk file using hybrid strategy.
|
||||
|
||||
Extracts docstrings first, then chunks remaining code.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
symbols: List of extracted symbols
|
||||
file_path: Path to source file
|
||||
language: Programming language
|
||||
symbol_token_counts: Optional dict mapping symbol names to token counts
|
||||
"""
|
||||
chunks: List[SemanticChunk] = []
|
||||
|
||||
# Step 1: Extract docstrings as dedicated chunks
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
if language == "python":
|
||||
# Fast path: avoid expensive docstring extraction if delimiters are absent.
|
||||
if '"""' in content or "'''" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
if "/**" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
else:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
|
||||
# Fast path: no docstrings -> delegate to base chunker directly.
|
||||
if not docstrings:
|
||||
if symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
else:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
return base_chunks
|
||||
|
||||
for docstring_content, start_line, end_line in docstrings:
|
||||
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
|
||||
# Use base chunker's token estimation method
|
||||
token_count = self.base_chunker._estimate_token_count(docstring_content)
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
if parent_symbol is not None:
|
||||
metadata["parent_symbol"] = parent_symbol.name
|
||||
metadata["parent_symbol_kind"] = parent_symbol.kind
|
||||
metadata["parent_symbol_range"] = parent_symbol.range
|
||||
chunks.append(SemanticChunk(
|
||||
content=docstring_content,
|
||||
embedding=None,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Step 2: Get line ranges occupied by docstrings
|
||||
excluded_lines = self._get_excluded_line_ranges(docstrings)
|
||||
|
||||
# Step 3: Filter symbols to exclude docstring-only ranges
|
||||
filtered_symbols = self._filter_symbols_outside_docstrings(symbols, excluded_lines)
|
||||
|
||||
# Step 4: Chunk remaining content using base chunker
|
||||
if filtered_symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, filtered_symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
lines = content.splitlines(keepends=True)
|
||||
remaining_lines: List[str] = []
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
if i not in excluded_lines:
|
||||
remaining_lines.append(line)
|
||||
|
||||
if remaining_lines:
|
||||
remaining_content = "".join(remaining_lines)
|
||||
if len(remaining_content.strip()) >= self.config.min_chunk_size:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(
|
||||
remaining_content, file_path, language
|
||||
)
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
274
codex-lens/build/lib/codexlens/semantic/code_extractor.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Smart code extraction for complete code blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SearchResult, Symbol
|
||||
|
||||
|
||||
def extract_complete_code_block(
|
||||
result: SearchResult,
|
||||
source_file_path: Optional[str] = None,
|
||||
context_lines: int = 0,
|
||||
) -> str:
|
||||
"""Extract complete code block from a search result.
|
||||
|
||||
Args:
|
||||
result: SearchResult from semantic search.
|
||||
source_file_path: Optional path to source file for re-reading.
|
||||
context_lines: Additional lines of context to include above/below.
|
||||
|
||||
Returns:
|
||||
Complete code block as string.
|
||||
"""
|
||||
# If we have full content stored, use it
|
||||
if result.content:
|
||||
if context_lines == 0:
|
||||
return result.content
|
||||
# Need to add context, read from file
|
||||
|
||||
# Try to read from source file
|
||||
file_path = source_file_path or result.path
|
||||
if not file_path or not Path(file_path).exists():
|
||||
# Fall back to excerpt
|
||||
return result.excerpt or ""
|
||||
|
||||
try:
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.splitlines()
|
||||
|
||||
# Get line range
|
||||
start_line = result.start_line or 1
|
||||
end_line = result.end_line or len(lines)
|
||||
|
||||
# Add context
|
||||
start_idx = max(0, start_line - 1 - context_lines)
|
||||
end_idx = min(len(lines), end_line + context_lines)
|
||||
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
except Exception:
|
||||
return result.excerpt or result.content or ""
|
||||
|
||||
|
||||
def extract_symbol_with_context(
|
||||
file_path: str,
|
||||
symbol: Symbol,
|
||||
include_docstring: bool = True,
|
||||
include_decorators: bool = True,
|
||||
) -> str:
|
||||
"""Extract a symbol (function/class) with its docstring and decorators.
|
||||
|
||||
Args:
|
||||
file_path: Path to source file.
|
||||
symbol: Symbol to extract.
|
||||
include_docstring: Include docstring if present.
|
||||
include_decorators: Include decorators/annotations above symbol.
|
||||
|
||||
Returns:
|
||||
Complete symbol code with context.
|
||||
"""
|
||||
try:
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="ignore")
|
||||
lines = content.splitlines()
|
||||
|
||||
start_line, end_line = symbol.range
|
||||
start_idx = start_line - 1
|
||||
end_idx = end_line
|
||||
|
||||
# Look for decorators above the symbol
|
||||
if include_decorators and start_idx > 0:
|
||||
decorator_start = start_idx
|
||||
# Search backwards for decorators
|
||||
i = start_idx - 1
|
||||
while i >= 0 and i >= start_idx - 20: # Look up to 20 lines back
|
||||
line = lines[i].strip()
|
||||
if line.startswith("@"):
|
||||
decorator_start = i
|
||||
i -= 1
|
||||
elif line == "" or line.startswith("#"):
|
||||
# Skip empty lines and comments, continue looking
|
||||
i -= 1
|
||||
elif line.startswith("//") or line.startswith("/*") or line.startswith("*"):
|
||||
# JavaScript/Java style comments
|
||||
decorator_start = i
|
||||
i -= 1
|
||||
else:
|
||||
# Found non-decorator, non-comment line, stop
|
||||
break
|
||||
start_idx = decorator_start
|
||||
|
||||
return "\n".join(lines[start_idx:end_idx])
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def format_search_result_code(
|
||||
result: SearchResult,
|
||||
max_lines: Optional[int] = None,
|
||||
show_line_numbers: bool = True,
|
||||
highlight_match: bool = False,
|
||||
) -> str:
|
||||
"""Format search result code for display.
|
||||
|
||||
Args:
|
||||
result: SearchResult to format.
|
||||
max_lines: Maximum lines to show (None for all).
|
||||
show_line_numbers: Include line numbers in output.
|
||||
highlight_match: Add markers for matched region.
|
||||
|
||||
Returns:
|
||||
Formatted code string.
|
||||
"""
|
||||
content = result.content or result.excerpt or ""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
lines = content.splitlines()
|
||||
|
||||
# Truncate if needed
|
||||
truncated = False
|
||||
if max_lines and len(lines) > max_lines:
|
||||
lines = lines[:max_lines]
|
||||
truncated = True
|
||||
|
||||
# Format with line numbers
|
||||
if show_line_numbers:
|
||||
start = result.start_line or 1
|
||||
formatted_lines = []
|
||||
for i, line in enumerate(lines):
|
||||
line_num = start + i
|
||||
formatted_lines.append(f"{line_num:4d} | {line}")
|
||||
output = "\n".join(formatted_lines)
|
||||
else:
|
||||
output = "\n".join(lines)
|
||||
|
||||
if truncated:
|
||||
output += "\n... (truncated)"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_code_block_summary(result: SearchResult) -> str:
|
||||
"""Get a concise summary of a code block.
|
||||
|
||||
Args:
|
||||
result: SearchResult to summarize.
|
||||
|
||||
Returns:
|
||||
Summary string like "function hello_world (lines 10-25)"
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if result.symbol_kind:
|
||||
parts.append(result.symbol_kind)
|
||||
|
||||
if result.symbol_name:
|
||||
parts.append(f"`{result.symbol_name}`")
|
||||
elif result.excerpt:
|
||||
# Extract first meaningful identifier
|
||||
first_line = result.excerpt.split("\n")[0][:50]
|
||||
parts.append(f'"{first_line}..."')
|
||||
|
||||
if result.start_line and result.end_line:
|
||||
if result.start_line == result.end_line:
|
||||
parts.append(f"(line {result.start_line})")
|
||||
else:
|
||||
parts.append(f"(lines {result.start_line}-{result.end_line})")
|
||||
|
||||
if result.path:
|
||||
file_name = Path(result.path).name
|
||||
parts.append(f"in {file_name}")
|
||||
|
||||
return " ".join(parts) if parts else "unknown code block"
|
||||
|
||||
|
||||
class CodeBlockResult:
|
||||
"""Enhanced search result with complete code block."""
|
||||
|
||||
def __init__(self, result: SearchResult, source_path: Optional[str] = None):
|
||||
self.result = result
|
||||
self.source_path = source_path or result.path
|
||||
self._full_code: Optional[str] = None
|
||||
|
||||
@property
|
||||
def score(self) -> float:
|
||||
return self.result.score
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self.result.path
|
||||
|
||||
@property
|
||||
def file_name(self) -> str:
|
||||
return Path(self.result.path).name
|
||||
|
||||
@property
|
||||
def symbol_name(self) -> Optional[str]:
|
||||
return self.result.symbol_name
|
||||
|
||||
@property
|
||||
def symbol_kind(self) -> Optional[str]:
|
||||
return self.result.symbol_kind
|
||||
|
||||
@property
|
||||
def line_range(self) -> Tuple[int, int]:
|
||||
return (
|
||||
self.result.start_line or 1,
|
||||
self.result.end_line or 1
|
||||
)
|
||||
|
||||
@property
|
||||
def full_code(self) -> str:
|
||||
"""Get full code block content."""
|
||||
if self._full_code is None:
|
||||
self._full_code = extract_complete_code_block(self.result, self.source_path)
|
||||
return self._full_code
|
||||
|
||||
@property
|
||||
def excerpt(self) -> str:
|
||||
"""Get short excerpt."""
|
||||
return self.result.excerpt or ""
|
||||
|
||||
@property
|
||||
def summary(self) -> str:
|
||||
"""Get code block summary."""
|
||||
return get_code_block_summary(self.result)
|
||||
|
||||
def format(
|
||||
self,
|
||||
max_lines: Optional[int] = None,
|
||||
show_line_numbers: bool = True,
|
||||
) -> str:
|
||||
"""Format code for display."""
|
||||
# Use full code if available
|
||||
display_result = SearchResult(
|
||||
path=self.result.path,
|
||||
score=self.result.score,
|
||||
content=self.full_code,
|
||||
start_line=self.result.start_line,
|
||||
end_line=self.result.end_line,
|
||||
)
|
||||
return format_search_result_code(
|
||||
display_result,
|
||||
max_lines=max_lines,
|
||||
show_line_numbers=show_line_numbers
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<CodeBlockResult {self.summary} score={self.score:.3f}>"
|
||||
|
||||
|
||||
def enhance_search_results(
|
||||
results: List[SearchResult],
|
||||
) -> List[CodeBlockResult]:
|
||||
"""Enhance search results with complete code block access.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult from semantic search.
|
||||
|
||||
Returns:
|
||||
List of CodeBlockResult with full code access.
|
||||
"""
|
||||
return [CodeBlockResult(r) for r in results]
|
||||
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
288
codex-lens/build/lib/codexlens/semantic/embedder.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Embedder for semantic code search using fastembed.
|
||||
|
||||
Supports GPU acceleration via ONNX execution providers (CUDA, TensorRT, DirectML, ROCm, CoreML).
|
||||
GPU acceleration is automatic when available, with transparent CPU fallback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import SEMANTIC_AVAILABLE
|
||||
from .base import BaseEmbedder
|
||||
from .gpu_support import get_optimal_providers, is_gpu_available, get_gpu_summary, get_selected_device_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global embedder cache for singleton pattern
|
||||
_embedder_cache: Dict[str, "Embedder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_embedder(profile: str = "code", use_gpu: bool = True) -> "Embedder":
|
||||
"""Get or create a cached Embedder instance (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
Embedder instances across multiple searches, avoiding repeated model
|
||||
loading overhead (~0.8s per load).
|
||||
|
||||
Args:
|
||||
profile: Model profile ("fast", "code", "multilingual", "balanced")
|
||||
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||
|
||||
Returns:
|
||||
Cached Embedder instance for the given profile
|
||||
"""
|
||||
global _embedder_cache
|
||||
|
||||
# Cache key includes GPU preference to support mixed configurations
|
||||
cache_key = f"{profile}:{'gpu' if use_gpu else 'cpu'}"
|
||||
|
||||
# All cache access is protected by _cache_lock to avoid races with
|
||||
# clear_embedder_cache() during concurrent access.
|
||||
with _cache_lock:
|
||||
embedder = _embedder_cache.get(cache_key)
|
||||
if embedder is not None:
|
||||
return embedder
|
||||
|
||||
# Create new embedder and cache it
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu)
|
||||
# Pre-load model to ensure it's ready
|
||||
embedder._load_model()
|
||||
_embedder_cache[cache_key] = embedder
|
||||
|
||||
# Log GPU status on first embedder creation
|
||||
if use_gpu and is_gpu_available():
|
||||
logger.info(f"Embedder initialized with GPU: {get_gpu_summary()}")
|
||||
elif use_gpu:
|
||||
logger.debug("GPU not available, using CPU for embeddings")
|
||||
|
||||
return embedder
|
||||
|
||||
|
||||
def clear_embedder_cache() -> None:
|
||||
"""Clear the embedder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when embedders are no longer needed.
|
||||
"""
|
||||
global _embedder_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for embedder in _embedder_cache.values():
|
||||
if embedder._model is not None:
|
||||
del embedder._model
|
||||
embedder._model = None
|
||||
_embedder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class Embedder(BaseEmbedder):
|
||||
"""Generate embeddings for code chunks using fastembed (ONNX-based).
|
||||
|
||||
Supported Model Profiles:
|
||||
- fast: BAAI/bge-small-en-v1.5 (384 dim) - Fast, lightweight, English-optimized
|
||||
- code: jinaai/jina-embeddings-v2-base-code (768 dim) - Code-optimized, best for programming languages
|
||||
- multilingual: intfloat/multilingual-e5-large (1024 dim) - Multilingual + code support
|
||||
- balanced: mixedbread-ai/mxbai-embed-large-v1 (1024 dim) - High accuracy, general purpose
|
||||
"""
|
||||
|
||||
# Model profiles for different use cases
|
||||
MODELS = {
|
||||
"fast": "BAAI/bge-small-en-v1.5", # 384 dim - Fast, lightweight
|
||||
"code": "jinaai/jina-embeddings-v2-base-code", # 768 dim - Code-optimized
|
||||
"multilingual": "intfloat/multilingual-e5-large", # 1024 dim - Multilingual
|
||||
"balanced": "mixedbread-ai/mxbai-embed-large-v1", # 1024 dim - High accuracy
|
||||
}
|
||||
|
||||
# Dimension mapping for each model
|
||||
MODEL_DIMS = {
|
||||
"BAAI/bge-small-en-v1.5": 384,
|
||||
"jinaai/jina-embeddings-v2-base-code": 768,
|
||||
"intfloat/multilingual-e5-large": 1024,
|
||||
"mixedbread-ai/mxbai-embed-large-v1": 1024,
|
||||
}
|
||||
|
||||
# Default model (fast profile)
|
||||
DEFAULT_MODEL = "BAAI/bge-small-en-v1.5"
|
||||
DEFAULT_PROFILE = "fast"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
profile: str | None = None,
|
||||
use_gpu: bool = True,
|
||||
providers: List[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize embedder with model or profile.
|
||||
|
||||
Args:
|
||||
model_name: Explicit model name (e.g., "jinaai/jina-embeddings-v2-base-code")
|
||||
profile: Model profile shortcut ("fast", "code", "multilingual", "balanced")
|
||||
If both provided, model_name takes precedence.
|
||||
use_gpu: If True, use GPU acceleration when available (default: True)
|
||||
providers: Explicit ONNX providers list (overrides use_gpu if provided)
|
||||
"""
|
||||
if not SEMANTIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Semantic search dependencies not available. "
|
||||
"Install with: pip install codexlens[semantic]"
|
||||
)
|
||||
|
||||
# Resolve model name from profile or use explicit name
|
||||
if model_name:
|
||||
self._model_name = model_name
|
||||
elif profile and profile in self.MODELS:
|
||||
self._model_name = self.MODELS[profile]
|
||||
else:
|
||||
self._model_name = self.DEFAULT_MODEL
|
||||
|
||||
# Configure ONNX execution providers with device_id options for GPU selection
|
||||
# Using with_device_options=True ensures DirectML/CUDA device_id is passed correctly
|
||||
if providers is not None:
|
||||
self._providers = providers
|
||||
else:
|
||||
self._providers = get_optimal_providers(use_gpu=use_gpu, with_device_options=True)
|
||||
|
||||
self._use_gpu = use_gpu
|
||||
self._model = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Get model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Get embedding dimension for current model."""
|
||||
return self.MODEL_DIMS.get(self._model_name, 768) # Default to 768 if unknown
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Get maximum token limit for current model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens based on model profile.
|
||||
- fast: 512 (lightweight, optimized for speed)
|
||||
- code: 8192 (code-optimized, larger context)
|
||||
- multilingual: 512 (standard multilingual model)
|
||||
- balanced: 512 (general purpose)
|
||||
"""
|
||||
# Determine profile from model name
|
||||
profile = None
|
||||
for prof, model in self.MODELS.items():
|
||||
if model == self._model_name:
|
||||
profile = prof
|
||||
break
|
||||
|
||||
# Return token limit based on profile
|
||||
if profile == "code":
|
||||
return 8192
|
||||
elif profile in ("fast", "multilingual", "balanced"):
|
||||
return 512
|
||||
else:
|
||||
# Default for unknown models
|
||||
return 512
|
||||
|
||||
@property
|
||||
def providers(self) -> List[str]:
|
||||
"""Get configured ONNX execution providers."""
|
||||
return self._providers
|
||||
|
||||
@property
|
||||
def is_gpu_enabled(self) -> bool:
|
||||
"""Check if GPU acceleration is enabled for this embedder."""
|
||||
gpu_providers = {"CUDAExecutionProvider", "TensorrtExecutionProvider",
|
||||
"DmlExecutionProvider", "ROCMExecutionProvider", "CoreMLExecutionProvider"}
|
||||
# Handle both string providers and tuple providers (name, options)
|
||||
for p in self._providers:
|
||||
provider_name = p[0] if isinstance(p, tuple) else p
|
||||
if provider_name in gpu_providers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load the embedding model with configured providers."""
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
# providers already include device_id options via get_optimal_providers(with_device_options=True)
|
||||
# DO NOT pass device_ids separately - fastembed ignores it when providers is specified
|
||||
# See: fastembed/text/onnx_embedding.py - device_ids is only used with cuda=True
|
||||
try:
|
||||
self._model = TextEmbedding(
|
||||
model_name=self.model_name,
|
||||
providers=self._providers,
|
||||
)
|
||||
logger.debug(f"Model loaded with providers: {self._providers}")
|
||||
except TypeError:
|
||||
# Fallback for older fastembed versions without providers parameter
|
||||
logger.warning(
|
||||
"fastembed version doesn't support 'providers' parameter. "
|
||||
"Upgrade fastembed for GPU acceleration: pip install --upgrade fastembed"
|
||||
)
|
||||
self._model = TextEmbedding(model_name=self.model_name)
|
||||
|
||||
def embed(self, texts: str | Iterable[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (each is a list of floats).
|
||||
|
||||
Note:
|
||||
This method converts numpy arrays to Python lists for backward compatibility.
|
||||
For memory-efficient processing, use embed_to_numpy() instead.
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
embeddings = list(self._model.embed(texts))
|
||||
return [emb.tolist() for emb in embeddings]
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], batch_size: Optional[int] = None) -> np.ndarray:
|
||||
"""Generate embeddings for one or more texts (returns numpy arrays).
|
||||
|
||||
This method is more memory-efficient than embed() as it avoids converting
|
||||
numpy arrays to Python lists, which can significantly reduce memory usage
|
||||
during batch processing.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
batch_size: Optional batch size for fastembed processing.
|
||||
Larger values improve GPU utilization but use more memory.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Pass batch_size to fastembed for optimal GPU utilization
|
||||
# Default batch_size in fastembed is 256, but larger values can improve throughput
|
||||
if batch_size is not None:
|
||||
embeddings = list(self._model.embed(texts, batch_size=batch_size))
|
||||
else:
|
||||
embeddings = list(self._model.embed(texts))
|
||||
return np.array(embeddings)
|
||||
|
||||
def embed_single(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
return self.embed(text)[0]
|
||||
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
158
codex-lens/build/lib/codexlens/semantic/factory.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Factory for creating embedders.
|
||||
|
||||
Provides a unified interface for instantiating different embedder backends.
|
||||
Includes caching to avoid repeated model loading overhead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
# Module-level cache for embedder instances
|
||||
# Key: (backend, profile, model, use_gpu) -> embedder instance
|
||||
_embedder_cache: Dict[tuple, BaseEmbedder] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_embedder(
|
||||
backend: str = "fastembed",
|
||||
profile: str = "code",
|
||||
model: str = "default",
|
||||
use_gpu: bool = True,
|
||||
endpoints: Optional[List[Dict[str, Any]]] = None,
|
||||
strategy: str = "latency_aware",
|
||||
cooldown: float = 60.0,
|
||||
**kwargs: Any,
|
||||
) -> BaseEmbedder:
|
||||
"""Factory function to create embedder based on backend.
|
||||
|
||||
Args:
|
||||
backend: Embedder backend to use. Options:
|
||||
- "fastembed": Use fastembed (ONNX-based) embedder (default)
|
||||
- "litellm": Use ccw-litellm embedder
|
||||
profile: Model profile for fastembed backend ("fast", "code", "multilingual", "balanced")
|
||||
Used only when backend="fastembed". Default: "code"
|
||||
model: Model identifier for litellm backend.
|
||||
Used only when backend="litellm". Default: "default"
|
||||
use_gpu: Whether to use GPU acceleration when available (default: True).
|
||||
Used only when backend="fastembed".
|
||||
endpoints: Optional list of endpoint configurations for multi-endpoint load balancing.
|
||||
Each endpoint is a dict with keys: model, api_key, api_base, weight.
|
||||
Used only when backend="litellm" and multiple endpoints provided.
|
||||
strategy: Selection strategy for multi-endpoint mode:
|
||||
"round_robin", "latency_aware", "weighted_random".
|
||||
Default: "latency_aware"
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints (default: 60.0)
|
||||
**kwargs: Additional backend-specific arguments
|
||||
|
||||
Returns:
|
||||
BaseEmbedder: Configured embedder instance
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized
|
||||
ImportError: If required backend dependencies are not installed
|
||||
|
||||
Examples:
|
||||
Create fastembed embedder with code profile:
|
||||
>>> embedder = get_embedder(backend="fastembed", profile="code")
|
||||
|
||||
Create fastembed embedder with fast profile and CPU only:
|
||||
>>> embedder = get_embedder(backend="fastembed", profile="fast", use_gpu=False)
|
||||
|
||||
Create litellm embedder:
|
||||
>>> embedder = get_embedder(backend="litellm", model="text-embedding-3-small")
|
||||
|
||||
Create rotational embedder with multiple endpoints:
|
||||
>>> endpoints = [
|
||||
... {"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
... {"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
... ]
|
||||
>>> embedder = get_embedder(backend="litellm", endpoints=endpoints)
|
||||
"""
|
||||
# Build cache key from immutable configuration
|
||||
if backend == "fastembed":
|
||||
cache_key = ("fastembed", profile, None, use_gpu)
|
||||
elif backend == "litellm":
|
||||
# For litellm, use model as part of cache key
|
||||
# Multi-endpoint mode is not cached as it's more complex
|
||||
if endpoints and len(endpoints) > 1:
|
||||
cache_key = None # Skip cache for multi-endpoint
|
||||
else:
|
||||
effective_model = endpoints[0]["model"] if endpoints else model
|
||||
cache_key = ("litellm", None, effective_model, None)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
# Check cache first (thread-safe)
|
||||
if cache_key is not None:
|
||||
with _cache_lock:
|
||||
if cache_key in _embedder_cache:
|
||||
_logger.debug("Returning cached embedder for %s", cache_key)
|
||||
return _embedder_cache[cache_key]
|
||||
|
||||
# Create new embedder instance
|
||||
embedder: Optional[BaseEmbedder] = None
|
||||
|
||||
if backend == "fastembed":
|
||||
from .embedder import Embedder
|
||||
embedder = Embedder(profile=profile, use_gpu=use_gpu, **kwargs)
|
||||
elif backend == "litellm":
|
||||
# Check if multi-endpoint mode is requested
|
||||
if endpoints and len(endpoints) > 1:
|
||||
from .rotational_embedder import create_rotational_embedder
|
||||
# Multi-endpoint is not cached
|
||||
return create_rotational_embedder(
|
||||
endpoints_config=endpoints,
|
||||
strategy=strategy,
|
||||
default_cooldown=cooldown,
|
||||
)
|
||||
elif endpoints and len(endpoints) == 1:
|
||||
# Single endpoint in list - use it directly
|
||||
ep = endpoints[0]
|
||||
ep_kwargs = {**kwargs}
|
||||
if "api_key" in ep:
|
||||
ep_kwargs["api_key"] = ep["api_key"]
|
||||
if "api_base" in ep:
|
||||
ep_kwargs["api_base"] = ep["api_base"]
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
embedder = LiteLLMEmbedderWrapper(model=ep["model"], **ep_kwargs)
|
||||
else:
|
||||
# No endpoints list - use model parameter
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
embedder = LiteLLMEmbedderWrapper(model=model, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Supported backends: 'fastembed', 'litellm'"
|
||||
)
|
||||
|
||||
# Cache the embedder for future use (thread-safe)
|
||||
if cache_key is not None and embedder is not None:
|
||||
with _cache_lock:
|
||||
# Double-check to avoid race condition
|
||||
if cache_key not in _embedder_cache:
|
||||
_embedder_cache[cache_key] = embedder
|
||||
_logger.debug("Cached new embedder for %s", cache_key)
|
||||
else:
|
||||
# Another thread created it already, use that one
|
||||
embedder = _embedder_cache[cache_key]
|
||||
|
||||
return embedder # type: ignore
|
||||
|
||||
|
||||
def clear_embedder_cache() -> int:
|
||||
"""Clear the embedder cache.
|
||||
|
||||
Returns:
|
||||
Number of embedders cleared from cache
|
||||
"""
|
||||
with _cache_lock:
|
||||
count = len(_embedder_cache)
|
||||
_embedder_cache.clear()
|
||||
_logger.debug("Cleared %d embedders from cache", count)
|
||||
return count
|
||||
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
431
codex-lens/build/lib/codexlens/semantic/gpu_support.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""GPU acceleration support for semantic embeddings.
|
||||
|
||||
This module provides GPU detection, initialization, and fallback handling
|
||||
for ONNX-based embedding generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUDevice:
|
||||
"""Individual GPU device info."""
|
||||
device_id: int
|
||||
name: str
|
||||
is_discrete: bool # True for discrete GPU (NVIDIA, AMD), False for integrated (Intel UHD)
|
||||
vendor: str # "nvidia", "amd", "intel", "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUInfo:
|
||||
"""GPU availability and configuration info."""
|
||||
|
||||
gpu_available: bool = False
|
||||
cuda_available: bool = False
|
||||
gpu_count: int = 0
|
||||
gpu_name: Optional[str] = None
|
||||
onnx_providers: List[str] = None
|
||||
devices: List[GPUDevice] = None # List of detected GPU devices
|
||||
preferred_device_id: Optional[int] = None # Preferred GPU for embedding
|
||||
|
||||
def __post_init__(self):
|
||||
if self.onnx_providers is None:
|
||||
self.onnx_providers = ["CPUExecutionProvider"]
|
||||
if self.devices is None:
|
||||
self.devices = []
|
||||
|
||||
|
||||
_gpu_info_cache: Optional[GPUInfo] = None
|
||||
|
||||
|
||||
def _enumerate_gpus() -> List[GPUDevice]:
|
||||
"""Enumerate available GPU devices using WMI on Windows.
|
||||
|
||||
Returns:
|
||||
List of GPUDevice with device info, ordered by device_id.
|
||||
"""
|
||||
devices = []
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform == "win32":
|
||||
# Use PowerShell to query GPU information via WMI
|
||||
cmd = [
|
||||
"powershell", "-NoProfile", "-Command",
|
||||
"Get-WmiObject Win32_VideoController | Select-Object DeviceID, Name, AdapterCompatibility | ConvertTo-Json"
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
import json
|
||||
gpu_data = json.loads(result.stdout)
|
||||
|
||||
# Handle single GPU case (returns dict instead of list)
|
||||
if isinstance(gpu_data, dict):
|
||||
gpu_data = [gpu_data]
|
||||
|
||||
for idx, gpu in enumerate(gpu_data):
|
||||
name = gpu.get("Name", "Unknown GPU")
|
||||
compat = gpu.get("AdapterCompatibility", "").lower()
|
||||
|
||||
# Determine vendor
|
||||
name_lower = name.lower()
|
||||
if "nvidia" in name_lower or "nvidia" in compat:
|
||||
vendor = "nvidia"
|
||||
is_discrete = True
|
||||
elif "amd" in name_lower or "radeon" in name_lower or "amd" in compat:
|
||||
vendor = "amd"
|
||||
is_discrete = True
|
||||
elif "intel" in name_lower or "intel" in compat:
|
||||
vendor = "intel"
|
||||
# Intel UHD/Iris are integrated, Intel Arc is discrete
|
||||
is_discrete = "arc" in name_lower
|
||||
else:
|
||||
vendor = "unknown"
|
||||
is_discrete = False
|
||||
|
||||
devices.append(GPUDevice(
|
||||
device_id=idx,
|
||||
name=name,
|
||||
is_discrete=is_discrete,
|
||||
vendor=vendor
|
||||
))
|
||||
logger.debug(f"Detected GPU {idx}: {name} (vendor={vendor}, discrete={is_discrete})")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"GPU enumeration failed: {e}")
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def _get_preferred_device_id(devices: List[GPUDevice]) -> Optional[int]:
|
||||
"""Determine the preferred GPU device_id for embedding.
|
||||
|
||||
Preference order:
|
||||
1. NVIDIA discrete GPU (best DirectML/CUDA support)
|
||||
2. AMD discrete GPU
|
||||
3. Intel Arc (discrete)
|
||||
4. Intel integrated (fallback)
|
||||
|
||||
Returns:
|
||||
device_id of preferred GPU, or None to use default.
|
||||
"""
|
||||
if not devices:
|
||||
return None
|
||||
|
||||
# Priority: NVIDIA > AMD > Intel Arc > Intel integrated
|
||||
priority_order = [
|
||||
("nvidia", True), # NVIDIA discrete
|
||||
("amd", True), # AMD discrete
|
||||
("intel", True), # Intel Arc (discrete)
|
||||
("intel", False), # Intel integrated (fallback)
|
||||
]
|
||||
|
||||
for target_vendor, target_discrete in priority_order:
|
||||
for device in devices:
|
||||
if device.vendor == target_vendor and device.is_discrete == target_discrete:
|
||||
logger.info(f"Preferred GPU: {device.name} (device_id={device.device_id})")
|
||||
return device.device_id
|
||||
|
||||
# If no match, use first device
|
||||
if devices:
|
||||
return devices[0].device_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def detect_gpu(force_refresh: bool = False) -> GPUInfo:
|
||||
"""Detect available GPU resources for embedding acceleration.
|
||||
|
||||
Args:
|
||||
force_refresh: If True, re-detect GPU even if cached.
|
||||
|
||||
Returns:
|
||||
GPUInfo with detection results.
|
||||
"""
|
||||
global _gpu_info_cache
|
||||
|
||||
if _gpu_info_cache is not None and not force_refresh:
|
||||
return _gpu_info_cache
|
||||
|
||||
info = GPUInfo()
|
||||
|
||||
# Enumerate GPU devices first
|
||||
info.devices = _enumerate_gpus()
|
||||
info.gpu_count = len(info.devices)
|
||||
if info.devices:
|
||||
# Set preferred device (discrete GPU preferred over integrated)
|
||||
info.preferred_device_id = _get_preferred_device_id(info.devices)
|
||||
# Set gpu_name to preferred device name
|
||||
for dev in info.devices:
|
||||
if dev.device_id == info.preferred_device_id:
|
||||
info.gpu_name = dev.name
|
||||
break
|
||||
|
||||
# Check PyTorch CUDA availability (most reliable detection)
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
info.cuda_available = True
|
||||
info.gpu_available = True
|
||||
info.gpu_count = torch.cuda.device_count()
|
||||
if info.gpu_count > 0:
|
||||
info.gpu_name = torch.cuda.get_device_name(0)
|
||||
logger.debug(f"PyTorch CUDA detected: {info.gpu_count} GPU(s)")
|
||||
except ImportError:
|
||||
logger.debug("PyTorch not available for GPU detection")
|
||||
|
||||
# Check ONNX Runtime providers with validation
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
|
||||
# Build provider list with priority order
|
||||
providers = []
|
||||
|
||||
# Test each provider to ensure it actually works
|
||||
def test_provider(provider_name: str) -> bool:
|
||||
"""Test if a provider actually works by creating a dummy session."""
|
||||
try:
|
||||
# Create a minimal ONNX model to test provider
|
||||
import numpy as np
|
||||
# Simple test: just check if provider can be instantiated
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.log_severity_level = 4 # Suppress warnings
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# CUDA provider (NVIDIA GPU) - check if CUDA runtime is available
|
||||
if "CUDAExecutionProvider" in available_providers:
|
||||
# Verify CUDA is actually usable by checking for cuBLAS
|
||||
cuda_works = False
|
||||
try:
|
||||
import ctypes
|
||||
# Try to load cuBLAS to verify CUDA installation
|
||||
try:
|
||||
ctypes.CDLL("cublas64_12.dll")
|
||||
cuda_works = True
|
||||
except OSError:
|
||||
try:
|
||||
ctypes.CDLL("cublas64_11.dll")
|
||||
cuda_works = True
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if cuda_works:
|
||||
providers.append("CUDAExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX CUDAExecutionProvider available and working")
|
||||
else:
|
||||
logger.debug("ONNX CUDAExecutionProvider listed but CUDA runtime not found")
|
||||
|
||||
# TensorRT provider (optimized NVIDIA inference)
|
||||
if "TensorrtExecutionProvider" in available_providers:
|
||||
# TensorRT requires additional libraries, skip for now
|
||||
logger.debug("ONNX TensorrtExecutionProvider available (requires TensorRT SDK)")
|
||||
|
||||
# DirectML provider (Windows GPU - AMD/Intel/NVIDIA)
|
||||
if "DmlExecutionProvider" in available_providers:
|
||||
providers.append("DmlExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX DmlExecutionProvider available (DirectML)")
|
||||
|
||||
# ROCm provider (AMD GPU on Linux)
|
||||
if "ROCMExecutionProvider" in available_providers:
|
||||
providers.append("ROCMExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX ROCMExecutionProvider available (AMD)")
|
||||
|
||||
# CoreML provider (Apple Silicon)
|
||||
if "CoreMLExecutionProvider" in available_providers:
|
||||
providers.append("CoreMLExecutionProvider")
|
||||
info.gpu_available = True
|
||||
logger.debug("ONNX CoreMLExecutionProvider available (Apple)")
|
||||
|
||||
# Always include CPU as fallback
|
||||
providers.append("CPUExecutionProvider")
|
||||
|
||||
info.onnx_providers = providers
|
||||
|
||||
except ImportError:
|
||||
logger.debug("ONNX Runtime not available")
|
||||
info.onnx_providers = ["CPUExecutionProvider"]
|
||||
|
||||
_gpu_info_cache = info
|
||||
return info
|
||||
|
||||
|
||||
def get_optimal_providers(use_gpu: bool = True, with_device_options: bool = False) -> list:
|
||||
"""Get optimal ONNX execution providers based on availability.
|
||||
|
||||
Args:
|
||||
use_gpu: If True, include GPU providers when available.
|
||||
If False, force CPU-only execution.
|
||||
with_device_options: If True, return providers as tuples with device_id options
|
||||
for proper GPU device selection (required for DirectML).
|
||||
|
||||
Returns:
|
||||
List of provider names or tuples (provider_name, options_dict) in priority order.
|
||||
"""
|
||||
if not use_gpu:
|
||||
return ["CPUExecutionProvider"]
|
||||
|
||||
gpu_info = detect_gpu()
|
||||
|
||||
# Check if GPU was requested but not available - log warning
|
||||
if not gpu_info.gpu_available:
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
available_providers = ort.get_available_providers()
|
||||
except ImportError:
|
||||
available_providers = []
|
||||
logger.warning(
|
||||
"GPU acceleration was requested, but no supported GPU provider (CUDA, DirectML) "
|
||||
f"was found. Available providers: {available_providers}. Falling back to CPU."
|
||||
)
|
||||
else:
|
||||
# Log which GPU provider is being used
|
||||
gpu_providers = [p for p in gpu_info.onnx_providers if p != "CPUExecutionProvider"]
|
||||
if gpu_providers:
|
||||
logger.info(f"Using {gpu_providers[0]} for ONNX GPU acceleration")
|
||||
|
||||
if not with_device_options:
|
||||
return gpu_info.onnx_providers
|
||||
|
||||
# Build providers with device_id options for GPU providers
|
||||
device_id = get_selected_device_id()
|
||||
providers = []
|
||||
|
||||
for provider in gpu_info.onnx_providers:
|
||||
if provider == "DmlExecutionProvider" and device_id is not None:
|
||||
# DirectML requires device_id in provider_options tuple
|
||||
providers.append(("DmlExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"DmlExecutionProvider configured with device_id={device_id}")
|
||||
elif provider == "CUDAExecutionProvider" and device_id is not None:
|
||||
# CUDA also supports device_id in provider_options
|
||||
providers.append(("CUDAExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"CUDAExecutionProvider configured with device_id={device_id}")
|
||||
elif provider == "ROCMExecutionProvider" and device_id is not None:
|
||||
# ROCm supports device_id
|
||||
providers.append(("ROCMExecutionProvider", {"device_id": device_id}))
|
||||
logger.debug(f"ROCMExecutionProvider configured with device_id={device_id}")
|
||||
else:
|
||||
# CPU and other providers don't need device_id
|
||||
providers.append(provider)
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
def is_gpu_available() -> bool:
|
||||
"""Check if any GPU acceleration is available."""
|
||||
return detect_gpu().gpu_available
|
||||
|
||||
|
||||
def get_gpu_summary() -> str:
|
||||
"""Get human-readable GPU status summary."""
|
||||
info = detect_gpu()
|
||||
|
||||
if not info.gpu_available:
|
||||
return "GPU: Not available (using CPU)"
|
||||
|
||||
parts = []
|
||||
if info.gpu_name:
|
||||
parts.append(f"GPU: {info.gpu_name}")
|
||||
if info.gpu_count > 1:
|
||||
parts.append(f"({info.gpu_count} devices)")
|
||||
|
||||
# Show active providers (excluding CPU fallback)
|
||||
gpu_providers = [p for p in info.onnx_providers if p != "CPUExecutionProvider"]
|
||||
if gpu_providers:
|
||||
parts.append(f"Providers: {', '.join(gpu_providers)}")
|
||||
|
||||
return " | ".join(parts) if parts else "GPU: Available"
|
||||
|
||||
|
||||
def clear_gpu_cache() -> None:
|
||||
"""Clear cached GPU detection info."""
|
||||
global _gpu_info_cache
|
||||
_gpu_info_cache = None
|
||||
|
||||
|
||||
# User-selected device ID (overrides auto-detection)
|
||||
_selected_device_id: Optional[int] = None
|
||||
|
||||
|
||||
def get_gpu_devices() -> List[dict]:
|
||||
"""Get list of available GPU devices for frontend selection.
|
||||
|
||||
Returns:
|
||||
List of dicts with device info for each GPU.
|
||||
"""
|
||||
info = detect_gpu()
|
||||
devices = []
|
||||
|
||||
for dev in info.devices:
|
||||
devices.append({
|
||||
"device_id": dev.device_id,
|
||||
"name": dev.name,
|
||||
"vendor": dev.vendor,
|
||||
"is_discrete": dev.is_discrete,
|
||||
"is_preferred": dev.device_id == info.preferred_device_id,
|
||||
"is_selected": dev.device_id == get_selected_device_id(),
|
||||
})
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def get_selected_device_id() -> Optional[int]:
|
||||
"""Get the user-selected GPU device_id.
|
||||
|
||||
Returns:
|
||||
User-selected device_id, or auto-detected preferred device_id if not set.
|
||||
"""
|
||||
global _selected_device_id
|
||||
|
||||
if _selected_device_id is not None:
|
||||
return _selected_device_id
|
||||
|
||||
# Fall back to auto-detected preferred device
|
||||
info = detect_gpu()
|
||||
return info.preferred_device_id
|
||||
|
||||
|
||||
def set_selected_device_id(device_id: Optional[int]) -> bool:
|
||||
"""Set the GPU device_id to use for embeddings.
|
||||
|
||||
Args:
|
||||
device_id: GPU device_id to use, or None to use auto-detection.
|
||||
|
||||
Returns:
|
||||
True if device_id is valid, False otherwise.
|
||||
"""
|
||||
global _selected_device_id
|
||||
|
||||
if device_id is None:
|
||||
_selected_device_id = None
|
||||
logger.info("GPU selection reset to auto-detection")
|
||||
return True
|
||||
|
||||
# Validate device_id exists
|
||||
info = detect_gpu()
|
||||
valid_ids = [dev.device_id for dev in info.devices]
|
||||
|
||||
if device_id in valid_ids:
|
||||
_selected_device_id = device_id
|
||||
device_name = next((dev.name for dev in info.devices if dev.device_id == device_id), "Unknown")
|
||||
logger.info(f"GPU selection set to device {device_id}: {device_name}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Invalid device_id {device_id}. Valid IDs: {valid_ids}")
|
||||
return False
|
||||
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
144
codex-lens/build/lib/codexlens/semantic/litellm_embedder.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""LiteLLM embedder wrapper for CodexLens.
|
||||
|
||||
Provides integration with ccw-litellm's LiteLLMEmbedder for embedding generation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
|
||||
class LiteLLMEmbedderWrapper(BaseEmbedder):
|
||||
"""Wrapper for ccw-litellm LiteLLMEmbedder.
|
||||
|
||||
This wrapper adapts the ccw-litellm LiteLLMEmbedder to the CodexLens
|
||||
BaseEmbedder interface, enabling seamless integration with CodexLens
|
||||
semantic search functionality.
|
||||
|
||||
Args:
|
||||
model: Model identifier for LiteLLM (default: "default")
|
||||
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm package is not installed
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "default", **kwargs) -> None:
|
||||
"""Initialize LiteLLM embedder wrapper.
|
||||
|
||||
Args:
|
||||
model: Model identifier for LiteLLM (default: "default")
|
||||
**kwargs: Additional arguments passed to LiteLLMEmbedder
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm package is not installed
|
||||
"""
|
||||
try:
|
||||
from ccw_litellm import LiteLLMEmbedder
|
||||
self._embedder = LiteLLMEmbedder(model=model, **kwargs)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions from LiteLLMEmbedder.
|
||||
|
||||
Returns:
|
||||
int: Dimension of the embedding vectors.
|
||||
"""
|
||||
return self._embedder.dimensions
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name from LiteLLMEmbedder.
|
||||
|
||||
Returns:
|
||||
str: Name or identifier of the underlying model.
|
||||
"""
|
||||
return self._embedder.model_name
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit for the embedding model.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be embedded at once.
|
||||
Reads from LiteLLM config's max_input_tokens property.
|
||||
"""
|
||||
# Get from LiteLLM embedder's max_input_tokens property (now exposed)
|
||||
if hasattr(self._embedder, 'max_input_tokens'):
|
||||
return self._embedder.max_input_tokens
|
||||
|
||||
# Fallback: infer from model name
|
||||
model_name_lower = self.model_name.lower()
|
||||
|
||||
# Large models (8B or "large" in name)
|
||||
if '8b' in model_name_lower or 'large' in model_name_lower:
|
||||
return 32768
|
||||
|
||||
# OpenAI text-embedding-3-* models
|
||||
if 'text-embedding-3' in model_name_lower:
|
||||
return 8191
|
||||
|
||||
# Default fallback
|
||||
return 8192
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
"""Sanitize text to work around ModelScope API routing bug.
|
||||
|
||||
ModelScope incorrectly routes text starting with lowercase 'import'
|
||||
to an Ollama endpoint, causing failures. This adds a leading space
|
||||
to work around the issue without affecting embedding quality.
|
||||
|
||||
Args:
|
||||
text: Text to sanitize.
|
||||
|
||||
Returns:
|
||||
Sanitized text safe for embedding API.
|
||||
"""
|
||||
if text.startswith('import'):
|
||||
return ' ' + text
|
||||
return text
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts to numpy array using LiteLLMEmbedder.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments (ignored for LiteLLM backend).
|
||||
Accepts batch_size for API compatibility with fastembed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
# Sanitize texts to avoid ModelScope routing bug
|
||||
texts = [self._sanitize_text(t) for t in texts]
|
||||
|
||||
# LiteLLM handles batching internally, ignore batch_size parameter
|
||||
return self._embedder.embed(texts)
|
||||
|
||||
def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
list[float]: Embedding vector as a list of floats.
|
||||
"""
|
||||
# Sanitize text before embedding
|
||||
sanitized = self._sanitize_text(text)
|
||||
embedding = self._embedder.embed([sanitized])
|
||||
return embedding[0].tolist()
|
||||
|
||||
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
25
codex-lens/build/lib/codexlens/semantic/reranker/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Reranker backends for second-stage search ranking.
|
||||
|
||||
This subpackage provides a unified interface and factory for different reranking
|
||||
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import BaseReranker
|
||||
from .factory import check_reranker_available, get_reranker
|
||||
from .fastembed_reranker import FastEmbedReranker, check_fastembed_reranker_available
|
||||
from .legacy import CrossEncoderReranker, check_cross_encoder_available
|
||||
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
|
||||
|
||||
__all__ = [
|
||||
"BaseReranker",
|
||||
"check_reranker_available",
|
||||
"get_reranker",
|
||||
"CrossEncoderReranker",
|
||||
"check_cross_encoder_available",
|
||||
"FastEmbedReranker",
|
||||
"check_fastembed_reranker_available",
|
||||
"ONNXReranker",
|
||||
"check_onnx_reranker_available",
|
||||
]
|
||||
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
403
codex-lens/build/lib/codexlens/semantic/reranker/api_reranker.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""API-based reranker using a remote HTTP provider.
|
||||
|
||||
Supported providers:
|
||||
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
|
||||
- Cohere: https://api.cohere.ai/v1/rerank
|
||||
- Jina: https://api.jina.ai/v1/rerank
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
|
||||
|
||||
|
||||
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback."""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Try loading from .env files
|
||||
try:
|
||||
from codexlens.env_config import get_env
|
||||
return get_env(key, workspace_root=workspace_root)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def check_httpx_available() -> tuple[bool, str | None]:
|
||||
try:
|
||||
import httpx # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"httpx not available: {exc}. Install with: pip install httpx"
|
||||
return True, None
|
||||
|
||||
|
||||
class APIReranker(BaseReranker):
|
||||
"""Reranker backed by a remote reranking HTTP API."""
|
||||
|
||||
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
|
||||
"siliconflow": {
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "BAAI/bge-reranker-v2-m3",
|
||||
},
|
||||
"cohere": {
|
||||
"api_base": "https://api.cohere.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "rerank-english-v3.0",
|
||||
},
|
||||
"jina": {
|
||||
"api_base": "https://api.jina.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "jina-reranker-v2-base-multilingual",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str = "siliconflow",
|
||||
model_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
backoff_base_s: float = 0.5,
|
||||
backoff_max_s: float = 8.0,
|
||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||
workspace_root: Path | str | None = None,
|
||||
max_input_tokens: int | None = None,
|
||||
) -> None:
|
||||
ok, err = check_httpx_available()
|
||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||
raise ImportError(err)
|
||||
|
||||
import httpx
|
||||
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
|
||||
self.provider = (provider or "").strip().lower()
|
||||
if self.provider not in self._PROVIDER_DEFAULTS:
|
||||
raise ValueError(
|
||||
f"Unknown reranker provider: {provider}. "
|
||||
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
|
||||
)
|
||||
|
||||
defaults = self._PROVIDER_DEFAULTS[self.provider]
|
||||
|
||||
# Load api_base from env with .env fallback
|
||||
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
|
||||
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
|
||||
self.endpoint = defaults["endpoint"]
|
||||
|
||||
# Load model from env with .env fallback
|
||||
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
|
||||
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
# Load API key from env with .env fallback
|
||||
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
|
||||
resolved_key = resolved_key.strip()
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
f"Missing API key for reranker provider '{self.provider}'. "
|
||||
f"Pass api_key=... or set ${env_api_key}."
|
||||
)
|
||||
self._api_key = resolved_key
|
||||
|
||||
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
|
||||
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
|
||||
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
|
||||
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.provider == "cohere":
|
||||
headers.setdefault("Cohere-Version", "2022-12-06")
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.api_base,
|
||||
headers=headers,
|
||||
timeout=self.timeout_s,
|
||||
)
|
||||
|
||||
# Store max_input_tokens with model-aware defaults
|
||||
if max_input_tokens is not None:
|
||||
self._max_input_tokens = max_input_tokens
|
||||
else:
|
||||
# Infer from model name
|
||||
model_lower = self.model_name.lower()
|
||||
if '8b' in model_lower or 'large' in model_lower:
|
||||
self._max_input_tokens = 32768
|
||||
else:
|
||||
self._max_input_tokens = 8192
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking."""
|
||||
return self._max_input_tokens
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return
|
||||
|
||||
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
|
||||
if retry_after_s is not None and retry_after_s > 0:
|
||||
time.sleep(min(float(retry_after_s), self.backoff_max_s))
|
||||
return
|
||||
|
||||
exp = self.backoff_base_s * (2**attempt)
|
||||
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
|
||||
time.sleep(min(self.backoff_max_s, exp + jitter))
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
|
||||
value = (headers.get("Retry-After") or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_status(status_code: int) -> bool:
|
||||
return status_code == 429 or 500 <= status_code <= 599
|
||||
|
||||
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = self._client.post(self.endpoint, json=dict(payload))
|
||||
except Exception as exc: # httpx is optional at import-time
|
||||
last_exc = exc
|
||||
if attempt < self.max_retries:
|
||||
self._sleep_backoff(attempt)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' after "
|
||||
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
status = int(getattr(response, "status_code", 0) or 0)
|
||||
if status >= 400:
|
||||
body_preview = ""
|
||||
try:
|
||||
body_preview = (response.text or "").strip()
|
||||
except Exception:
|
||||
body_preview = ""
|
||||
if len(body_preview) > 300:
|
||||
body_preview = body_preview[:300] + "…"
|
||||
|
||||
if self._should_retry_status(status) and attempt < self.max_retries:
|
||||
retry_after = self._parse_retry_after_seconds(response.headers)
|
||||
logger.warning(
|
||||
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
|
||||
self.api_base,
|
||||
self.endpoint,
|
||||
status,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
)
|
||||
self._sleep_backoff(attempt, retry_after_s=retry_after)
|
||||
continue
|
||||
|
||||
if status in {401, 403}:
|
||||
raise RuntimeError(
|
||||
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
|
||||
"Check your API key."
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
|
||||
f"Response: {body_preview or '<empty>'}"
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' is not valid JSON: "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' must be a JSON object; "
|
||||
f"got {type(data).__name__}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(expected)]
|
||||
filled = 0
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score"))
|
||||
if idx is None or score is None:
|
||||
continue
|
||||
try:
|
||||
idx_int = int(idx)
|
||||
score_f = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if 0 <= idx_int < expected:
|
||||
scores[idx_int] = score_f
|
||||
filled += 1
|
||||
|
||||
if filled != expected:
|
||||
raise RuntimeError(
|
||||
f"Rerank response contained {filled}/{expected} scored documents; "
|
||||
"ensure top_n matches the number of documents."
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": list(documents),
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
return payload
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""Estimate token count using fast heuristic.
|
||||
|
||||
Uses len(text) // 4 as approximation (~4 chars per token for English).
|
||||
Not perfectly accurate for all models/languages but sufficient for
|
||||
batch sizing decisions where exact counts aren't critical.
|
||||
"""
|
||||
return len(text) // 4
|
||||
|
||||
def _create_token_aware_batches(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
) -> list[list[tuple[int, str]]]:
|
||||
"""Split documents into batches that fit within token limits.
|
||||
|
||||
Uses 90% of max_input_tokens as safety margin.
|
||||
Each batch includes the query tokens overhead.
|
||||
"""
|
||||
max_tokens = int(self._max_input_tokens * 0.9)
|
||||
query_tokens = self._estimate_tokens(query)
|
||||
|
||||
batches: list[list[tuple[int, str]]] = []
|
||||
current_batch: list[tuple[int, str]] = []
|
||||
current_tokens = query_tokens # Start with query overhead
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_tokens = self._estimate_tokens(doc)
|
||||
|
||||
# Warn if single document exceeds token limit (will be truncated by API)
|
||||
if doc_tokens > max_tokens - query_tokens:
|
||||
logger.warning(
|
||||
f"Document {idx} exceeds token limit: ~{doc_tokens} tokens "
|
||||
f"(limit: {max_tokens - query_tokens} after query overhead). "
|
||||
"Document will likely be truncated by the API."
|
||||
)
|
||||
|
||||
# If batch would exceed limit, start new batch
|
||||
if current_tokens + doc_tokens > max_tokens and current_batch:
|
||||
batches.append(current_batch)
|
||||
current_batch = []
|
||||
current_tokens = query_tokens
|
||||
|
||||
current_batch.append((idx, doc))
|
||||
current_tokens += doc_tokens
|
||||
|
||||
if current_batch:
|
||||
batches.append(current_batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# Create token-aware batches
|
||||
batches = self._create_token_aware_batches(query, documents)
|
||||
|
||||
if len(batches) == 1:
|
||||
# Single batch - original behavior
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
|
||||
# Multiple batches - process each and merge results
|
||||
logger.info(
|
||||
f"Splitting {len(documents)} documents into {len(batches)} batches "
|
||||
f"(max_input_tokens: {self._max_input_tokens})"
|
||||
)
|
||||
|
||||
all_scores: list[float] = [0.0] * len(documents)
|
||||
|
||||
for batch in batches:
|
||||
batch_docs = [doc for _, doc in batch]
|
||||
payload = self._build_payload(query=query, documents=batch_docs)
|
||||
data = self._request_json(payload)
|
||||
results = data.get("results")
|
||||
batch_scores = self._extract_scores_from_results(results, expected=len(batch_docs))
|
||||
|
||||
# Map scores back to original indices
|
||||
for (orig_idx, _), score in zip(batch, batch_scores):
|
||||
all_scores[orig_idx] = score
|
||||
|
||||
return all_scores
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
|
||||
) -> list[float]:
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
grouped.setdefault(str(query), []).append((idx, str(doc)))
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(len(pairs))]
|
||||
|
||||
for query, items in grouped.items():
|
||||
documents = [doc for _, doc in items]
|
||||
query_scores = self._rerank_one_query(query=query, documents=documents)
|
||||
for (orig_idx, _), score in zip(items, query_scores):
|
||||
scores[orig_idx] = float(score)
|
||||
|
||||
return scores
|
||||
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
46
codex-lens/build/lib/codexlens/semantic/reranker/base.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Base class for rerankers.
|
||||
|
||||
Defines the interface that all rerankers must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
"""Base class for all rerankers.
|
||||
|
||||
All reranker implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int:
|
||||
"""Return maximum token limit for reranking.
|
||||
|
||||
Returns:
|
||||
int: Maximum number of tokens that can be processed at once.
|
||||
Default is 8192 if not overridden by implementation.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
@abstractmethod
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair).
|
||||
"""
|
||||
...
|
||||
|
||||
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
159
codex-lens/build/lib/codexlens/semantic/reranker/factory.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Factory for creating rerankers.
|
||||
|
||||
Provides a unified interface for instantiating different reranker backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
|
||||
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific reranker backend can be used.
|
||||
|
||||
Notes:
|
||||
- "fastembed" uses fastembed TextCrossEncoder (pip install fastembed>=0.4.0). [Recommended]
|
||||
- "onnx" redirects to "fastembed" for backward compatibility.
|
||||
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
|
||||
- "api" uses a remote reranking HTTP API (requires httpx).
|
||||
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "legacy":
|
||||
from .legacy import check_cross_encoder_available
|
||||
|
||||
return check_cross_encoder_available()
|
||||
|
||||
if backend == "fastembed":
|
||||
from .fastembed_reranker import check_fastembed_reranker_available
|
||||
|
||||
return check_fastembed_reranker_available()
|
||||
|
||||
if backend == "onnx":
|
||||
# Redirect to fastembed for backward compatibility
|
||||
from .fastembed_reranker import check_fastembed_reranker_available
|
||||
|
||||
return check_fastembed_reranker_available()
|
||||
|
||||
if backend == "litellm":
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
|
||||
)
|
||||
|
||||
try:
|
||||
from .litellm_reranker import LiteLLMReranker # noqa: F401
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return False, f"LiteLLM reranker backend not available: {exc}"
|
||||
|
||||
return True, None
|
||||
|
||||
if backend == "api":
|
||||
from .api_reranker import check_httpx_available
|
||||
|
||||
return check_httpx_available()
|
||||
|
||||
return False, (
|
||||
f"Invalid reranker backend: {backend}. "
|
||||
"Must be 'fastembed', 'onnx', 'api', 'litellm', or 'legacy'."
|
||||
)
|
||||
|
||||
|
||||
def get_reranker(
|
||||
backend: str = "fastembed",
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
device: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseReranker:
|
||||
"""Factory function to create reranker based on backend.
|
||||
|
||||
Args:
|
||||
backend: Reranker backend to use. Options:
|
||||
- "fastembed": FastEmbed TextCrossEncoder backend (default, recommended)
|
||||
- "onnx": Redirects to fastembed for backward compatibility
|
||||
- "api": HTTP API backend (remote providers)
|
||||
- "litellm": LiteLLM backend (LLM-based, for API mode)
|
||||
- "legacy": sentence-transformers CrossEncoder backend (optional)
|
||||
model_name: Model identifier for model-based backends. Defaults depend on backend:
|
||||
- fastembed: Xenova/ms-marco-MiniLM-L-6-v2
|
||||
- onnx: (redirects to fastembed)
|
||||
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
|
||||
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
- litellm: default
|
||||
device: Optional device string for backends that support it (legacy only).
|
||||
**kwargs: Additional backend-specific arguments.
|
||||
|
||||
Returns:
|
||||
BaseReranker: Configured reranker instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized.
|
||||
ImportError: If required backend dependencies are not installed or backend is unavailable.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "fastembed":
|
||||
ok, err = check_reranker_available("fastembed")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .fastembed_reranker import FastEmbedReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via fastembed providers.
|
||||
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "onnx":
|
||||
# Redirect to fastembed for backward compatibility
|
||||
ok, err = check_reranker_available("fastembed")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .fastembed_reranker import FastEmbedReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or FastEmbedReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via fastembed providers.
|
||||
return FastEmbedReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "legacy":
|
||||
ok, err = check_reranker_available("legacy")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .legacy import CrossEncoderReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
|
||||
|
||||
if backend == "litellm":
|
||||
ok, err = check_reranker_available("litellm")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .litellm_reranker import LiteLLMReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote LLM backends.
|
||||
resolved_model_name = (model_name or "").strip() or "default"
|
||||
return LiteLLMReranker(model=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "api":
|
||||
ok, err = check_reranker_available("api")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .api_reranker import APIReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote HTTP backends.
|
||||
resolved_model_name = (model_name or "").strip() or None
|
||||
return APIReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. Supported backends: 'fastembed', 'onnx', 'api', 'litellm', 'legacy'"
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
"""FastEmbed-based reranker backend.
|
||||
|
||||
This reranker uses fastembed's TextCrossEncoder for cross-encoder reranking.
|
||||
FastEmbed is ONNX-based internally but provides a cleaner, unified API.
|
||||
|
||||
Install:
|
||||
pip install fastembed>=0.4.0
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_fastembed_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether fastembed reranker dependencies are available."""
|
||||
try:
|
||||
import fastembed # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed not available: {exc}. Install with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
try:
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"fastembed TextCrossEncoder not available: {exc}. "
|
||||
"Upgrade with: pip install fastembed>=0.4.0",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
class FastEmbedReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using fastembed's TextCrossEncoder with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
# Alternative models supported by fastembed:
|
||||
# - "BAAI/bge-reranker-base"
|
||||
# - "BAAI/bge-reranker-large"
|
||||
# - "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
cache_dir: str | None = None,
|
||||
threads: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize FastEmbed reranker.
|
||||
|
||||
Args:
|
||||
model_name: Model identifier. Defaults to Xenova/ms-marco-MiniLM-L-6-v2.
|
||||
use_gpu: Whether to use GPU acceleration when available.
|
||||
cache_dir: Optional directory for caching downloaded models.
|
||||
threads: Optional number of threads for ONNX Runtime.
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.cache_dir = cache_dir
|
||||
self.threads = threads
|
||||
|
||||
self._encoder: Any | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy-load the TextCrossEncoder model."""
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
ok, err = check_fastembed_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._encoder is not None:
|
||||
return
|
||||
|
||||
from fastembed.rerank.cross_encoder import TextCrossEncoder
|
||||
|
||||
# Determine providers based on GPU preference
|
||||
providers: list[str] | None = None
|
||||
if self.use_gpu:
|
||||
try:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
providers = get_optimal_providers(use_gpu=True, with_device_options=False)
|
||||
except Exception:
|
||||
# Fallback: let fastembed decide
|
||||
providers = None
|
||||
|
||||
# Build initialization kwargs
|
||||
init_kwargs: dict[str, Any] = {}
|
||||
if self.cache_dir:
|
||||
init_kwargs["cache_dir"] = self.cache_dir
|
||||
if self.threads is not None:
|
||||
init_kwargs["threads"] = self.threads
|
||||
if providers:
|
||||
init_kwargs["providers"] = providers
|
||||
|
||||
logger.debug(
|
||||
"Loading FastEmbed reranker model: %s (use_gpu=%s)",
|
||||
self.model_name,
|
||||
self.use_gpu,
|
||||
)
|
||||
|
||||
self._encoder = TextCrossEncoder(
|
||||
model_name=self.model_name,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
logger.debug("FastEmbed reranker model loaded successfully")
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: float) -> float:
|
||||
"""Numerically stable sigmoid function."""
|
||||
if x < -709:
|
||||
return 0.0
|
||||
if x > 709:
|
||||
return 1.0
|
||||
import math
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair), normalized to [0, 1] range.
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
# FastEmbed's TextCrossEncoder.rerank() expects a query and list of documents.
|
||||
# For batch scoring of multiple query-doc pairs, we need to process them.
|
||||
# Group by query for efficiency when same query appears multiple times.
|
||||
query_to_docs: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
if query not in query_to_docs:
|
||||
query_to_docs[query] = []
|
||||
query_to_docs[query].append((idx, doc))
|
||||
|
||||
# Score each query group
|
||||
scores: list[float] = [0.0] * len(pairs)
|
||||
|
||||
for query, indexed_docs in query_to_docs.items():
|
||||
docs = [doc for _, doc in indexed_docs]
|
||||
indices = [idx for idx, _ in indexed_docs]
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Map scores back to original positions and normalize with sigmoid
|
||||
for i, raw_score in enumerate(raw_scores):
|
||||
if i < len(indices):
|
||||
original_idx = indices[i]
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
scores[original_idx] = self._sigmoid(float(raw_score))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed for query: %s", str(e)[:100])
|
||||
# Leave scores as 0.0 for failed queries
|
||||
|
||||
return scores
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
query: str,
|
||||
documents: Sequence[str],
|
||||
*,
|
||||
top_k: int | None = None,
|
||||
batch_size: int = 32,
|
||||
) -> list[tuple[float, str, int]]:
|
||||
"""Rerank documents for a single query.
|
||||
|
||||
This is a convenience method that provides results in ranked order.
|
||||
|
||||
Args:
|
||||
query: The query string.
|
||||
documents: List of documents to rerank.
|
||||
top_k: Return only top K results. None returns all.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of (score, document, original_index) tuples, sorted by score descending.
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._encoder is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
try:
|
||||
# TextCrossEncoder.rerank returns raw float scores in same order as input
|
||||
raw_scores = list(
|
||||
self._encoder.rerank(
|
||||
query=query,
|
||||
documents=list(documents),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to our format: (normalized_score, document, original_index)
|
||||
ranked = []
|
||||
for idx, raw_score in enumerate(raw_scores):
|
||||
if idx < len(documents):
|
||||
# Normalize score to [0, 1] using stable sigmoid
|
||||
normalized = self._sigmoid(float(raw_score))
|
||||
ranked.append((normalized, documents[idx], idx))
|
||||
|
||||
# Sort by score descending
|
||||
ranked.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
ranked = ranked[:top_k]
|
||||
|
||||
return ranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed rerank failed: %s", str(e)[:100])
|
||||
return []
|
||||
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
91
codex-lens/build/lib/codexlens/semantic/reranker/legacy.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Legacy sentence-transformers cross-encoder reranker.
|
||||
|
||||
Install with: pip install codexlens[reranker-legacy]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder
|
||||
|
||||
CROSS_ENCODER_AVAILABLE = True
|
||||
_import_error: str | None = None
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
CROSS_ENCODER_AVAILABLE = False
|
||||
_import_error = str(exc)
|
||||
|
||||
|
||||
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||
if CROSS_ENCODER_AVAILABLE:
|
||||
return True, None
|
||||
return (
|
||||
False,
|
||||
_import_error
|
||||
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
|
||||
)
|
||||
|
||||
|
||||
class CrossEncoderReranker(BaseReranker):
|
||||
"""Cross-encoder reranker with lazy model loading."""
|
||||
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = (model_name or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.device = (device or "").strip() or None
|
||||
self._model = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.device:
|
||||
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
||||
else:
|
||||
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
||||
raise
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[Tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> List[float]:
|
||||
"""Score (query, doc) pairs using the cross-encoder.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair) in the model's native scale (usually logits).
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||
return [float(s) for s in scores]
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Experimental LiteLLM reranker backend.
|
||||
|
||||
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
|
||||
relevance of a single (query, document) pair per request.
|
||||
|
||||
Notes:
|
||||
- This backend is experimental and may be slow/expensive compared to local
|
||||
rerankers.
|
||||
- It relies on `ccw-litellm` for a unified LLM API across providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
|
||||
|
||||
|
||||
def _coerce_score_to_unit_interval(score: float) -> float:
|
||||
"""Coerce a numeric score into [0, 1].
|
||||
|
||||
The prompt asks for a float in [0, 1], but some models may respond with 0-10
|
||||
or 0-100 scales. This function attempts a conservative normalization.
|
||||
"""
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
if 0.0 <= score <= 10.0:
|
||||
return score / 10.0
|
||||
if 0.0 <= score <= 100.0:
|
||||
return score / 100.0
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def _extract_score(text: str) -> float | None:
|
||||
"""Extract a numeric relevance score from an LLM response."""
|
||||
content = (text or "").strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Prefer JSON if present.
|
||||
if "{" in content and "}" in content:
|
||||
try:
|
||||
start = content.index("{")
|
||||
end = content.rindex("}") + 1
|
||||
payload = json.loads(content[start:end])
|
||||
if isinstance(payload, dict) and "score" in payload:
|
||||
return float(payload["score"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = _NUMBER_RE.search(content)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
return float(match.group(0))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class LiteLLMReranker(BaseReranker):
|
||||
"""Experimental reranker that uses a LiteLLM-compatible model.
|
||||
|
||||
This reranker scores each (query, doc) pair in isolation (single-pair mode)
|
||||
to improve prompt reliability across providers.
|
||||
"""
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a relevance scoring assistant.\n"
|
||||
"Given a search query and a document snippet, output a single numeric "
|
||||
"relevance score between 0 and 1.\n\n"
|
||||
"Scoring guidance:\n"
|
||||
"- 1.0: The document directly answers the query.\n"
|
||||
"- 0.5: The document is partially relevant.\n"
|
||||
"- 0.0: The document is unrelated.\n\n"
|
||||
"Output requirements:\n"
|
||||
"- Output ONLY the number (e.g., 0.73).\n"
|
||||
"- Do not include any other text."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "default",
|
||||
*,
|
||||
requests_per_minute: float | None = None,
|
||||
min_interval_seconds: float | None = None,
|
||||
default_score: float = 0.0,
|
||||
max_doc_chars: int = 8000,
|
||||
**litellm_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the reranker.
|
||||
|
||||
Args:
|
||||
model: Model name from ccw-litellm configuration (default: "default").
|
||||
requests_per_minute: Optional rate limit in requests per minute.
|
||||
min_interval_seconds: Optional minimum interval between requests. If set,
|
||||
it takes precedence over requests_per_minute.
|
||||
default_score: Score to use when an API call fails or parsing fails.
|
||||
max_doc_chars: Maximum number of document characters to include in the prompt.
|
||||
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm is not installed.
|
||||
ValueError: If model is blank.
|
||||
"""
|
||||
self.model_name = (model or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model cannot be blank")
|
||||
|
||||
self.default_score = float(default_score)
|
||||
|
||||
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
|
||||
|
||||
if min_interval_seconds is not None:
|
||||
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
|
||||
elif requests_per_minute is not None and float(requests_per_minute) > 0:
|
||||
self._min_interval_seconds = 60.0 / float(requests_per_minute)
|
||||
else:
|
||||
self._min_interval_seconds = 0.0
|
||||
|
||||
# Prefer deterministic output by default; allow overrides via kwargs.
|
||||
litellm_kwargs = dict(litellm_kwargs)
|
||||
litellm_kwargs.setdefault("temperature", 0.0)
|
||||
litellm_kwargs.setdefault("max_tokens", 16)
|
||||
|
||||
try:
|
||||
from ccw_litellm import ChatMessage, LiteLLMClient
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from exc
|
||||
|
||||
self._ChatMessage = ChatMessage
|
||||
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
|
||||
|
||||
self._lock = threading.RLock()
|
||||
self._last_request_at = 0.0
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
# Keep consistent with LiteLLMEmbedderWrapper workaround.
|
||||
if text.startswith("import"):
|
||||
return " " + text
|
||||
return text
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
if self._min_interval_seconds <= 0:
|
||||
return
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_at
|
||||
if elapsed < self._min_interval_seconds:
|
||||
time.sleep(self._min_interval_seconds - elapsed)
|
||||
self._last_request_at = time.monotonic()
|
||||
|
||||
def _build_user_prompt(self, query: str, doc: str) -> str:
|
||||
sanitized_query = self._sanitize_text(query or "")
|
||||
sanitized_doc = self._sanitize_text(doc or "")
|
||||
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
|
||||
sanitized_doc = sanitized_doc[: self.max_doc_chars]
|
||||
|
||||
return (
|
||||
"Query:\n"
|
||||
f"{sanitized_query}\n\n"
|
||||
"Document:\n"
|
||||
f"{sanitized_doc}\n\n"
|
||||
"Return the relevance score (0 to 1) as a single number:"
|
||||
)
|
||||
|
||||
def _score_single_pair(self, query: str, doc: str) -> float:
|
||||
messages = [
|
||||
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
|
||||
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
|
||||
]
|
||||
|
||||
try:
|
||||
self._rate_limit()
|
||||
response = self._client.chat(messages)
|
||||
except Exception as exc:
|
||||
logger.debug("LiteLLM reranker request failed: %s", exc)
|
||||
return self.default_score
|
||||
|
||||
raw = getattr(response, "content", "") or ""
|
||||
score = _extract_score(raw)
|
||||
if score is None:
|
||||
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
|
||||
return self.default_score
|
||||
return _coerce_score_to_unit_interval(float(score))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with per-pair LLM calls."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
|
||||
scores: list[float] = []
|
||||
for i in range(0, len(pairs), bs):
|
||||
batch = pairs[i : i + bs]
|
||||
for query, doc in batch:
|
||||
scores.append(self._score_single_pair(query, doc))
|
||||
return scores
|
||||
@@ -0,0 +1,268 @@
|
||||
"""Optimum + ONNX Runtime reranker backend.
|
||||
|
||||
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
|
||||
classification models. It is designed to run without requiring PyTorch at
|
||||
runtime by using numpy tensors and ONNX Runtime execution providers.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_onnx_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
|
||||
class ONNXReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
providers: list[Any] | None = None,
|
||||
max_length: int | None = None,
|
||||
) -> None:
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.providers = providers
|
||||
|
||||
self.max_length = int(max_length) if max_length is not None else None
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._model_input_names: set[str] | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_onnx_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
# Include device_id options for DirectML/CUDA selection when available.
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=True
|
||||
)
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`.
|
||||
# Prefer passing the full providers list, with a conservative fallback.
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
model_kwargs = {}
|
||||
|
||||
try:
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(
|
||||
self.model_name,
|
||||
**model_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments.
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache model input names to filter tokenizer outputs defensively.
|
||||
input_names: set[str] | None = None
|
||||
for attr in ("input_names", "model_input_names"):
|
||||
names = getattr(self._model, attr, None)
|
||||
if isinstance(names, (list, tuple)) and names:
|
||||
input_names = {str(n) for n in names}
|
||||
break
|
||||
if input_names is None:
|
||||
try:
|
||||
session = getattr(self._model, "model", None)
|
||||
if session is not None and hasattr(session, "get_inputs"):
|
||||
input_names = {i.name for i in session.get_inputs()}
|
||||
except Exception:
|
||||
input_names = None
|
||||
self._model_input_names = input_names
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
x = np.clip(x, -50.0, 50.0)
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@staticmethod
|
||||
def _select_relevance_logit(logits: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
arr = np.asarray(logits)
|
||||
if arr.ndim == 0:
|
||||
return arr.reshape(1)
|
||||
if arr.ndim == 1:
|
||||
return arr
|
||||
if arr.ndim >= 2:
|
||||
# Common cases:
|
||||
# - Regression: (batch, 1)
|
||||
# - Binary classification: (batch, 2)
|
||||
if arr.shape[-1] == 1:
|
||||
return arr[..., 0]
|
||||
if arr.shape[-1] == 2:
|
||||
# Convert 2-logit softmax into a single logit via difference.
|
||||
return arr[..., 1] - arr[..., 0]
|
||||
return arr.max(axis=-1)
|
||||
return arr.reshape(-1)
|
||||
|
||||
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
|
||||
|
||||
queries = [q for q, _ in batch]
|
||||
docs = [d for _, d in batch]
|
||||
|
||||
tokenizer_kwargs: dict[str, Any] = {
|
||||
"text": queries,
|
||||
"text_pair": docs,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
"return_tensors": "np",
|
||||
}
|
||||
|
||||
max_len = self.max_length
|
||||
if max_len is None:
|
||||
try:
|
||||
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
|
||||
if 0 < model_max < 10_000:
|
||||
max_len = model_max
|
||||
else:
|
||||
max_len = 512
|
||||
except Exception:
|
||||
max_len = 512
|
||||
if max_len is not None and max_len > 0:
|
||||
tokenizer_kwargs["max_length"] = int(max_len)
|
||||
|
||||
encoded = self._tokenizer(**tokenizer_kwargs)
|
||||
inputs = dict(encoded)
|
||||
|
||||
# Some models do not accept token_type_ids; filter to known input names if available.
|
||||
if self._model_input_names:
|
||||
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
|
||||
if self._model is None:
|
||||
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
|
||||
|
||||
outputs = self._model(**inputs)
|
||||
if hasattr(outputs, "logits"):
|
||||
return outputs.logits
|
||||
if isinstance(outputs, dict) and "logits" in outputs:
|
||||
return outputs["logits"]
|
||||
if isinstance(outputs, (list, tuple)) and outputs:
|
||||
return outputs[0]
|
||||
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
import numpy as np
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores: list[float] = []
|
||||
|
||||
for batch in _iter_batches(list(pairs), bs):
|
||||
inputs = self._tokenize_batch(batch)
|
||||
logits = self._forward_logits(inputs)
|
||||
rel_logits = self._select_relevance_logit(logits)
|
||||
probs = self._sigmoid(rel_logits)
|
||||
probs = np.clip(probs, 0.0, 1.0)
|
||||
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
|
||||
|
||||
if len(scores) != len(pairs):
|
||||
logger.debug(
|
||||
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
|
||||
)
|
||||
return scores[: len(pairs)]
|
||||
|
||||
return scores
|
||||
434
codex-lens/build/lib/codexlens/semantic/rotational_embedder.py
Normal file
434
codex-lens/build/lib/codexlens/semantic/rotational_embedder.py
Normal file
@@ -0,0 +1,434 @@
|
||||
"""Rotational embedder for multi-endpoint API load balancing.
|
||||
|
||||
Provides intelligent load balancing across multiple LiteLLM embedding endpoints
|
||||
to maximize throughput while respecting rate limits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndpointStatus(Enum):
|
||||
"""Status of an API endpoint."""
|
||||
AVAILABLE = "available"
|
||||
COOLING = "cooling" # Rate limited, temporarily unavailable
|
||||
FAILED = "failed" # Permanent failure (auth error, etc.)
|
||||
|
||||
|
||||
class SelectionStrategy(Enum):
|
||||
"""Strategy for selecting endpoints."""
|
||||
ROUND_ROBIN = "round_robin"
|
||||
LATENCY_AWARE = "latency_aware"
|
||||
WEIGHTED_RANDOM = "weighted_random"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointConfig:
|
||||
"""Configuration for a single API endpoint."""
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
weight: float = 1.0 # Higher weight = more requests
|
||||
max_concurrent: int = 4 # Max concurrent requests to this endpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndpointState:
|
||||
"""Runtime state for an endpoint."""
|
||||
config: EndpointConfig
|
||||
embedder: Any = None # LiteLLMEmbedderWrapper instance
|
||||
|
||||
# Health metrics
|
||||
status: EndpointStatus = EndpointStatus.AVAILABLE
|
||||
cooldown_until: float = 0.0 # Unix timestamp when cooldown ends
|
||||
|
||||
# Performance metrics
|
||||
total_requests: int = 0
|
||||
total_failures: int = 0
|
||||
avg_latency_ms: float = 0.0
|
||||
last_latency_ms: float = 0.0
|
||||
|
||||
# Concurrency tracking
|
||||
active_requests: int = 0
|
||||
lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if endpoint is available for requests."""
|
||||
if self.status == EndpointStatus.FAILED:
|
||||
return False
|
||||
if self.status == EndpointStatus.COOLING:
|
||||
if time.time() >= self.cooldown_until:
|
||||
self.status = EndpointStatus.AVAILABLE
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_cooldown(self, seconds: float) -> None:
|
||||
"""Put endpoint in cooldown state."""
|
||||
self.status = EndpointStatus.COOLING
|
||||
self.cooldown_until = time.time() + seconds
|
||||
logger.warning(f"Endpoint {self.config.model} cooling down for {seconds:.1f}s")
|
||||
|
||||
def mark_failed(self) -> None:
|
||||
"""Mark endpoint as permanently failed."""
|
||||
self.status = EndpointStatus.FAILED
|
||||
logger.error(f"Endpoint {self.config.model} marked as failed")
|
||||
|
||||
def record_success(self, latency_ms: float) -> None:
|
||||
"""Record successful request."""
|
||||
self.total_requests += 1
|
||||
self.last_latency_ms = latency_ms
|
||||
# Exponential moving average for latency
|
||||
alpha = 0.3
|
||||
if self.avg_latency_ms == 0:
|
||||
self.avg_latency_ms = latency_ms
|
||||
else:
|
||||
self.avg_latency_ms = alpha * latency_ms + (1 - alpha) * self.avg_latency_ms
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""Record failed request."""
|
||||
self.total_requests += 1
|
||||
self.total_failures += 1
|
||||
|
||||
@property
|
||||
def health_score(self) -> float:
|
||||
"""Calculate health score (0-1) based on metrics."""
|
||||
if not self.is_available():
|
||||
return 0.0
|
||||
|
||||
# Base score from success rate
|
||||
if self.total_requests > 0:
|
||||
success_rate = 1 - (self.total_failures / self.total_requests)
|
||||
else:
|
||||
success_rate = 1.0
|
||||
|
||||
# Latency factor (faster = higher score)
|
||||
# Normalize: 100ms = 1.0, 1000ms = 0.1
|
||||
if self.avg_latency_ms > 0:
|
||||
latency_factor = min(1.0, 100 / self.avg_latency_ms)
|
||||
else:
|
||||
latency_factor = 1.0
|
||||
|
||||
# Availability factor (less concurrent = more available)
|
||||
if self.config.max_concurrent > 0:
|
||||
availability = 1 - (self.active_requests / self.config.max_concurrent)
|
||||
else:
|
||||
availability = 1.0
|
||||
|
||||
# Combined score with weights
|
||||
return (success_rate * 0.4 + latency_factor * 0.3 + availability * 0.3) * self.config.weight
|
||||
|
||||
|
||||
class RotationalEmbedder(BaseEmbedder):
|
||||
"""Embedder that load balances across multiple API endpoints.
|
||||
|
||||
Features:
|
||||
- Intelligent endpoint selection based on latency and health
|
||||
- Automatic failover on rate limits (429) and server errors
|
||||
- Cooldown management to respect rate limits
|
||||
- Thread-safe concurrent request handling
|
||||
|
||||
Args:
|
||||
endpoints: List of endpoint configurations
|
||||
strategy: Selection strategy (default: latency_aware)
|
||||
default_cooldown: Default cooldown seconds for rate limits (default: 60)
|
||||
max_retries: Maximum retry attempts across all endpoints (default: 3)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoints: List[EndpointConfig],
|
||||
strategy: SelectionStrategy = SelectionStrategy.LATENCY_AWARE,
|
||||
default_cooldown: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
if not endpoints:
|
||||
raise ValueError("At least one endpoint must be provided")
|
||||
|
||||
self.strategy = strategy
|
||||
self.default_cooldown = default_cooldown
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize endpoint states
|
||||
self._endpoints: List[EndpointState] = []
|
||||
self._lock = threading.Lock()
|
||||
self._round_robin_index = 0
|
||||
|
||||
# Create embedder instances for each endpoint
|
||||
from .litellm_embedder import LiteLLMEmbedderWrapper
|
||||
|
||||
for config in endpoints:
|
||||
# Build kwargs for LiteLLMEmbedderWrapper
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if config.api_key:
|
||||
kwargs["api_key"] = config.api_key
|
||||
if config.api_base:
|
||||
kwargs["api_base"] = config.api_base
|
||||
|
||||
try:
|
||||
embedder = LiteLLMEmbedderWrapper(model=config.model, **kwargs)
|
||||
state = EndpointState(config=config, embedder=embedder)
|
||||
self._endpoints.append(state)
|
||||
logger.info(f"Initialized endpoint: {config.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize endpoint {config.model}: {e}")
|
||||
|
||||
if not self._endpoints:
|
||||
raise ValueError("Failed to initialize any endpoints")
|
||||
|
||||
# Cache embedding properties from first endpoint
|
||||
self._embedding_dim = self._endpoints[0].embedder.embedding_dim
|
||||
self._model_name = f"rotational({len(self._endpoints)} endpoints)"
|
||||
self._max_tokens = self._endpoints[0].embedder.max_tokens
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimensions."""
|
||||
return self._embedding_dim
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return model name."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def max_tokens(self) -> int:
|
||||
"""Return maximum token limit."""
|
||||
return self._max_tokens
|
||||
|
||||
@property
|
||||
def endpoint_count(self) -> int:
|
||||
"""Return number of configured endpoints."""
|
||||
return len(self._endpoints)
|
||||
|
||||
@property
|
||||
def available_endpoint_count(self) -> int:
|
||||
"""Return number of available endpoints."""
|
||||
return sum(1 for ep in self._endpoints if ep.is_available())
|
||||
|
||||
def get_endpoint_stats(self) -> List[Dict[str, Any]]:
|
||||
"""Get statistics for all endpoints."""
|
||||
stats = []
|
||||
for ep in self._endpoints:
|
||||
stats.append({
|
||||
"model": ep.config.model,
|
||||
"status": ep.status.value,
|
||||
"total_requests": ep.total_requests,
|
||||
"total_failures": ep.total_failures,
|
||||
"avg_latency_ms": round(ep.avg_latency_ms, 2),
|
||||
"health_score": round(ep.health_score, 3),
|
||||
"active_requests": ep.active_requests,
|
||||
})
|
||||
return stats
|
||||
|
||||
def _select_endpoint(self) -> Optional[EndpointState]:
|
||||
"""Select best available endpoint based on strategy."""
|
||||
available = [ep for ep in self._endpoints if ep.is_available()]
|
||||
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if self.strategy == SelectionStrategy.ROUND_ROBIN:
|
||||
with self._lock:
|
||||
self._round_robin_index = (self._round_robin_index + 1) % len(available)
|
||||
return available[self._round_robin_index]
|
||||
|
||||
elif self.strategy == SelectionStrategy.LATENCY_AWARE:
|
||||
# Sort by health score (descending) and pick top candidate
|
||||
# Add small random factor to prevent thundering herd
|
||||
scored = [(ep, ep.health_score + random.uniform(0, 0.1)) for ep in available]
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored[0][0]
|
||||
|
||||
elif self.strategy == SelectionStrategy.WEIGHTED_RANDOM:
|
||||
# Weighted random selection based on health scores
|
||||
scores = [ep.health_score for ep in available]
|
||||
total = sum(scores)
|
||||
if total == 0:
|
||||
return random.choice(available)
|
||||
|
||||
weights = [s / total for s in scores]
|
||||
return random.choices(available, weights=weights, k=1)[0]
|
||||
|
||||
return available[0]
|
||||
|
||||
def _parse_retry_after(self, error: Exception) -> Optional[float]:
|
||||
"""Extract Retry-After value from error if available."""
|
||||
error_str = str(error)
|
||||
|
||||
# Try to find Retry-After in error message
|
||||
import re
|
||||
match = re.search(r'[Rr]etry[- ][Aa]fter[:\s]+(\d+)', error_str)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _is_rate_limit_error(self, error: Exception) -> bool:
|
||||
"""Check if error is a rate limit error."""
|
||||
error_str = str(error).lower()
|
||||
return any(x in error_str for x in ["429", "rate limit", "too many requests"])
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""Check if error is retryable (not auth/config error)."""
|
||||
error_str = str(error).lower()
|
||||
# Retryable errors
|
||||
if any(x in error_str for x in ["429", "rate limit", "502", "503", "504",
|
||||
"timeout", "connection", "service unavailable"]):
|
||||
return True
|
||||
# Non-retryable errors (auth, config)
|
||||
if any(x in error_str for x in ["401", "403", "invalid", "authentication",
|
||||
"unauthorized", "api key"]):
|
||||
return False
|
||||
# Default to retryable for unknown errors
|
||||
return True
|
||||
|
||||
def embed_to_numpy(self, texts: str | Iterable[str], **kwargs) -> np.ndarray:
|
||||
"""Embed texts using load-balanced endpoint selection.
|
||||
|
||||
Args:
|
||||
texts: Single text or iterable of texts to embed.
|
||||
**kwargs: Additional arguments passed to underlying embedder.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: Array of shape (n_texts, embedding_dim) containing embeddings.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all endpoints fail after retries.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
else:
|
||||
texts = list(texts)
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
tried_endpoints: set = set()
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
endpoint = self._select_endpoint()
|
||||
|
||||
if endpoint is None:
|
||||
# All endpoints unavailable, wait for shortest cooldown
|
||||
min_cooldown = min(
|
||||
(ep.cooldown_until - time.time() for ep in self._endpoints
|
||||
if ep.status == EndpointStatus.COOLING),
|
||||
default=self.default_cooldown
|
||||
)
|
||||
if min_cooldown > 0 and attempt < self.max_retries:
|
||||
wait_time = min(min_cooldown, 30) # Cap wait at 30s
|
||||
logger.warning(f"All endpoints busy, waiting {wait_time:.1f}s...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
break
|
||||
|
||||
# Track tried endpoints to avoid infinite loops
|
||||
endpoint_id = id(endpoint)
|
||||
if endpoint_id in tried_endpoints and len(tried_endpoints) >= len(self._endpoints):
|
||||
# Already tried all endpoints
|
||||
break
|
||||
tried_endpoints.add(endpoint_id)
|
||||
|
||||
# Acquire slot
|
||||
with endpoint.lock:
|
||||
endpoint.active_requests += 1
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = endpoint.embedder.embed_to_numpy(texts, **kwargs)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
# Record success
|
||||
endpoint.record_success(latency_ms)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
endpoint.record_failure()
|
||||
|
||||
if self._is_rate_limit_error(e):
|
||||
# Rate limited - set cooldown
|
||||
retry_after = self._parse_retry_after(e) or self.default_cooldown
|
||||
endpoint.set_cooldown(retry_after)
|
||||
logger.warning(f"Endpoint {endpoint.config.model} rate limited, "
|
||||
f"cooling for {retry_after}s")
|
||||
|
||||
elif not self._is_retryable_error(e):
|
||||
# Permanent failure (auth error, etc.)
|
||||
endpoint.mark_failed()
|
||||
logger.error(f"Endpoint {endpoint.config.model} failed permanently: {e}")
|
||||
|
||||
else:
|
||||
# Temporary error - short cooldown
|
||||
endpoint.set_cooldown(5.0)
|
||||
logger.warning(f"Endpoint {endpoint.config.model} error: {e}")
|
||||
|
||||
finally:
|
||||
with endpoint.lock:
|
||||
endpoint.active_requests -= 1
|
||||
|
||||
# All retries exhausted
|
||||
available = self.available_endpoint_count
|
||||
raise RuntimeError(
|
||||
f"All embedding attempts failed after {self.max_retries + 1} tries. "
|
||||
f"Available endpoints: {available}/{len(self._endpoints)}. "
|
||||
f"Last error: {last_error}"
|
||||
)
|
||||
|
||||
|
||||
def create_rotational_embedder(
|
||||
endpoints_config: List[Dict[str, Any]],
|
||||
strategy: str = "latency_aware",
|
||||
default_cooldown: float = 60.0,
|
||||
) -> RotationalEmbedder:
|
||||
"""Factory function to create RotationalEmbedder from config dicts.
|
||||
|
||||
Args:
|
||||
endpoints_config: List of endpoint configuration dicts with keys:
|
||||
- model: Model identifier (required)
|
||||
- api_key: API key (optional)
|
||||
- api_base: API base URL (optional)
|
||||
- weight: Request weight (optional, default 1.0)
|
||||
- max_concurrent: Max concurrent requests (optional, default 4)
|
||||
strategy: Selection strategy name (round_robin, latency_aware, weighted_random)
|
||||
default_cooldown: Default cooldown seconds for rate limits
|
||||
|
||||
Returns:
|
||||
Configured RotationalEmbedder instance
|
||||
|
||||
Example config:
|
||||
endpoints_config = [
|
||||
{"model": "openai/text-embedding-3-small", "api_key": "sk-..."},
|
||||
{"model": "azure/my-embedding", "api_base": "https://...", "api_key": "..."},
|
||||
]
|
||||
"""
|
||||
endpoints = []
|
||||
for cfg in endpoints_config:
|
||||
endpoints.append(EndpointConfig(
|
||||
model=cfg["model"],
|
||||
api_key=cfg.get("api_key"),
|
||||
api_base=cfg.get("api_base"),
|
||||
weight=cfg.get("weight", 1.0),
|
||||
max_concurrent=cfg.get("max_concurrent", 4),
|
||||
))
|
||||
|
||||
strategy_enum = SelectionStrategy[strategy.upper()]
|
||||
|
||||
return RotationalEmbedder(
|
||||
endpoints=endpoints,
|
||||
strategy=strategy_enum,
|
||||
default_cooldown=default_cooldown,
|
||||
)
|
||||
567
codex-lens/build/lib/codexlens/semantic/splade_encoder.py
Normal file
567
codex-lens/build/lib/codexlens/semantic/splade_encoder.py
Normal file
@@ -0,0 +1,567 @@
|
||||
"""ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
|
||||
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
|
||||
that combine the interpretability of BM25 with neural relevance modeling.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
Install (GPU):
|
||||
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_splade_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check whether SPLADE dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available: bool, error_message: Optional[str])
|
||||
"""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc:
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# Global cache for SPLADE encoders (singleton pattern)
|
||||
_splade_cache: Dict[str, "SpladeEncoder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_splade_encoder(
|
||||
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> "SpladeEncoder":
|
||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
SpladeEncoder instances across multiple searches, avoiding repeated model
|
||||
loading overhead.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
|
||||
Returns:
|
||||
Cached SpladeEncoder instance for the given configuration
|
||||
"""
|
||||
global _splade_cache
|
||||
|
||||
# Cache key includes all configuration parameters
|
||||
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
|
||||
|
||||
with _cache_lock:
|
||||
encoder = _splade_cache.get(cache_key)
|
||||
if encoder is not None:
|
||||
return encoder
|
||||
|
||||
# Create new encoder and cache it
|
||||
encoder = SpladeEncoder(
|
||||
model_name=model_name,
|
||||
use_gpu=use_gpu,
|
||||
max_length=max_length,
|
||||
sparsity_threshold=sparsity_threshold,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Pre-load model to ensure it's ready
|
||||
encoder._load_model()
|
||||
_splade_cache[cache_key] = encoder
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def clear_splade_cache() -> None:
|
||||
"""Clear the SPLADE encoder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when encoders are no longer needed.
|
||||
"""
|
||||
global _splade_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for encoder in _splade_cache.values():
|
||||
if encoder._model is not None:
|
||||
del encoder._model
|
||||
encoder._model = None
|
||||
if encoder._tokenizer is not None:
|
||||
del encoder._tokenizer
|
||||
encoder._tokenizer = None
|
||||
_splade_cache.clear()
|
||||
|
||||
|
||||
class SpladeEncoder:
|
||||
"""ONNX-optimized SPLADE sparse encoder.
|
||||
|
||||
Produces sparse vectors with vocabulary-aligned dimensions.
|
||||
Output: Dict[int, float] mapping token_id to weight.
|
||||
|
||||
SPLADE activation formula:
|
||||
splade_repr = log(1 + ReLU(logits)) * attention_mask
|
||||
splade_vec = max_pooling(splade_repr, axis=sequence_length)
|
||||
|
||||
References:
|
||||
- SPLADE: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
providers: Optional[List[Any]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize SPLADE encoder.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.max_length = int(max_length) if max_length > 0 else 512
|
||||
self.sparsity_threshold = float(sparsity_threshold)
|
||||
self.providers = providers
|
||||
|
||||
# Setup ONNX cache directory
|
||||
if cache_dir:
|
||||
self._cache_dir = Path(cache_dir)
|
||||
else:
|
||||
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._vocab_size: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_local_cache_path(self) -> Path:
|
||||
"""Get local cache path for this model's ONNX files.
|
||||
|
||||
Returns:
|
||||
Path to the local ONNX cache directory for this model
|
||||
"""
|
||||
# Replace / with -- for filesystem-safe naming
|
||||
safe_name = self.model_name.replace("/", "--")
|
||||
return self._cache_dir / safe_name
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load ONNX model and tokenizer.
|
||||
|
||||
First checks local cache for ONNX model, falling back to
|
||||
HuggingFace download and conversion if not cached.
|
||||
"""
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from .gpu_support import get_optimal_providers, get_selected_device_id
|
||||
|
||||
# Get providers as pure string list (cache-friendly)
|
||||
# NOTE: with_device_options=False to avoid tuple-based providers
|
||||
# which break optimum's caching mechanism
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=False
|
||||
)
|
||||
# Get device_id separately for provider_options
|
||||
self._device_id = get_selected_device_id() if self.use_gpu else None
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`
|
||||
# Prefer passing the full providers list, with a conservative fallback
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
# Pass device_id via provider_options for GPU selection
|
||||
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
|
||||
# Build provider_options dict for each GPU provider
|
||||
provider_options = {}
|
||||
for p in self.providers:
|
||||
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
|
||||
provider_options[p] = {"device_id": self._device_id}
|
||||
if provider_options:
|
||||
model_kwargs["provider_options"] = provider_options
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to inspect ORTModel signature: {e}")
|
||||
model_kwargs = {}
|
||||
|
||||
# Check for local ONNX cache first
|
||||
local_cache = self._get_local_cache_path()
|
||||
onnx_model_path = local_cache / "model.onnx"
|
||||
|
||||
if onnx_model_path.exists():
|
||||
# Load from local cache
|
||||
logger.info(f"Loading SPLADE from local cache: {local_cache}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
str(local_cache),
|
||||
**model_kwargs,
|
||||
)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
str(local_cache), use_fast=True
|
||||
)
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.info(
|
||||
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load from cache, redownloading: {e}")
|
||||
|
||||
# Download and convert from HuggingFace
|
||||
logger.info(f"Downloading SPLADE model: {self.model_name}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True, # Export to ONNX
|
||||
**model_kwargs,
|
||||
)
|
||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True,
|
||||
)
|
||||
logger.warning(
|
||||
"Optimum version doesn't support provider parameters. "
|
||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||
)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache vocabulary size
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||
|
||||
# Save to local cache for future use
|
||||
try:
|
||||
local_cache.mkdir(parents=True, exist_ok=True)
|
||||
self._model.save_pretrained(str(local_cache))
|
||||
self._tokenizer.save_pretrained(str(local_cache))
|
||||
logger.info(f"SPLADE model cached to: {local_cache}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache SPLADE model: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||
"""Apply SPLADE activation function to model outputs.
|
||||
|
||||
Formula: log(1 + ReLU(logits)) * attention_mask
|
||||
|
||||
Args:
|
||||
logits: Model output logits (batch, seq_len, vocab_size)
|
||||
attention_mask: Attention mask (batch, seq_len)
|
||||
|
||||
Returns:
|
||||
SPLADE representations (batch, seq_len, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# ReLU activation
|
||||
relu_logits = np.maximum(0, logits)
|
||||
|
||||
# Log(1 + x) transformation
|
||||
log_relu = np.log1p(relu_logits)
|
||||
|
||||
# Apply attention mask (expand to match vocab dimension)
|
||||
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
|
||||
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
||||
|
||||
# Element-wise multiplication
|
||||
splade_repr = log_relu * mask_expanded
|
||||
|
||||
return splade_repr
|
||||
|
||||
@staticmethod
|
||||
def _max_pooling(splade_repr: Any) -> Any:
|
||||
"""Max pooling over sequence length dimension.
|
||||
|
||||
Args:
|
||||
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
|
||||
|
||||
Returns:
|
||||
Pooled sparse vectors (batch, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Max pooling over sequence dimension (axis=1)
|
||||
return np.max(splade_repr, axis=1)
|
||||
|
||||
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
|
||||
"""Convert dense vector to sparse dictionary.
|
||||
|
||||
Args:
|
||||
dense_vec: Dense vector (vocab_size,)
|
||||
|
||||
Returns:
|
||||
Sparse dictionary {token_id: weight} with weights above threshold
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Find non-zero indices above threshold
|
||||
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
|
||||
|
||||
# Create sparse dictionary
|
||||
sparse_dict = {
|
||||
int(idx): float(dense_vec[idx])
|
||||
for idx in nonzero_indices
|
||||
}
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def warmup(self, text: str = "warmup query") -> None:
|
||||
"""Warmup the encoder by running a dummy inference.
|
||||
|
||||
First-time model inference includes initialization overhead.
|
||||
Call this method once before the first real search to avoid
|
||||
latency spikes.
|
||||
|
||||
Args:
|
||||
text: Dummy text for warmup (default: "warmup query")
|
||||
"""
|
||||
logger.info("Warming up SPLADE encoder...")
|
||||
# Trigger model loading and first inference
|
||||
_ = self.encode_text(text)
|
||||
logger.info("SPLADE encoder warmup complete")
|
||||
|
||||
def encode_text(self, text: str) -> Dict[int, float]:
|
||||
"""Encode text to sparse vector {token_id: weight}.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Sparse vector as dictionary mapping token_id to weight
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Tokenize input
|
||||
encoded = self._tokenizer(
|
||||
text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vec = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert to sparse dictionary (single item batch)
|
||||
sparse_dict = self._to_sparse_dict(splade_vec[0])
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
|
||||
"""Batch encode texts to sparse vectors.
|
||||
|
||||
Args:
|
||||
texts: List of input texts to encode
|
||||
batch_size: Batch size for encoding (default: 32)
|
||||
|
||||
Returns:
|
||||
List of sparse vectors as dictionaries
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
results: List[Dict[int, float]] = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Tokenize batch
|
||||
encoded = self._tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vecs = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert each vector to sparse dictionary
|
||||
for vec in splade_vecs:
|
||||
sparse_dict = self._to_sparse_dict(vec)
|
||||
results.append(sparse_dict)
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Return vocabulary size (~30k for BERT-based models).
|
||||
|
||||
Returns:
|
||||
Vocabulary size (number of tokens in tokenizer)
|
||||
"""
|
||||
if self._vocab_size is not None:
|
||||
return self._vocab_size
|
||||
|
||||
self._load_model()
|
||||
return self._vocab_size or 0
|
||||
|
||||
def get_token(self, token_id: int) -> str:
|
||||
"""Convert token_id to string (for debugging).
|
||||
|
||||
Args:
|
||||
token_id: Token ID to convert
|
||||
|
||||
Returns:
|
||||
Token string
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded")
|
||||
|
||||
return self._tokenizer.decode([token_id])
|
||||
|
||||
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
|
||||
"""Get top-k tokens with highest weights from sparse vector.
|
||||
|
||||
Useful for debugging and understanding what the model is focusing on.
|
||||
|
||||
Args:
|
||||
sparse_vec: Sparse vector as {token_id: weight}
|
||||
top_k: Number of top tokens to return
|
||||
|
||||
Returns:
|
||||
List of (token_string, weight) tuples, sorted by weight descending
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if not sparse_vec:
|
||||
return []
|
||||
|
||||
# Sort by weight descending
|
||||
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Take top-k and convert token_ids to strings
|
||||
top_items = sorted_items[:top_k]
|
||||
|
||||
return [
|
||||
(self.get_token(token_id), weight)
|
||||
for token_id, weight in top_items
|
||||
]
|
||||
1278
codex-lens/build/lib/codexlens/semantic/vector_store.py
Normal file
1278
codex-lens/build/lib/codexlens/semantic/vector_store.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user