feat: Enhance embedding generation and search capabilities

- Added pre-calculation of estimated chunk count for HNSW capacity in `generate_dense_embeddings_centralized` to optimize indexing performance.
- Implemented binary vector generation with memory-mapped storage for efficient cascade search, including metadata saving.
- Introduced SPLADE sparse index generation with improved handling and metadata storage.
- Updated `ChainSearchEngine` to prefer centralized binary searcher for improved performance and added fallback to legacy binary index.
- Deprecated `BinaryANNIndex` in favor of `BinarySearcher` for better memory management and performance.
- Enhanced `SpladeEncoder` with warmup functionality to reduce latency spikes during first-time inference.
- Improved `SpladeIndex` with cache size adjustments for better query performance.
- Added methods for managing binary vectors in `VectorMetadataStore`, including batch insertion and retrieval.
- Created a new `BinarySearcher` class for efficient binary vector search using Hamming distance, supporting both memory-mapped and database loading modes.
This commit is contained in:
catlog22
2026-01-02 23:57:55 +08:00
parent 96b44e1482
commit 54fd94547c
12 changed files with 945 additions and 167 deletions

View File

@@ -15,6 +15,7 @@ Requires-Dist: tree-sitter-python>=0.25
Requires-Dist: tree-sitter-javascript>=0.25
Requires-Dist: tree-sitter-typescript>=0.23
Requires-Dist: pathspec>=0.11
Requires-Dist: watchdog>=3.0
Provides-Extra: semantic
Requires-Dist: numpy>=1.24; extra == "semantic"
Requires-Dist: fastembed>=0.2; extra == "semantic"
@@ -29,6 +30,26 @@ Requires-Dist: numpy>=1.24; extra == "semantic-directml"
Requires-Dist: fastembed>=0.2; extra == "semantic-directml"
Requires-Dist: hnswlib>=0.8.0; extra == "semantic-directml"
Requires-Dist: onnxruntime-directml>=1.15.0; extra == "semantic-directml"
Provides-Extra: reranker-onnx
Requires-Dist: optimum>=1.16; extra == "reranker-onnx"
Requires-Dist: onnxruntime>=1.15; extra == "reranker-onnx"
Requires-Dist: transformers>=4.36; extra == "reranker-onnx"
Provides-Extra: reranker-api
Requires-Dist: httpx>=0.25; extra == "reranker-api"
Provides-Extra: reranker-litellm
Requires-Dist: ccw-litellm>=0.1; extra == "reranker-litellm"
Provides-Extra: reranker-legacy
Requires-Dist: sentence-transformers>=2.2; extra == "reranker-legacy"
Provides-Extra: reranker
Requires-Dist: optimum>=1.16; extra == "reranker"
Requires-Dist: onnxruntime>=1.15; extra == "reranker"
Requires-Dist: transformers>=4.36; extra == "reranker"
Provides-Extra: splade
Requires-Dist: transformers>=4.36; extra == "splade"
Requires-Dist: optimum[onnxruntime]>=1.16; extra == "splade"
Provides-Extra: splade-gpu
Requires-Dist: transformers>=4.36; extra == "splade-gpu"
Requires-Dist: optimum[onnxruntime-gpu]>=1.16; extra == "splade-gpu"
Provides-Extra: encoding
Requires-Dist: chardet>=5.0; extra == "encoding"
Provides-Extra: full

View File

@@ -8,6 +8,7 @@ src/codexlens/__init__.py
src/codexlens/__main__.py
src/codexlens/config.py
src/codexlens/entities.py
src/codexlens/env_config.py
src/codexlens/errors.py
src/codexlens/cli/__init__.py
src/codexlens/cli/commands.py
@@ -15,6 +16,7 @@ src/codexlens/cli/embedding_manager.py
src/codexlens/cli/model_manager.py
src/codexlens/cli/output.py
src/codexlens/indexing/__init__.py
src/codexlens/indexing/embedding.py
src/codexlens/indexing/symbol_extractor.py
src/codexlens/parsers/__init__.py
src/codexlens/parsers/encoding.py
@@ -24,6 +26,7 @@ src/codexlens/parsers/treesitter_parser.py
src/codexlens/search/__init__.py
src/codexlens/search/chain_search.py
src/codexlens/search/enrichment.py
src/codexlens/search/graph_expander.py
src/codexlens/search/hybrid_search.py
src/codexlens/search/query_parser.py
src/codexlens/search/ranking.py
@@ -37,28 +40,52 @@ src/codexlens/semantic/factory.py
src/codexlens/semantic/gpu_support.py
src/codexlens/semantic/litellm_embedder.py
src/codexlens/semantic/rotational_embedder.py
src/codexlens/semantic/splade_encoder.py
src/codexlens/semantic/vector_store.py
src/codexlens/semantic/reranker/__init__.py
src/codexlens/semantic/reranker/api_reranker.py
src/codexlens/semantic/reranker/base.py
src/codexlens/semantic/reranker/factory.py
src/codexlens/semantic/reranker/legacy.py
src/codexlens/semantic/reranker/litellm_reranker.py
src/codexlens/semantic/reranker/onnx_reranker.py
src/codexlens/storage/__init__.py
src/codexlens/storage/dir_index.py
src/codexlens/storage/file_cache.py
src/codexlens/storage/global_index.py
src/codexlens/storage/index_tree.py
src/codexlens/storage/merkle_tree.py
src/codexlens/storage/migration_manager.py
src/codexlens/storage/path_mapper.py
src/codexlens/storage/registry.py
src/codexlens/storage/splade_index.py
src/codexlens/storage/sqlite_store.py
src/codexlens/storage/sqlite_utils.py
src/codexlens/storage/vector_meta_store.py
src/codexlens/storage/migrations/__init__.py
src/codexlens/storage/migrations/migration_001_normalize_keywords.py
src/codexlens/storage/migrations/migration_002_add_token_metadata.py
src/codexlens/storage/migrations/migration_004_dual_fts.py
src/codexlens/storage/migrations/migration_005_cleanup_unused_fields.py
src/codexlens/storage/migrations/migration_006_enhance_relationships.py
src/codexlens/storage/migrations/migration_007_add_graph_neighbors.py
src/codexlens/storage/migrations/migration_008_add_merkle_hashes.py
src/codexlens/storage/migrations/migration_009_add_splade.py
src/codexlens/storage/migrations/migration_010_add_multi_vector_chunks.py
src/codexlens/watcher/__init__.py
src/codexlens/watcher/events.py
src/codexlens/watcher/file_watcher.py
src/codexlens/watcher/incremental_indexer.py
src/codexlens/watcher/manager.py
tests/test_ann_index.py
tests/test_api_reranker.py
tests/test_chain_search.py
tests/test_cli_hybrid_search.py
tests/test_cli_output.py
tests/test_code_extractor.py
tests/test_config.py
tests/test_dual_fts.py
tests/test_embedder.py
tests/test_embedding_backend_availability.py
tests/test_encoding.py
tests/test_enrichment.py
@@ -67,15 +94,22 @@ tests/test_errors.py
tests/test_file_cache.py
tests/test_global_index.py
tests/test_global_symbol_index.py
tests/test_graph_expansion.py
tests/test_hybrid_chunker.py
tests/test_hybrid_search_e2e.py
tests/test_hybrid_search_reranker_backend.py
tests/test_incremental_indexing.py
tests/test_litellm_reranker.py
tests/test_merkle_detection.py
tests/test_parser_integration.py
tests/test_parsers.py
tests/test_performance_optimizations.py
tests/test_pure_vector_search.py
tests/test_query_parser.py
tests/test_recursive_splitting.py
tests/test_registry.py
tests/test_reranker_backends.py
tests/test_reranker_factory.py
tests/test_result_grouping.py
tests/test_rrf_fusion.py
tests/test_schema_cleanup_migration.py
@@ -85,11 +119,14 @@ tests/test_search_full_coverage.py
tests/test_search_performance.py
tests/test_semantic.py
tests/test_semantic_search.py
tests/test_sqlite_store.py
tests/test_storage.py
tests/test_storage_concurrency.py
tests/test_symbol_extractor.py
tests/test_token_chunking.py
tests/test_token_storage.py
tests/test_tokenizer.py
tests/test_tokenizer_performance.py
tests/test_treesitter_parser.py
tests/test_vector_search_full.py
tests/test_vector_search_full.py
tests/test_vector_store.py

View File

@@ -6,6 +6,7 @@ tree-sitter-python>=0.25
tree-sitter-javascript>=0.25
tree-sitter-typescript>=0.23
pathspec>=0.11
watchdog>=3.0
[encoding]
chardet>=5.0
@@ -13,6 +14,25 @@ chardet>=5.0
[full]
tiktoken>=0.5.0
[reranker]
optimum>=1.16
onnxruntime>=1.15
transformers>=4.36
[reranker-api]
httpx>=0.25
[reranker-legacy]
sentence-transformers>=2.2
[reranker-litellm]
ccw-litellm>=0.1
[reranker-onnx]
optimum>=1.16
onnxruntime>=1.15
transformers>=4.36
[semantic]
numpy>=1.24
fastembed>=0.2
@@ -29,3 +49,11 @@ numpy>=1.24
fastembed>=0.2
hnswlib>=0.8.0
onnxruntime-gpu>=1.15.0
[splade]
transformers>=4.36
optimum[onnxruntime]>=1.16
[splade-gpu]
transformers>=4.36
optimum[onnxruntime-gpu]>=1.16

View File

@@ -36,6 +36,27 @@ from .output import (
app = typer.Typer(help="CodexLens CLI — local code indexing and search.")
# Index subcommand group for reorganized commands
index_app = typer.Typer(help="Index management commands (embeddings, SPLADE, migrations).")
app.add_typer(index_app, name="index")
def _deprecated_command_warning(old_name: str, new_name: str) -> None:
"""Display deprecation warning for renamed commands.
Args:
old_name: The old command name being deprecated
new_name: The new command name to use instead
"""
console.print(
f"[yellow]Warning:[/yellow] '{old_name}' is deprecated. "
f"Use '{new_name}' instead."
)
# Index management subcommand group
index_app = typer.Typer(help="Index management commands (init, embeddings, splade, binary, status, migrate, all)")
app.add_typer(index_app, name="index")
def _configure_logging(verbose: bool, json_mode: bool = False) -> None:
"""Configure logging level.
@@ -96,8 +117,8 @@ def _get_registry_path() -> Path:
return Path.home() / ".codexlens" / "registry.db"
@app.command()
def init(
@index_app.command("init")
def index_init(
path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to index."),
language: Optional[List[str]] = typer.Option(
None,
@@ -108,8 +129,8 @@ def init(
workers: Optional[int] = typer.Option(None, "--workers", "-w", min=1, help="Parallel worker processes (default: auto-detect based on CPU count)."),
force: bool = typer.Option(False, "--force", "-f", help="Force full reindex (skip incremental mode)."),
no_embeddings: bool = typer.Option(False, "--no-embeddings", help="Skip automatic embedding generation (if semantic deps installed)."),
embedding_backend: str = typer.Option("fastembed", "--embedding-backend", help="Embedding backend: fastembed (local) or litellm (remote API)."),
embedding_model: str = typer.Option("code", "--embedding-model", help="Embedding model: profile name for fastembed (fast/code/multilingual/balanced) or model name for litellm (e.g. text-embedding-3-small)."),
backend: str = typer.Option("fastembed", "--backend", "-b", help="Embedding backend: fastembed (local) or litellm (remote API)."),
model: str = typer.Option("code", "--model", "-m", help="Embedding model: profile name for fastembed (fast/code/multilingual/balanced) or model name for litellm (e.g. text-embedding-3-small)."),
max_workers: int = typer.Option(1, "--max-workers", min=1, help="Max concurrent API calls for embedding generation. Recommended: 4-8 for litellm backend."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
@@ -125,11 +146,11 @@ def init(
If semantic search dependencies are installed, automatically generates embeddings
after indexing completes. Use --no-embeddings to skip this step.
Embedding Backend Options:
Backend Options (--backend):
- fastembed: Local ONNX-based embeddings (default, no API calls)
- litellm: Remote API embeddings via ccw-litellm (requires API keys)
Embedding Model Options:
Model Options (--model):
- For fastembed backend: Use profile names (fast, code, multilingual, balanced)
- For litellm backend: Use model names (e.g., text-embedding-3-small, text-embedding-ada-002)
"""
@@ -182,15 +203,15 @@ def init(
# Validate embedding backend
valid_backends = ["fastembed", "litellm"]
if embedding_backend not in valid_backends:
error_msg = f"Invalid embedding backend: {embedding_backend}. Must be one of: {', '.join(valid_backends)}"
if backend not in valid_backends:
error_msg = f"Invalid embedding backend: {backend}. Must be one of: {', '.join(valid_backends)}"
if json_mode:
print_json(success=False, error=error_msg)
else:
console.print(f"[red]Error:[/red] {error_msg}")
raise typer.Exit(code=1)
backend_available, backend_error = is_embedding_backend_available(embedding_backend)
backend_available, backend_error = is_embedding_backend_available(backend)
if backend_available:
# Use the index root directory (not the _index.db file)
@@ -198,8 +219,8 @@ def init(
if not json_mode:
console.print("\n[bold]Generating embeddings...[/bold]")
console.print(f"Backend: [cyan]{embedding_backend}[/cyan]")
console.print(f"Model: [cyan]{embedding_model}[/cyan]")
console.print(f"Backend: [cyan]{backend}[/cyan]")
console.print(f"Model: [cyan]{model}[/cyan]")
else:
# Output progress message for JSON mode (parsed by Node.js)
print("Generating embeddings...", flush=True)
@@ -219,8 +240,8 @@ def init(
embed_result = generate_embeddings_recursive(
index_root,
embedding_backend=embedding_backend,
model_profile=embedding_model,
embedding_backend=backend,
model_profile=model,
force=False, # Don't force regenerate during init
chunk_size=2000,
progress_callback=progress_update, # Always use callback
@@ -266,7 +287,7 @@ def init(
}
else:
if not json_mode and verbose:
console.print(f"[dim]Embedding backend '{embedding_backend}' not available. Skipping embeddings.[/dim]")
console.print(f"[dim]Embedding backend '{backend}' not available. Skipping embeddings.[/dim]")
result["embeddings"] = {
"generated": False,
"error": backend_error or "Embedding backend not available",
@@ -410,22 +431,20 @@ def watch(
@app.command()
def search(
query: str = typer.Argument(..., help="FTS query to run."),
query: str = typer.Argument(..., help="Search query."),
path: Path = typer.Option(Path("."), "--path", "-p", help="Directory to search from."),
limit: int = typer.Option(20, "--limit", "-n", min=1, max=500, help="Max results."),
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited, 0 = current only)."),
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
method: str = typer.Option("hybrid", "--method", "-m", help="Search method: fts, vector, splade, hybrid, cascade."),
use_fuzzy: bool = typer.Option(False, "--use-fuzzy", help="Enable fuzzy matching in FTS method."),
weights: Optional[str] = typer.Option(
None,
"--weights", "-w",
help="RRF weights as key=value pairs (e.g., 'splade=0.4,vector=0.6' or 'exact=0.3,fuzzy=0.1,vector=0.6'). Default: auto-detect based on available backends."
),
use_fts: bool = typer.Option(
False,
"--use-fts",
help="Use FTS (exact+fuzzy) instead of SPLADE for sparse retrieval"
help="RRF weights as key=value pairs (e.g., 'splade=0.4,vector=0.6' or 'fts=0.4,vector=0.6'). Default: auto-detect based on available backends."
),
# Hidden deprecated parameter for backward compatibility
mode: Optional[str] = typer.Option(None, "--mode", hidden=True, help="[DEPRECATED] Use --method instead."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
@@ -434,64 +453,95 @@ def search(
Uses chain search across directory indexes.
Use --depth to limit search recursion (0 = current dir only).
Search Modes:
- auto: Auto-detect (hybrid if embeddings exist, exact otherwise) [default]
- exact: Exact FTS using unicode61 tokenizer - for code identifiers
- fuzzy: Fuzzy FTS using trigram tokenizer - for typo-tolerant search
- hybrid: RRF fusion of sparse + dense search (recommended) - best recall
- vector: Vector search with sparse fallback - semantic + keyword
- pure-vector: Pure semantic vector search only - natural language queries
Search Methods:
- fts: Full-text search using FTS5 (unicode61 tokenizer). Use --use-fuzzy for typo tolerance.
- vector: Pure semantic vector search - for natural language queries.
- splade: SPLADE sparse neural search - semantic term expansion.
- hybrid: RRF fusion of sparse + dense search (default) - best recall.
- cascade: Two-stage retrieval (binary coarse + dense rerank) - fast + accurate.
SPLADE Mode:
When SPLADE is available (pip install codex-lens[splade]), it automatically
replaces FTS (exact+fuzzy) as the sparse retrieval backend. SPLADE provides
semantic term expansion for better synonym handling.
Use --use-fts to force FTS mode instead of SPLADE.
Method Selection Guide:
- Code identifiers (function/class names): fts
- Natural language queries: vector or hybrid
- Typo-tolerant search: fts --use-fuzzy
- Best overall quality: hybrid (default)
- Large codebase performance: cascade
Vector Search Requirements:
Vector search modes require pre-generated embeddings.
Vector, hybrid, and cascade methods require pre-generated embeddings.
Use 'codexlens embeddings-generate' to create embeddings first.
Hybrid Mode Weights:
Use --weights to adjust RRF fusion weights:
- SPLADE mode: 'splade=0.4,vector=0.6' (default)
- FTS mode: 'exact=0.3,fuzzy=0.1,vector=0.6' (default)
Legacy format also supported: '0.3,0.1,0.6' (exact,fuzzy,vector)
- FTS mode: 'fts=0.4,vector=0.6' (default)
Examples:
# Auto-detect mode (uses hybrid if embeddings available)
# Default hybrid search
codexlens search "authentication"
# Explicit exact code search
codexlens search "authenticate_user" --mode exact
# Exact code identifier search
codexlens search "authenticate_user" --method fts
# Semantic search (requires embeddings)
codexlens search "how to verify user credentials" --mode pure-vector
# Typo-tolerant fuzzy search
codexlens search "authentcate" --method fts --use-fuzzy
# Force hybrid mode with custom weights
codexlens search "authentication" --mode hybrid --weights splade=0.5,vector=0.5
# Force FTS instead of SPLADE
codexlens search "authentication" --use-fts
# Pure semantic search
codexlens search "how to verify user credentials" --method vector
# SPLADE sparse neural search
codexlens search "user login flow" --method splade
# Fast cascade retrieval for large codebases
codexlens search "authentication" --method cascade
# Hybrid with custom weights
codexlens search "authentication" --method hybrid --weights splade=0.5,vector=0.5
"""
_configure_logging(verbose, json_mode)
search_path = path.expanduser().resolve()
# Configure search with FTS fallback if requested
config = Config()
if use_fts:
config.use_fts_fallback = True
# Validate mode
valid_modes = ["auto", "exact", "fuzzy", "hybrid", "vector", "pure-vector"]
if mode not in valid_modes:
if json_mode:
print_json(success=False, error=f"Invalid mode: {mode}. Must be one of: {', '.join(valid_modes)}")
# Handle deprecated --mode parameter
actual_method = method
if mode is not None:
# Show deprecation warning
if not json_mode:
console.print("[yellow]Warning: --mode is deprecated, use --method instead.[/yellow]")
# Map old mode values to new method values
mode_to_method = {
"auto": "hybrid",
"exact": "fts",
"fuzzy": "fts", # with use_fuzzy=True
"hybrid": "hybrid",
"vector": "vector",
"pure-vector": "vector",
}
if mode in mode_to_method:
actual_method = mode_to_method[mode]
# Enable fuzzy for old fuzzy mode
if mode == "fuzzy":
use_fuzzy = True
else:
console.print(f"[red]Invalid mode:[/red] {mode}")
console.print(f"[dim]Valid modes: {', '.join(valid_modes)}[/dim]")
if json_mode:
print_json(success=False, error=f"Invalid deprecated mode: {mode}. Use --method instead.")
else:
console.print(f"[red]Invalid deprecated mode:[/red] {mode}")
console.print("[dim]Use --method with: fts, vector, splade, hybrid, cascade[/dim]")
raise typer.Exit(code=1)
# Configure search
config = Config()
# Validate method
valid_methods = ["fts", "vector", "splade", "hybrid", "cascade"]
if actual_method not in valid_methods:
if json_mode:
print_json(success=False, error=f"Invalid method: {actual_method}. Must be one of: {', '.join(valid_methods)}")
else:
console.print(f"[red]Invalid method:[/red] {actual_method}")
console.print(f"[dim]Valid methods: {', '.join(valid_methods)}[/dim]")
raise typer.Exit(code=1)
# Parse custom weights if provided
@@ -557,48 +607,49 @@ def search(
engine = ChainSearchEngine(registry, mapper, config=config)
# Auto-detect mode if set to "auto"
actual_mode = mode
if mode == "auto":
# Check if embeddings are available by looking for project in registry
project_record = registry.find_by_source_path(str(search_path))
has_embeddings = False
if project_record:
# Check if index has embeddings
index_path = Path(project_record["index_root"]) / "_index.db"
try:
from codexlens.cli.embedding_manager import check_embeddings_status
embed_status = check_embeddings_status(index_path)
if embed_status["success"]:
embed_data = embed_status["result"]
has_embeddings = embed_data["has_embeddings"] and embed_data["chunks_count"] > 0
except Exception:
pass
# Choose mode based on embedding availability
if has_embeddings:
actual_mode = "hybrid"
if not json_mode and verbose:
console.print("[dim]Auto-detected mode: hybrid (embeddings available)[/dim]")
else:
actual_mode = "exact"
if not json_mode and verbose:
console.print("[dim]Auto-detected mode: exact (no embeddings)[/dim]")
# Map mode to options
if actual_mode == "exact":
hybrid_mode, enable_fuzzy, enable_vector, pure_vector = False, False, False, False
elif actual_mode == "fuzzy":
hybrid_mode, enable_fuzzy, enable_vector, pure_vector = False, True, False, False
elif actual_mode == "vector":
hybrid_mode, enable_fuzzy, enable_vector, pure_vector = True, False, True, False # Vector + exact fallback
elif actual_mode == "pure-vector":
hybrid_mode, enable_fuzzy, enable_vector, pure_vector = True, False, True, True # Pure vector only
elif actual_mode == "hybrid":
hybrid_mode, enable_fuzzy, enable_vector, pure_vector = True, True, True, False
# Map method to SearchOptions flags
# fts: FTS-only search (optionally with fuzzy)
# vector: Pure vector semantic search
# splade: SPLADE sparse neural search
# hybrid: RRF fusion of sparse + dense
# cascade: Two-stage binary + dense retrieval
if actual_method == "fts":
hybrid_mode = False
enable_fuzzy = use_fuzzy
enable_vector = False
pure_vector = False
enable_splade = False
enable_cascade = False
elif actual_method == "vector":
hybrid_mode = True
enable_fuzzy = False
enable_vector = True
pure_vector = True
enable_splade = False
enable_cascade = False
elif actual_method == "splade":
hybrid_mode = True
enable_fuzzy = False
enable_vector = False
pure_vector = False
enable_splade = True
enable_cascade = False
elif actual_method == "hybrid":
hybrid_mode = True
enable_fuzzy = use_fuzzy
enable_vector = True
pure_vector = False
enable_splade = True # SPLADE is preferred sparse in hybrid
enable_cascade = False
elif actual_method == "cascade":
hybrid_mode = True
enable_fuzzy = False
enable_vector = True
pure_vector = False
enable_splade = False
enable_cascade = True
else:
raise ValueError(f"Invalid mode: {actual_mode}")
raise ValueError(f"Invalid method: {actual_method}")
options = SearchOptions(
depth=depth,
@@ -1960,8 +2011,8 @@ def embeddings_status(
console.print(f" [cyan]codexlens embeddings-generate {index_path}[/cyan]")
@app.command(name="embeddings-generate")
def embeddings_generate(
@index_app.command("embeddings")
def index_embeddings(
path: Path = typer.Argument(
...,
exists=True,
@@ -2000,10 +2051,10 @@ def embeddings_generate(
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
centralized: bool = typer.Option(
False,
"--centralized",
"-c",
help="Use centralized vector storage (single HNSW index at project root).",
True,
"--centralized/--distributed",
"-c/-d",
help="Use centralized vector storage (default) or distributed per-directory indexes.",
),
) -> None:
"""Generate semantic embeddings for code search.
@@ -2033,11 +2084,11 @@ def embeddings_generate(
- Any model supported by ccw-litellm
Examples:
codexlens embeddings-generate ~/projects/my-app # Auto-find index (fastembed, code profile)
codexlens embeddings-generate ~/.codexlens/indexes/project/_index.db # Specific index
codexlens embeddings-generate ~/projects/my-app --backend litellm --model text-embedding-3-small # Use LiteLLM
codexlens embeddings-generate ~/projects/my-app --model fast --force # Regenerate with fast profile
codexlens embeddings-generate ~/projects/my-app --centralized # Centralized vector storage
codexlens index embeddings ~/projects/my-app # Auto-find index (fastembed, code profile)
codexlens index embeddings ~/.codexlens/indexes/project/_index.db # Specific index
codexlens index embeddings ~/projects/my-app --backend litellm --model text-embedding-3-small # Use LiteLLM
codexlens index embeddings ~/projects/my-app --model fast --force # Regenerate with fast profile
codexlens index embeddings ~/projects/my-app --centralized # Centralized vector storage
"""
_configure_logging(verbose, json_mode)
@@ -2072,25 +2123,20 @@ def embeddings_generate(
index_path = target_path
index_root = target_path.parent
elif target_path.is_dir():
# Directory: Try to find index for this project
if centralized:
# Centralized mode uses directory as root
index_root = target_path
else:
# Single index mode: find the specific index
registry = RegistryStore()
try:
registry.initialize()
mapper = PathMapper()
index_path = mapper.source_to_index_db(target_path)
# Directory: Find index location from registry
registry = RegistryStore()
try:
registry.initialize()
mapper = PathMapper()
index_path = mapper.source_to_index_db(target_path)
if not index_path.exists():
console.print(f"[red]Error:[/red] No index found for {target_path}")
console.print("Run 'codexlens init' first to create an index")
raise typer.Exit(code=1)
index_root = index_path.parent
finally:
registry.close()
if not index_path.exists():
console.print(f"[red]Error:[/red] No index found for {target_path}")
console.print("Run 'codexlens init' first to create an index")
raise typer.Exit(code=1)
index_root = index_path.parent # Use index directory for both modes
finally:
registry.close()
else:
console.print(f"[red]Error:[/red] Path must be _index.db file or directory")
raise typer.Exit(code=1)
@@ -2442,8 +2488,8 @@ def gpu_reset(
# ==================== SPLADE Commands ====================
@app.command("splade-index")
def splade_index_command(
@index_app.command("splade")
def index_splade(
path: Path = typer.Argument(..., help="Project path to index"),
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
@@ -2457,8 +2503,8 @@ def splade_index_command(
index directory and builds SPLADE encodings for chunks across all of them.
Examples:
codexlens splade-index ~/projects/my-app
codexlens splade-index . --rebuild
codexlens index splade ~/projects/my-app
codexlens index splade . --rebuild
"""
_configure_logging(verbose)

View File

@@ -1170,6 +1170,22 @@ def generate_dense_embeddings_centralized(
if progress_callback:
progress_callback(f"Found {len(index_files)} index databases for centralized embedding")
# Pre-calculate estimated chunk count for HNSW capacity
# This avoids expensive resize operations during indexing
estimated_total_files = 0
for index_path in index_files:
try:
with sqlite3.connect(index_path) as conn:
cursor = conn.execute("SELECT COUNT(*) FROM files")
estimated_total_files += cursor.fetchone()[0]
except Exception:
pass
# Heuristic: ~15 chunks per file on average
estimated_chunks = max(100000, estimated_total_files * 15)
if progress_callback:
progress_callback(f"Estimated {estimated_total_files} files, ~{estimated_chunks} chunks")
# Check for existing centralized index
central_hnsw_path = index_root / VECTORS_HNSW_NAME
if central_hnsw_path.exists() and not force:
@@ -1217,11 +1233,12 @@ def generate_dense_embeddings_centralized(
"error": f"Failed to initialize components: {str(e)}",
}
# Create centralized ANN index
# Create centralized ANN index with pre-calculated capacity
# Using estimated_chunks avoids expensive resize operations during indexing
central_ann_index = ANNIndex.create_central(
index_root=index_root,
dim=embedder.embedding_dim,
initial_capacity=100000, # Larger capacity for centralized index
initial_capacity=estimated_chunks,
auto_save=False,
)
@@ -1360,6 +1377,148 @@ def generate_dense_embeddings_centralized(
logger.warning("Failed to store vector metadata: %s", e)
# Non-fatal: continue without centralized metadata
# --- Binary Vector Generation for Cascade Search (Memory-Mapped) ---
binary_success = False
binary_count = 0
try:
from codexlens.config import Config, BINARY_VECTORS_MMAP_NAME
config = Config.load()
if getattr(config, 'enable_binary_cascade', True) and all_embeddings:
import numpy as np
if progress_callback:
progress_callback(f"Generating binary vectors for {len(all_embeddings)} chunks...")
# Binarize dense vectors: sign(x) -> 1 if x > 0, 0 otherwise
# Pack into bytes for efficient storage and Hamming distance computation
embeddings_matrix = np.vstack(all_embeddings)
binary_matrix = (embeddings_matrix > 0).astype(np.uint8)
# Pack bits into bytes (8 bits per byte) - vectorized for all rows
packed_matrix = np.packbits(binary_matrix, axis=1)
binary_count = len(packed_matrix)
# Save as memory-mapped file for efficient loading
binary_mmap_path = index_root / BINARY_VECTORS_MMAP_NAME
mmap_array = np.memmap(
str(binary_mmap_path),
dtype=np.uint8,
mode='w+',
shape=packed_matrix.shape
)
mmap_array[:] = packed_matrix
mmap_array.flush()
del mmap_array # Close the memmap
# Save metadata (shape and chunk_ids) to sidecar JSON
import json
meta_path = binary_mmap_path.with_suffix('.meta.json')
with open(meta_path, 'w') as f:
json.dump({
'shape': list(packed_matrix.shape),
'chunk_ids': all_chunk_ids,
'embedding_dim': embeddings_matrix.shape[1],
}, f)
# Also store in DB for backward compatibility
from codexlens.storage.vector_meta_store import VectorMetadataStore
binary_packed_bytes = [row.tobytes() for row in packed_matrix]
with VectorMetadataStore(vectors_meta_path) as meta_store:
meta_store.add_binary_vectors(all_chunk_ids, binary_packed_bytes)
binary_success = True
if progress_callback:
progress_callback(f"Generated {binary_count} binary vectors ({embeddings_matrix.shape[1]} dims -> {packed_matrix.shape[1]} bytes, mmap: {binary_mmap_path.name})")
except Exception as e:
logger.warning("Binary vector generation failed: %s", e)
# Non-fatal: continue without binary vectors
# --- SPLADE Sparse Index Generation (Centralized) ---
splade_success = False
splade_chunks_count = 0
try:
from codexlens.config import Config
config = Config.load()
if config.enable_splade and chunk_id_to_info:
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
from codexlens.storage.splade_index import SpladeIndex
import json
ok, err = check_splade_available()
if ok:
if progress_callback:
progress_callback(f"Generating SPLADE sparse vectors for {len(chunk_id_to_info)} chunks...")
# Initialize SPLADE encoder and index
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
splade_db_path = index_root / SPLADE_DB_NAME
splade_index = SpladeIndex(splade_db_path)
splade_index.create_tables()
# Batch encode for efficiency
SPLADE_BATCH_SIZE = 32
all_postings = []
all_chunk_metadata = []
# Create batches from chunk_id_to_info
chunk_items = list(chunk_id_to_info.items())
for i in range(0, len(chunk_items), SPLADE_BATCH_SIZE):
batch_items = chunk_items[i:i + SPLADE_BATCH_SIZE]
chunk_ids = [item[0] for item in batch_items]
chunk_contents = [item[1]["content"] for item in batch_items]
# Generate sparse vectors
sparse_vecs = splade_encoder.encode_batch(chunk_contents, batch_size=SPLADE_BATCH_SIZE)
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
all_postings.append((cid, sparse_vec))
if progress_callback and (i + SPLADE_BATCH_SIZE) % 100 == 0:
progress_callback(f"SPLADE encoding: {min(i + SPLADE_BATCH_SIZE, len(chunk_items))}/{len(chunk_items)}")
# Batch insert all postings
if all_postings:
splade_index.add_postings_batch(all_postings)
# CRITICAL FIX: Populate splade_chunks table
for cid, info in chunk_id_to_info.items():
metadata_str = json.dumps(info.get("metadata", {})) if info.get("metadata") else None
all_chunk_metadata.append((
cid,
info["file_path"],
info["content"],
metadata_str,
info.get("source_index_db")
))
if all_chunk_metadata:
splade_index.add_chunks_metadata_batch(all_chunk_metadata)
splade_chunks_count = len(all_chunk_metadata)
# Set metadata
splade_index.set_metadata(
model_name=splade_encoder.model_name,
vocab_size=splade_encoder.vocab_size
)
splade_index.close()
splade_success = True
if progress_callback:
progress_callback(f"SPLADE index created: {len(all_postings)} postings, {splade_chunks_count} chunks")
else:
if progress_callback:
progress_callback(f"SPLADE not available, skipping sparse index: {err}")
except Exception as e:
logger.warning("SPLADE encoding failed: %s", e)
if progress_callback:
progress_callback(f"SPLADE encoding failed: {e}")
elapsed_time = time.time() - start_time
# Cleanup
@@ -1380,6 +1539,10 @@ def generate_dense_embeddings_centralized(
"model_name": embedder.model_name,
"central_index_path": str(central_hnsw_path),
"failed_files": failed_files[:5],
"splade_success": splade_success,
"splade_chunks": splade_chunks_count,
"binary_success": binary_success,
"binary_count": binary_count,
},
}

View File

@@ -25,6 +25,7 @@ SPLADE_DB_NAME = "_splade.db"
# Dense vector storage names (centralized storage)
VECTORS_HNSW_NAME = "_vectors.hnsw"
VECTORS_META_DB_NAME = "_vectors_meta.db"
BINARY_VECTORS_MMAP_NAME = "_binary_vectors.mmap"
log = logging.getLogger(__name__)

View File

@@ -0,0 +1,277 @@
"""Binary vector searcher for cascade search.
This module provides fast binary vector search using Hamming distance
for the first stage of cascade search (coarse filtering).
Supports two loading modes:
1. Memory-mapped file (preferred): Low memory footprint, OS-managed paging
2. Database loading (fallback): Loads all vectors into RAM
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
# Pre-computed popcount lookup table for vectorized Hamming distance
# Each byte value (0-255) maps to its bit count
_POPCOUNT_TABLE = np.array([bin(i).count('1') for i in range(256)], dtype=np.uint8)
class BinarySearcher:
"""Fast binary vector search using Hamming distance.
This class implements the first stage of cascade search:
fast, approximate retrieval using binary vectors and Hamming distance.
The binary vectors are derived from dense embeddings by thresholding:
binary[i] = 1 if dense[i] > 0 else 0
Hamming distance between two binary vectors counts the number of
differing bits, which can be computed very efficiently using XOR
and population count.
Supports two loading modes:
- Memory-mapped file (preferred): Uses np.memmap for minimal RAM usage
- Database (fallback): Loads all vectors into memory from SQLite
"""
def __init__(self, index_root_or_meta_path: Path) -> None:
"""Initialize BinarySearcher.
Args:
index_root_or_meta_path: Either:
- Path to index root directory (containing _binary_vectors.mmap)
- Path to _vectors_meta.db (legacy mode, loads from DB)
"""
path = Path(index_root_or_meta_path)
# Determine if this is an index root or a specific DB path
if path.suffix == '.db':
# Legacy mode: specific DB path
self.index_root = path.parent
self.meta_store_path = path
else:
# New mode: index root directory
self.index_root = path
self.meta_store_path = path / "_vectors_meta.db"
self._chunk_ids: Optional[np.ndarray] = None
self._binary_matrix: Optional[np.ndarray] = None
self._is_memmap = False
self._loaded = False
def load(self) -> bool:
"""Load binary vectors using memory-mapped file or database fallback.
Tries to load from memory-mapped file first (preferred for large indexes),
falls back to database loading if mmap file doesn't exist.
Returns:
True if vectors were loaded successfully.
"""
if self._loaded:
return True
# Try memory-mapped file first (preferred)
mmap_path = self.index_root / "_binary_vectors.mmap"
meta_path = mmap_path.with_suffix('.meta.json')
if mmap_path.exists() and meta_path.exists():
try:
with open(meta_path, 'r') as f:
meta = json.load(f)
shape = tuple(meta['shape'])
self._chunk_ids = np.array(meta['chunk_ids'], dtype=np.int64)
# Memory-map the binary matrix (read-only)
self._binary_matrix = np.memmap(
str(mmap_path),
dtype=np.uint8,
mode='r',
shape=shape
)
self._is_memmap = True
self._loaded = True
logger.info(
"Memory-mapped %d binary vectors (%d bytes each)",
len(self._chunk_ids), shape[1]
)
return True
except Exception as e:
logger.warning("Failed to load mmap binary vectors, falling back to DB: %s", e)
# Fallback: load from database
return self._load_from_db()
def _load_from_db(self) -> bool:
"""Load binary vectors from database (legacy/fallback mode).
Returns:
True if vectors were loaded successfully.
"""
try:
from codexlens.storage.vector_meta_store import VectorMetadataStore
with VectorMetadataStore(self.meta_store_path) as store:
rows = store.get_all_binary_vectors()
if not rows:
logger.warning("No binary vectors found in %s", self.meta_store_path)
return False
# Convert to numpy arrays for fast computation
self._chunk_ids = np.array([r[0] for r in rows], dtype=np.int64)
# Unpack bytes to numpy array
binary_arrays = []
for _, vec_bytes in rows:
arr = np.frombuffer(vec_bytes, dtype=np.uint8)
binary_arrays.append(arr)
self._binary_matrix = np.vstack(binary_arrays)
self._is_memmap = False
self._loaded = True
logger.info(
"Loaded %d binary vectors from DB (%d bytes each)",
len(self._chunk_ids), self._binary_matrix.shape[1]
)
return True
except Exception as e:
logger.error("Failed to load binary vectors: %s", e)
return False
def search(
self,
query_vector: np.ndarray,
top_k: int = 100
) -> List[Tuple[int, int]]:
"""Search for similar vectors using Hamming distance.
Args:
query_vector: Dense query vector (will be binarized).
top_k: Number of top results to return.
Returns:
List of (chunk_id, hamming_distance) tuples sorted by distance.
"""
if not self._loaded and not self.load():
return []
# Binarize query vector
query_binary = (query_vector > 0).astype(np.uint8)
query_packed = np.packbits(query_binary)
# Compute Hamming distances using XOR and popcount
# XOR gives 1 for differing bits
xor_result = np.bitwise_xor(self._binary_matrix, query_packed)
# Vectorized popcount using lookup table (orders of magnitude faster)
# Sum the bit counts for each byte across all columns
distances = np.sum(_POPCOUNT_TABLE[xor_result], axis=1, dtype=np.int32)
# Get top-k with smallest distances
if top_k >= len(distances):
top_indices = np.argsort(distances)
else:
# Partial sort for efficiency
top_indices = np.argpartition(distances, top_k)[:top_k]
top_indices = top_indices[np.argsort(distances[top_indices])]
results = [
(int(self._chunk_ids[i]), int(distances[i]))
for i in top_indices
]
return results
def search_with_rerank(
self,
query_dense: np.ndarray,
dense_vectors: np.ndarray,
dense_chunk_ids: np.ndarray,
top_k: int = 10,
candidates: int = 100
) -> List[Tuple[int, float]]:
"""Two-stage cascade search: binary filter + dense rerank.
Args:
query_dense: Dense query vector.
dense_vectors: Dense vectors for reranking (from HNSW or stored).
dense_chunk_ids: Chunk IDs corresponding to dense_vectors.
top_k: Final number of results.
candidates: Number of candidates from binary search.
Returns:
List of (chunk_id, cosine_similarity) tuples.
"""
# Stage 1: Binary filtering
binary_results = self.search(query_dense, top_k=candidates)
if not binary_results:
return []
candidate_ids = {r[0] for r in binary_results}
# Stage 2: Dense reranking
# Find indices of candidates in dense_vectors
candidate_mask = np.isin(dense_chunk_ids, list(candidate_ids))
candidate_indices = np.where(candidate_mask)[0]
if len(candidate_indices) == 0:
# Fallback: return binary results with normalized distance
max_dist = max(r[1] for r in binary_results) if binary_results else 1
return [(r[0], 1.0 - r[1] / max_dist) for r in binary_results[:top_k]]
# Compute cosine similarities for candidates
candidate_vectors = dense_vectors[candidate_indices]
candidate_ids_array = dense_chunk_ids[candidate_indices]
# Normalize vectors
query_norm = query_dense / (np.linalg.norm(query_dense) + 1e-8)
cand_norms = candidate_vectors / (
np.linalg.norm(candidate_vectors, axis=1, keepdims=True) + 1e-8
)
# Cosine similarities
similarities = np.dot(cand_norms, query_norm)
# Sort by similarity (descending)
sorted_indices = np.argsort(-similarities)[:top_k]
results = [
(int(candidate_ids_array[i]), float(similarities[i]))
for i in sorted_indices
]
return results
@property
def vector_count(self) -> int:
"""Get number of loaded binary vectors."""
return len(self._chunk_ids) if self._chunk_ids is not None else 0
@property
def is_memmap(self) -> bool:
"""Check if using memory-mapped file (vs in-memory array)."""
return self._is_memmap
def clear(self) -> None:
"""Clear loaded vectors from memory."""
# For memmap, just delete the reference (OS will handle cleanup)
if self._is_memmap and self._binary_matrix is not None:
del self._binary_matrix
self._chunk_ids = None
self._binary_matrix = None
self._is_memmap = False
self._loaded = False

View File

@@ -541,26 +541,55 @@ class ChainSearchEngine:
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
# Search all indexes for binary candidates
# Try centralized BinarySearcher first (preferred for mmap indexes)
# The index root is the parent of the first index path
index_root = index_paths[0].parent if index_paths else None
all_candidates: List[Tuple[int, int, Path]] = [] # (chunk_id, distance, index_path)
used_centralized = False
for index_path in index_paths:
try:
# Get or create binary index for this path
binary_index = self._get_or_create_binary_index(index_path)
if binary_index is None or binary_index.count() == 0:
continue
if index_root:
centralized_searcher = self._get_centralized_binary_searcher(index_root)
if centralized_searcher is not None:
try:
# BinarySearcher expects dense vector, not packed binary
from codexlens.semantic.embedder import Embedder
embedder = Embedder()
query_dense = embedder.embed_to_numpy([query])[0]
# Search binary index
ids, distances = binary_index.search(query_binary_packed, coarse_k)
for chunk_id, dist in zip(ids, distances):
all_candidates.append((chunk_id, dist, index_path))
# Centralized search - returns (chunk_id, hamming_distance) tuples
results = centralized_searcher.search(query_dense, top_k=coarse_k)
for chunk_id, dist in results:
all_candidates.append((chunk_id, dist, index_root))
used_centralized = True
self.logger.debug(
"Centralized binary search found %d candidates", len(results)
)
except Exception as exc:
self.logger.debug(
"Centralized binary search failed: %s, falling back to per-directory",
exc
)
centralized_searcher.clear()
except Exception as exc:
self.logger.debug(
"Binary search failed for %s: %s", index_path, exc
)
stats.errors.append(f"Binary search failed for {index_path}: {exc}")
# Fallback: Search per-directory indexes with legacy BinaryANNIndex
if not used_centralized:
for index_path in index_paths:
try:
# Get or create binary index for this path (uses deprecated BinaryANNIndex)
binary_index = self._get_or_create_binary_index(index_path)
if binary_index is None or binary_index.count() == 0:
continue
# Search binary index
ids, distances = binary_index.search(query_binary_packed, coarse_k)
for chunk_id, dist in zip(ids, distances):
all_candidates.append((chunk_id, dist, index_path))
except Exception as exc:
self.logger.debug(
"Binary search failed for %s: %s", index_path, exc
)
stats.errors.append(f"Binary search failed for {index_path}: {exc}")
if not all_candidates:
self.logger.debug("No binary candidates found, falling back to hybrid")
@@ -743,6 +772,10 @@ class ChainSearchEngine:
def _get_or_create_binary_index(self, index_path: Path) -> Optional[Any]:
"""Get or create a BinaryANNIndex for the given index path.
.. deprecated::
This method uses the deprecated BinaryANNIndex. For centralized indexes,
use _get_centralized_binary_searcher() instead.
Attempts to load an existing binary index from disk. If not found,
returns None (binary index should be built during indexing).
@@ -753,16 +786,48 @@ class ChainSearchEngine:
BinaryANNIndex instance or None if not available
"""
try:
from codexlens.semantic.ann_index import BinaryANNIndex
import warnings
# Suppress deprecation warning since we're using it intentionally for legacy support
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
from codexlens.semantic.ann_index import BinaryANNIndex
binary_index = BinaryANNIndex(index_path, dim=256)
if binary_index.load():
return binary_index
binary_index = BinaryANNIndex(index_path, dim=256)
if binary_index.load():
return binary_index
return None
except Exception as exc:
self.logger.debug("Failed to load binary index for %s: %s", index_path, exc)
return None
def _get_centralized_binary_searcher(self, index_root: Path) -> Optional[Any]:
"""Get centralized BinarySearcher for memory-mapped binary vectors.
This is the preferred method for centralized indexes, providing faster
search via memory-mapped files.
Args:
index_root: Root directory containing centralized index files
Returns:
BinarySearcher instance or None if not available
"""
try:
from codexlens.search.binary_searcher import BinarySearcher
binary_searcher = BinarySearcher(index_root)
if binary_searcher.load():
self.logger.debug(
"Using centralized BinarySearcher with %d vectors (mmap=%s)",
binary_searcher.vector_count,
binary_searcher.is_memmap
)
return binary_searcher
return None
except Exception as exc:
self.logger.debug("Failed to load centralized binary searcher: %s", exc)
return None
def _compute_cosine_similarity(
self,
query_vec: "np.ndarray",

View File

@@ -508,6 +508,10 @@ class ANNIndex:
class BinaryANNIndex:
"""Binary vector ANN index using Hamming distance for fast coarse retrieval.
.. deprecated::
This class is deprecated. Use :class:`codexlens.search.binary_searcher.BinarySearcher`
instead, which provides faster memory-mapped search with centralized storage.
Optimized for binary vectors (256-bit / 32 bytes per vector).
Uses packed binary representation for memory efficiency.
@@ -553,6 +557,14 @@ class BinaryANNIndex:
"Install with: pip install codexlens[semantic]"
)
import warnings
warnings.warn(
"BinaryANNIndex is deprecated. Use codexlens.search.binary_searcher.BinarySearcher "
"instead for faster memory-mapped search with centralized storage.",
DeprecationWarning,
stacklevel=2
)
if dim <= 0 or dim % 8 != 0:
raise ValueError(
f"Invalid dimension: {dim}. Must be positive and divisible by 8."

View File

@@ -220,12 +220,16 @@ class SpladeEncoder:
from transformers import AutoTokenizer
if self.providers is None:
from .gpu_support import get_optimal_providers
from .gpu_support import get_optimal_providers, get_selected_device_id
# Include device_id options for DirectML/CUDA selection when available
# Get providers as pure string list (cache-friendly)
# NOTE: with_device_options=False to avoid tuple-based providers
# which break optimum's caching mechanism
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
use_gpu=self.use_gpu, with_device_options=False
)
# Get device_id separately for provider_options
self._device_id = get_selected_device_id() if self.use_gpu else None
# Some Optimum versions accept `providers`, others accept a single `provider`
# Prefer passing the full providers list, with a conservative fallback
@@ -234,6 +238,15 @@ class SpladeEncoder:
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
# Pass device_id via provider_options for GPU selection
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
# Build provider_options dict for each GPU provider
provider_options = {}
for p in self.providers:
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
provider_options[p] = {"device_id": self._device_id}
if provider_options:
model_kwargs["provider_options"] = provider_options
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
@@ -369,6 +382,21 @@ class SpladeEncoder:
return sparse_dict
def warmup(self, text: str = "warmup query") -> None:
"""Warmup the encoder by running a dummy inference.
First-time model inference includes initialization overhead.
Call this method once before the first real search to avoid
latency spikes.
Args:
text: Dummy text for warmup (default: "warmup query")
"""
logger.info("Warming up SPLADE encoder...")
# Trigger model loading and first inference
_ = self.encode_text(text)
logger.info("SPLADE encoder warmup complete")
def encode_text(self, text: str) -> Dict[int, float]:
"""Encode text to sparse vector {token_id: weight}.

View File

@@ -59,6 +59,8 @@ class SpladeIndex:
conn.execute("PRAGMA foreign_keys=ON")
# Limit mmap to 1GB to avoid OOM on smaller systems
conn.execute("PRAGMA mmap_size=1073741824")
# Increase cache size for better query performance (20MB = -20000 pages)
conn.execute("PRAGMA cache_size=-20000")
self._local.conn = conn
return conn
@@ -385,25 +387,29 @@ class SpladeIndex:
self,
query_sparse: Dict[int, float],
limit: int = 50,
min_score: float = 0.0
min_score: float = 0.0,
max_query_terms: int = 64
) -> List[Tuple[int, float]]:
"""Search for similar chunks using dot-product scoring.
Implements efficient sparse dot-product via SQL JOIN:
score(q, d) = sum(q[t] * d[t]) for all tokens t
Args:
query_sparse: Query sparse vector as {token_id: weight}.
limit: Maximum number of results.
min_score: Minimum score threshold.
max_query_terms: Maximum query terms to use (default: 64).
Pruning to top-K terms reduces search time with minimal impact on quality.
Set to 0 or negative to disable pruning (use all terms).
Returns:
List of (chunk_id, score) tuples, ordered by score descending.
"""
if not query_sparse:
logger.warning("Empty query sparse vector")
return []
with self._lock:
conn = self._get_connection()
try:
@@ -414,10 +420,20 @@ class SpladeIndex:
for token_id, weight in query_sparse.items()
if weight > 0
]
if not query_terms:
logger.warning("No non-zero query terms")
return []
# Query pruning: keep only top-K terms by weight
# max_query_terms <= 0 means no limit (use all terms)
if max_query_terms > 0 and len(query_terms) > max_query_terms:
query_terms = sorted(query_terms, key=lambda x: x[1], reverse=True)[:max_query_terms]
logger.debug(
"Query pruned from %d to %d terms",
len(query_sparse),
len(query_terms)
)
# Create CTE for query terms using parameterized VALUES
# Build placeholders and params to prevent SQL injection

View File

@@ -96,6 +96,13 @@ class VectorMetadataStore:
'CREATE INDEX IF NOT EXISTS idx_chunk_category '
'ON chunk_metadata(category)'
)
# Binary vectors table for cascade search
conn.execute('''
CREATE TABLE IF NOT EXISTS binary_vectors (
chunk_id INTEGER PRIMARY KEY,
vector BLOB NOT NULL
)
''')
conn.commit()
logger.debug("VectorMetadataStore schema created/verified")
except sqlite3.Error as e:
@@ -329,3 +336,80 @@ class VectorMetadataStore:
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
# ============= Binary Vector Methods for Cascade Search =============
def add_binary_vectors(
self, chunk_ids: List[int], binary_vectors: List[bytes]
) -> None:
"""Batch insert binary vectors for cascade search.
Args:
chunk_ids: List of chunk IDs.
binary_vectors: List of packed binary vectors (as bytes).
"""
if not chunk_ids or len(chunk_ids) != len(binary_vectors):
return
with self._lock:
conn = self._get_connection()
try:
data = list(zip(chunk_ids, binary_vectors))
conn.executemany(
"INSERT OR REPLACE INTO binary_vectors (chunk_id, vector) VALUES (?, ?)",
data
)
conn.commit()
logger.debug("Added %d binary vectors", len(chunk_ids))
except sqlite3.Error as e:
raise StorageError(
f"Failed to add binary vectors: {e}",
db_path=str(self.db_path),
operation="add_binary_vectors"
) from e
def get_all_binary_vectors(self) -> List[tuple]:
"""Get all binary vectors for cascade search.
Returns:
List of (chunk_id, vector_bytes) tuples.
"""
conn = self._get_connection()
try:
rows = conn.execute(
"SELECT chunk_id, vector FROM binary_vectors"
).fetchall()
return [(row[0], row[1]) for row in rows]
except sqlite3.Error as e:
logger.error("Failed to get binary vectors: %s", e)
return []
def get_binary_vector_count(self) -> int:
"""Get total number of binary vectors.
Returns:
Binary vector count.
"""
conn = self._get_connection()
try:
row = conn.execute(
"SELECT COUNT(*) FROM binary_vectors"
).fetchone()
return row[0] if row else 0
except sqlite3.Error:
return 0
def clear_binary_vectors(self) -> None:
"""Clear all binary vectors."""
with self._lock:
conn = self._get_connection()
try:
conn.execute("DELETE FROM binary_vectors")
conn.commit()
logger.info("Cleared all binary vectors")
except sqlite3.Error as e:
raise StorageError(
f"Failed to clear binary vectors: {e}",
db_path=str(self.db_path),
operation="clear_binary_vectors"
) from e