mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
refactor: 移除 SPLADE 和 hybrid_cascade,精简搜索架构
删除 SPLADE 稀疏神经搜索后端和 hybrid_cascade 策略,
将搜索架构从 6 种后端简化为 4 种(FTS Exact/Fuzzy, Binary Vector, Dense Vector, LSP)。
主要变更:
- 删除 splade_encoder.py, splade_index.py, migration_009 等 4 个文件
- 移除 config.py 中 SPLADE 相关配置(enable_splade, splade_model 等)
- DEFAULT_WEIGHTS 改为 FTS 权重 {exact:0.25, fuzzy:0.1, vector:0.5, lsp:0.15}
- 删除 hybrid_cascade_search(),所有 cascade fallback 改为 self.search()
- API fusion_strategy='hybrid' 向后兼容映射到 binary_rerank
- 删除 CLI index_splade/splade_status 命令和 --method splade
- 更新测试、基准测试和文档
This commit is contained in:
@@ -48,7 +48,8 @@ def semantic_search(
|
||||
- rrf: Reciprocal Rank Fusion (recommended, default)
|
||||
- staged: Staged cascade -> staged_cascade_search
|
||||
- binary: Binary rerank cascade -> binary_cascade_search
|
||||
- hybrid: Hybrid cascade -> hybrid_cascade_search
|
||||
- hybrid: Binary rerank cascade (backward compat) -> binary_rerank_cascade_search
|
||||
- dense_rerank: Dense rerank cascade -> dense_rerank_cascade_search
|
||||
kind_filter: Symbol type filter (e.g., ["function", "class"])
|
||||
limit: Max return count (default 20)
|
||||
include_match_reason: Generate match reason (heuristic, not LLM)
|
||||
@@ -215,7 +216,8 @@ def _execute_search(
|
||||
- rrf: Standard hybrid search with RRF fusion
|
||||
- staged: staged_cascade_search
|
||||
- binary: binary_cascade_search
|
||||
- hybrid: hybrid_cascade_search
|
||||
- hybrid: binary_rerank_cascade_search (backward compat)
|
||||
- dense_rerank: dense_rerank_cascade_search
|
||||
|
||||
Args:
|
||||
engine: ChainSearchEngine instance
|
||||
@@ -249,8 +251,8 @@ def _execute_search(
|
||||
options=options,
|
||||
)
|
||||
elif fusion_strategy == "hybrid":
|
||||
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
|
||||
return engine.hybrid_cascade_search(
|
||||
# Backward compat: hybrid now maps to binary_rerank_cascade_search
|
||||
return engine.binary_rerank_cascade_search(
|
||||
query=query,
|
||||
source_path=source_path,
|
||||
k=limit,
|
||||
@@ -342,8 +344,6 @@ def _transform_results(
|
||||
fts_scores.append(source_scores["exact"])
|
||||
if "fuzzy" in source_scores:
|
||||
fts_scores.append(source_scores["fuzzy"])
|
||||
if "splade" in source_scores:
|
||||
fts_scores.append(source_scores["splade"])
|
||||
|
||||
if fts_scores:
|
||||
structural_score = max(fts_scores)
|
||||
|
||||
@@ -6,7 +6,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, Iterable, List, Optional
|
||||
|
||||
@@ -37,7 +36,7 @@ from .output import (
|
||||
app = typer.Typer(help="CodexLens CLI — local code indexing and search.")
|
||||
|
||||
# Index subcommand group for reorganized commands
|
||||
index_app = typer.Typer(help="Index management commands (init, embeddings, splade, binary, status, migrate, all)")
|
||||
index_app = typer.Typer(help="Index management commands (init, embeddings, binary, status, migrate, all)")
|
||||
app.add_typer(index_app, name="index")
|
||||
|
||||
|
||||
@@ -521,15 +520,15 @@ def search(
|
||||
print_json(success=False, error=f"Invalid deprecated mode: {mode}. Use --method instead.")
|
||||
else:
|
||||
console.print(f"[red]Invalid deprecated mode:[/red] {mode}")
|
||||
console.print("[dim]Use --method with: fts, vector, splade, hybrid, cascade[/dim]")
|
||||
console.print("[dim]Use --method with: fts, vector, hybrid, cascade[/dim]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Configure search (load settings from file)
|
||||
config = Config.load()
|
||||
|
||||
# Validate method - simplified interface exposes only dense_rerank and fts
|
||||
# Other methods (vector, splade, hybrid, cascade) are hidden but still work for backward compatibility
|
||||
valid_methods = ["fts", "dense_rerank", "vector", "splade", "hybrid", "cascade"]
|
||||
# Other methods (vector, hybrid, cascade) are hidden but still work for backward compatibility
|
||||
valid_methods = ["fts", "dense_rerank", "vector", "hybrid", "cascade"]
|
||||
if actual_method not in valid_methods:
|
||||
if json_mode:
|
||||
print_json(success=False, error=f"Invalid method: {actual_method}. Use 'dense_rerank' (semantic) or 'fts' (exact keyword).")
|
||||
@@ -561,7 +560,7 @@ def search(
|
||||
try:
|
||||
# Check if using key=value format (new) or legacy comma-separated format
|
||||
if "=" in weights:
|
||||
# New format: splade=0.4,vector=0.6 or exact=0.3,fuzzy=0.1,vector=0.6
|
||||
# New format: exact=0.3,fuzzy=0.1,vector=0.6
|
||||
weight_dict = {}
|
||||
for pair in weights.split(","):
|
||||
if "=" in pair:
|
||||
@@ -592,17 +591,6 @@ def search(
|
||||
"fuzzy": weight_parts[1],
|
||||
"vector": weight_parts[2],
|
||||
}
|
||||
elif len(weight_parts) == 2:
|
||||
# Two values: assume splade,vector
|
||||
weight_sum = sum(weight_parts)
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"splade": weight_parts[0],
|
||||
"vector": weight_parts[1],
|
||||
}
|
||||
else:
|
||||
if not json_mode:
|
||||
console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]")
|
||||
@@ -621,7 +609,6 @@ def search(
|
||||
# Map method to SearchOptions flags
|
||||
# fts: FTS-only search (optionally with fuzzy)
|
||||
# vector: Pure vector semantic search
|
||||
# splade: SPLADE sparse neural search
|
||||
# hybrid: RRF fusion of sparse + dense
|
||||
# cascade: Two-stage binary + dense retrieval
|
||||
if actual_method == "fts":
|
||||
@@ -629,35 +616,24 @@ def search(
|
||||
enable_fuzzy = use_fuzzy
|
||||
enable_vector = False
|
||||
pure_vector = False
|
||||
enable_splade = False
|
||||
enable_cascade = False
|
||||
elif actual_method == "vector":
|
||||
hybrid_mode = True
|
||||
enable_fuzzy = False
|
||||
enable_vector = True
|
||||
pure_vector = True
|
||||
enable_splade = False
|
||||
enable_cascade = False
|
||||
elif actual_method == "splade":
|
||||
hybrid_mode = True
|
||||
enable_fuzzy = False
|
||||
enable_vector = False
|
||||
pure_vector = False
|
||||
enable_splade = True
|
||||
enable_cascade = False
|
||||
elif actual_method == "hybrid":
|
||||
hybrid_mode = True
|
||||
enable_fuzzy = use_fuzzy
|
||||
enable_vector = True
|
||||
pure_vector = False
|
||||
enable_splade = True # SPLADE is preferred sparse in hybrid
|
||||
enable_cascade = False
|
||||
elif actual_method == "cascade":
|
||||
hybrid_mode = True
|
||||
enable_fuzzy = False
|
||||
enable_vector = True
|
||||
pure_vector = False
|
||||
enable_splade = False
|
||||
enable_cascade = True
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {actual_method}")
|
||||
@@ -678,7 +654,6 @@ def search(
|
||||
enable_fuzzy=enable_fuzzy,
|
||||
enable_vector=enable_vector,
|
||||
pure_vector=pure_vector,
|
||||
enable_splade=enable_splade,
|
||||
enable_cascade=enable_cascade,
|
||||
hybrid_weights=hybrid_weights,
|
||||
)
|
||||
@@ -2857,251 +2832,8 @@ def gpu_reset(
|
||||
|
||||
|
||||
|
||||
# ==================== SPLADE Commands ====================
|
||||
|
||||
@index_app.command("splade")
|
||||
def index_splade(
|
||||
path: Path = typer.Argument(..., help="Project path to index"),
|
||||
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""Generate SPLADE sparse index for existing codebase.
|
||||
|
||||
Encodes all semantic chunks with SPLADE model and builds inverted index
|
||||
for efficient sparse retrieval.
|
||||
|
||||
This command discovers all _index.db files recursively in the project's
|
||||
index directory and builds SPLADE encodings for chunks across all of them.
|
||||
|
||||
Examples:
|
||||
codexlens index splade ~/projects/my-app
|
||||
codexlens index splade . --rebuild
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
|
||||
# Check SPLADE availability
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
console.print(f"[red]SPLADE not available: {err}[/red]")
|
||||
console.print("[dim]Install with: pip install transformers torch[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Find index root directory
|
||||
target_path = path.expanduser().resolve()
|
||||
|
||||
# Determine index root directory (containing _index.db files)
|
||||
if target_path.is_file() and target_path.name == "_index.db":
|
||||
index_root = target_path.parent
|
||||
elif target_path.is_dir():
|
||||
# Check for local .codexlens/_index.db
|
||||
local_index = target_path / ".codexlens" / "_index.db"
|
||||
if local_index.exists():
|
||||
index_root = local_index.parent
|
||||
else:
|
||||
# Try to find via registry
|
||||
registry = RegistryStore()
|
||||
try:
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(target_path)
|
||||
if not index_db.exists():
|
||||
console.print(f"[red]Error:[/red] No index found for {target_path}")
|
||||
console.print("Run 'codexlens init' first to create an index")
|
||||
raise typer.Exit(1)
|
||||
index_root = index_db.parent
|
||||
finally:
|
||||
registry.close()
|
||||
else:
|
||||
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Discover all _index.db files recursively
|
||||
all_index_dbs = sorted(index_root.rglob("_index.db"))
|
||||
if not all_index_dbs:
|
||||
console.print(f"[red]Error:[/red] No _index.db files found in {index_root}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[blue]Discovered {len(all_index_dbs)} index databases[/blue]")
|
||||
|
||||
# SPLADE index is stored alongside the root _index.db
|
||||
from codexlens.config import SPLADE_DB_NAME
|
||||
splade_db = index_root / SPLADE_DB_NAME
|
||||
|
||||
if splade_db.exists() and not rebuild:
|
||||
console.print("[yellow]SPLADE index exists. Use --rebuild to regenerate.[/yellow]")
|
||||
return
|
||||
|
||||
# If rebuild, delete existing splade database
|
||||
if splade_db.exists() and rebuild:
|
||||
splade_db.unlink()
|
||||
|
||||
# Collect all chunks from all distributed index databases
|
||||
# Assign globally unique IDs to avoid collisions (each DB starts with ID 1)
|
||||
console.print(f"[blue]Loading chunks from {len(all_index_dbs)} distributed indexes...[/blue]")
|
||||
all_chunks = [] # (global_id, chunk) pairs
|
||||
total_files_checked = 0
|
||||
indexes_with_chunks = 0
|
||||
global_id = 0 # Sequential global ID across all databases
|
||||
|
||||
for index_db in all_index_dbs:
|
||||
total_files_checked += 1
|
||||
try:
|
||||
vector_store = VectorStore(index_db)
|
||||
chunks = vector_store.get_all_chunks()
|
||||
if chunks:
|
||||
indexes_with_chunks += 1
|
||||
# Assign sequential global IDs to avoid collisions
|
||||
for chunk in chunks:
|
||||
global_id += 1
|
||||
all_chunks.append((global_id, chunk, index_db))
|
||||
if verbose:
|
||||
console.print(f" [dim]{index_db.parent.name}: {len(chunks)} chunks[/dim]")
|
||||
vector_store.close()
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
console.print(f" [yellow]Warning: Failed to read {index_db}: {e}[/yellow]")
|
||||
|
||||
if not all_chunks:
|
||||
console.print("[yellow]No chunks found in any index database[/yellow]")
|
||||
console.print(f"[dim]Checked {total_files_checked} index files, found 0 chunks[/dim]")
|
||||
console.print("[dim]Generate embeddings first with 'codexlens embeddings-generate --recursive'[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[blue]Found {len(all_chunks)} chunks across {indexes_with_chunks} indexes[/blue]")
|
||||
console.print(f"[blue]Encoding with SPLADE...[/blue]")
|
||||
|
||||
# Initialize SPLADE
|
||||
encoder = get_splade_encoder()
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Encode in batches with progress bar
|
||||
chunk_metadata_batch = []
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Encoding...", total=len(all_chunks))
|
||||
for global_id, chunk, source_db_path in all_chunks:
|
||||
sparse_vec = encoder.encode_text(chunk.content)
|
||||
splade_index.add_posting(global_id, sparse_vec)
|
||||
# Store chunk metadata for self-contained search
|
||||
# Serialize metadata dict to JSON string
|
||||
metadata_str = None
|
||||
if hasattr(chunk, 'metadata') and chunk.metadata:
|
||||
try:
|
||||
metadata_str = json.dumps(chunk.metadata) if isinstance(chunk.metadata, dict) else chunk.metadata
|
||||
except Exception:
|
||||
pass
|
||||
chunk_metadata_batch.append((
|
||||
global_id,
|
||||
chunk.file_path or "",
|
||||
chunk.content,
|
||||
metadata_str,
|
||||
str(source_db_path)
|
||||
))
|
||||
progress.advance(task)
|
||||
|
||||
# Batch insert chunk metadata
|
||||
if chunk_metadata_batch:
|
||||
splade_index.add_chunks_metadata_batch(chunk_metadata_batch)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=encoder.model_name,
|
||||
vocab_size=encoder.vocab_size
|
||||
)
|
||||
|
||||
stats = splade_index.get_stats()
|
||||
console.print(f"[green]OK[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings")
|
||||
console.print(f" Source indexes: {indexes_with_chunks}")
|
||||
console.print(f" Database: [dim]{splade_db}[/dim]")
|
||||
|
||||
|
||||
@app.command("splade-status", hidden=True, deprecated=True)
|
||||
def splade_status_command(
|
||||
path: Path = typer.Argument(..., help="Project path"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""[Deprecated] Use 'codexlens index status' instead.
|
||||
|
||||
Show SPLADE index status and statistics.
|
||||
|
||||
Examples:
|
||||
codexlens splade-status ~/projects/my-app
|
||||
codexlens splade-status .
|
||||
"""
|
||||
_deprecated_command_warning("splade-status", "index status")
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
from codexlens.config import SPLADE_DB_NAME
|
||||
|
||||
# Find index database
|
||||
target_path = path.expanduser().resolve()
|
||||
|
||||
if target_path.is_file() and target_path.name == "_index.db":
|
||||
splade_db = target_path.parent / SPLADE_DB_NAME
|
||||
elif target_path.is_dir():
|
||||
# Check for local .codexlens/_splade.db
|
||||
local_splade = target_path / ".codexlens" / SPLADE_DB_NAME
|
||||
if local_splade.exists():
|
||||
splade_db = local_splade
|
||||
else:
|
||||
# Try to find via registry
|
||||
registry = RegistryStore()
|
||||
try:
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(target_path)
|
||||
splade_db = index_db.parent / SPLADE_DB_NAME
|
||||
finally:
|
||||
registry.close()
|
||||
else:
|
||||
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not splade_db.exists():
|
||||
console.print("[yellow]No SPLADE index found[/yellow]")
|
||||
console.print(f"[dim]Run 'codexlens splade-index {path}' to create one[/dim]")
|
||||
return
|
||||
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
|
||||
if not splade_index.has_index():
|
||||
console.print("[yellow]SPLADE tables not initialized[/yellow]")
|
||||
return
|
||||
|
||||
metadata = splade_index.get_metadata()
|
||||
stats = splade_index.get_stats()
|
||||
|
||||
# Create status table
|
||||
table = Table(title="SPLADE Index Status", show_header=False)
|
||||
table.add_column("Property", style="cyan")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Database", str(splade_db))
|
||||
if metadata:
|
||||
table.add_row("Model", metadata['model_name'])
|
||||
table.add_row("Vocab Size", str(metadata['vocab_size']))
|
||||
table.add_row("Chunks", str(stats['unique_chunks']))
|
||||
table.add_row("Unique Tokens", str(stats['unique_tokens']))
|
||||
table.add_row("Total Postings", str(stats['total_postings']))
|
||||
|
||||
ok, err = check_splade_available()
|
||||
status_text = "[green]Yes[/green]" if ok else f"[red]No[/red] - {err}"
|
||||
table.add_row("SPLADE Available", status_text)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# ==================== Watch Command ====================
|
||||
@@ -3516,11 +3248,10 @@ def index_status(
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""Show comprehensive index status (embeddings + SPLADE).
|
||||
"""Show comprehensive index status (embeddings).
|
||||
|
||||
Shows combined status for all index types:
|
||||
- Dense vector embeddings (HNSW)
|
||||
- SPLADE sparse embeddings
|
||||
- Binary cascade embeddings
|
||||
|
||||
Examples:
|
||||
@@ -3531,9 +3262,6 @@ def index_status(
|
||||
_configure_logging(verbose, json_mode)
|
||||
|
||||
from codexlens.cli.embedding_manager import check_index_embeddings, get_embedding_stats_summary
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
from codexlens.config import SPLADE_DB_NAME
|
||||
|
||||
# Determine target path and index root
|
||||
if path is None:
|
||||
@@ -3571,36 +3299,11 @@ def index_status(
|
||||
# Get embeddings status
|
||||
embeddings_result = get_embedding_stats_summary(index_root)
|
||||
|
||||
# Get SPLADE status
|
||||
splade_db = index_root / SPLADE_DB_NAME
|
||||
splade_status = {
|
||||
"available": False,
|
||||
"has_index": False,
|
||||
"stats": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
splade_available, splade_err = check_splade_available()
|
||||
splade_status["available"] = splade_available
|
||||
|
||||
if splade_db.exists():
|
||||
try:
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
if splade_index.has_index():
|
||||
splade_status["has_index"] = True
|
||||
splade_status["stats"] = splade_index.get_stats()
|
||||
splade_status["metadata"] = splade_index.get_metadata()
|
||||
splade_index.close()
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
console.print(f"[yellow]Warning: Failed to read SPLADE index: {e}[/yellow]")
|
||||
|
||||
# Build combined result
|
||||
result = {
|
||||
"index_root": str(index_root),
|
||||
"embeddings": embeddings_result.get("result") if embeddings_result.get("success") else None,
|
||||
"embeddings_error": embeddings_result.get("error") if not embeddings_result.get("success") else None,
|
||||
"splade": splade_status,
|
||||
}
|
||||
|
||||
if json_mode:
|
||||
@@ -3623,27 +3326,6 @@ def index_status(
|
||||
else:
|
||||
console.print(f" [yellow]--[/yellow] {embeddings_result.get('error', 'Not available')}")
|
||||
|
||||
# SPLADE section
|
||||
console.print("\n[bold]SPLADE Sparse Index:[/bold]")
|
||||
if splade_status["has_index"]:
|
||||
stats = splade_status["stats"] or {}
|
||||
metadata = splade_status["metadata"] or {}
|
||||
console.print(f" [green]OK[/green] SPLADE index available")
|
||||
console.print(f" Chunks: {stats.get('unique_chunks', 0):,}")
|
||||
console.print(f" Unique tokens: {stats.get('unique_tokens', 0):,}")
|
||||
console.print(f" Total postings: {stats.get('total_postings', 0):,}")
|
||||
if metadata.get("model_name"):
|
||||
console.print(f" Model: {metadata['model_name']}")
|
||||
elif splade_available:
|
||||
console.print(f" [yellow]--[/yellow] No SPLADE index found")
|
||||
console.print(f" [dim]Run 'codexlens index splade <path>' to create one[/dim]")
|
||||
else:
|
||||
console.print(f" [yellow]--[/yellow] SPLADE not available: {splade_err}")
|
||||
|
||||
# Runtime availability
|
||||
console.print("\n[bold]Runtime Availability:[/bold]")
|
||||
console.print(f" SPLADE encoder: {'[green]Yes[/green]' if splade_available else f'[red]No[/red] ({splade_err})'}")
|
||||
|
||||
|
||||
# ==================== Index Update Command ====================
|
||||
|
||||
@@ -3739,22 +3421,19 @@ def index_all(
|
||||
backend: str = typer.Option("fastembed", "--backend", "-b", help="Embedding backend: fastembed or litellm."),
|
||||
model: str = typer.Option("code", "--model", "-m", help="Embedding model profile or name."),
|
||||
max_workers: int = typer.Option(1, "--max-workers", min=1, help="Max concurrent API calls."),
|
||||
skip_splade: bool = typer.Option(False, "--skip-splade", help="Skip SPLADE index generation."),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Run all indexing operations in sequence (init, embeddings, splade).
|
||||
"""Run all indexing operations in sequence (init, embeddings).
|
||||
|
||||
This is a convenience command that runs the complete indexing pipeline:
|
||||
1. FTS index initialization (index init)
|
||||
2. Dense vector embeddings (index embeddings)
|
||||
3. SPLADE sparse index (index splade) - unless --skip-splade
|
||||
|
||||
Examples:
|
||||
codexlens index all ~/projects/my-app
|
||||
codexlens index all . --force
|
||||
codexlens index all . --backend litellm --model text-embedding-3-small
|
||||
codexlens index all . --skip-splade
|
||||
"""
|
||||
_configure_logging(verbose, json_mode)
|
||||
|
||||
@@ -3766,7 +3445,7 @@ def index_all(
|
||||
|
||||
# Step 1: Run init
|
||||
if not json_mode:
|
||||
console.print(f"[bold]Step 1/3: Initializing FTS index...[/bold]")
|
||||
console.print(f"[bold]Step 1/2: Initializing FTS index...[/bold]")
|
||||
|
||||
try:
|
||||
# Import and call the init function directly
|
||||
@@ -3810,7 +3489,7 @@ def index_all(
|
||||
|
||||
# Step 2: Generate embeddings
|
||||
if not json_mode:
|
||||
console.print(f"\n[bold]Step 2/3: Generating dense embeddings...[/bold]")
|
||||
console.print(f"\n[bold]Step 2/2: Generating dense embeddings...[/bold]")
|
||||
|
||||
try:
|
||||
from codexlens.cli.embedding_manager import generate_dense_embeddings_centralized
|
||||
@@ -3851,103 +3530,6 @@ def index_all(
|
||||
if not json_mode:
|
||||
console.print(f" [yellow]Warning:[/yellow] {e}")
|
||||
|
||||
# Step 3: Generate SPLADE index (unless skipped)
|
||||
if not skip_splade:
|
||||
if not json_mode:
|
||||
console.print(f"\n[bold]Step 3/3: Generating SPLADE index...[/bold]")
|
||||
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
from codexlens.config import SPLADE_DB_NAME
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
results["steps"]["splade"] = {"success": False, "error": f"SPLADE not available: {err}"}
|
||||
if not json_mode:
|
||||
console.print(f" [yellow]Skipped:[/yellow] SPLADE not available ({err})")
|
||||
else:
|
||||
# Discover all _index.db files
|
||||
all_index_dbs = sorted(index_root.rglob("_index.db"))
|
||||
if not all_index_dbs:
|
||||
results["steps"]["splade"] = {"success": False, "error": "No index databases found"}
|
||||
if not json_mode:
|
||||
console.print(f" [yellow]Skipped:[/yellow] No index databases found")
|
||||
else:
|
||||
# Collect chunks
|
||||
all_chunks = []
|
||||
global_id = 0
|
||||
for index_db in all_index_dbs:
|
||||
try:
|
||||
vector_store = VectorStore(index_db)
|
||||
chunks = vector_store.get_all_chunks()
|
||||
for chunk in chunks:
|
||||
global_id += 1
|
||||
all_chunks.append((global_id, chunk, index_db))
|
||||
vector_store.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if all_chunks:
|
||||
splade_db = index_root / SPLADE_DB_NAME
|
||||
if splade_db.exists() and force:
|
||||
splade_db.unlink()
|
||||
|
||||
encoder = get_splade_encoder()
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
splade_index.create_tables()
|
||||
|
||||
chunk_metadata_batch = []
|
||||
import json as json_module
|
||||
for gid, chunk, source_db_path in all_chunks:
|
||||
sparse_vec = encoder.encode_text(chunk.content)
|
||||
splade_index.add_posting(gid, sparse_vec)
|
||||
metadata_str = None
|
||||
if hasattr(chunk, 'metadata') and chunk.metadata:
|
||||
try:
|
||||
metadata_str = json_module.dumps(chunk.metadata) if isinstance(chunk.metadata, dict) else chunk.metadata
|
||||
except Exception:
|
||||
pass
|
||||
chunk_metadata_batch.append((
|
||||
gid,
|
||||
chunk.file_path or "",
|
||||
chunk.content,
|
||||
metadata_str,
|
||||
str(source_db_path)
|
||||
))
|
||||
|
||||
if chunk_metadata_batch:
|
||||
splade_index.add_chunks_metadata_batch(chunk_metadata_batch)
|
||||
|
||||
splade_index.set_metadata(
|
||||
model_name=encoder.model_name,
|
||||
vocab_size=encoder.vocab_size
|
||||
)
|
||||
|
||||
stats = splade_index.get_stats()
|
||||
results["steps"]["splade"] = {
|
||||
"success": True,
|
||||
"chunks": stats['unique_chunks'],
|
||||
"postings": stats['total_postings'],
|
||||
}
|
||||
if not json_mode:
|
||||
console.print(f" [green]OK[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings")
|
||||
else:
|
||||
results["steps"]["splade"] = {"success": False, "error": "No chunks found"}
|
||||
if not json_mode:
|
||||
console.print(f" [yellow]Skipped:[/yellow] No chunks found in indexes")
|
||||
|
||||
except Exception as e:
|
||||
results["steps"]["splade"] = {"success": False, "error": str(e)}
|
||||
if not json_mode:
|
||||
console.print(f" [yellow]Warning:[/yellow] {e}")
|
||||
else:
|
||||
results["steps"]["splade"] = {"success": True, "skipped": True}
|
||||
if not json_mode:
|
||||
console.print(f"\n[bold]Step 3/3: SPLADE index...[/bold]")
|
||||
console.print(f" [dim]Skipped (--skip-splade)[/dim]")
|
||||
|
||||
# Summary
|
||||
if json_mode:
|
||||
print_json(success=True, result=results)
|
||||
@@ -3955,10 +3537,8 @@ def index_all(
|
||||
console.print(f"\n[bold]Indexing Complete[/bold]")
|
||||
init_ok = results["steps"].get("init", {}).get("success", False)
|
||||
emb_ok = results["steps"].get("embeddings", {}).get("success", False)
|
||||
splade_ok = results["steps"].get("splade", {}).get("success", False)
|
||||
console.print(f" FTS Index: {'[green]OK[/green]' if init_ok else '[red]Failed[/red]'}")
|
||||
console.print(f" Embeddings: {'[green]OK[/green]' if emb_ok else '[yellow]Partial/Skipped[/yellow]'}")
|
||||
console.print(f" SPLADE: {'[green]OK[/green]' if splade_ok else '[yellow]Partial/Skipped[/yellow]'}")
|
||||
|
||||
|
||||
# ==================== Index Migration Commands ====================
|
||||
@@ -3997,50 +3577,6 @@ def _set_index_version(index_root: Path, version: str) -> None:
|
||||
version_file.write_text(version, encoding="utf-8")
|
||||
|
||||
|
||||
def _discover_distributed_splade(index_root: Path) -> List[Dict[str, Any]]:
|
||||
"""Discover distributed SPLADE data in _index.db files.
|
||||
|
||||
Scans all _index.db files for embedded splade_postings tables.
|
||||
This is the old distributed format that needs migration.
|
||||
|
||||
Args:
|
||||
index_root: Root directory to scan
|
||||
|
||||
Returns:
|
||||
List of dicts with db_path, posting_count, chunk_count
|
||||
"""
|
||||
results = []
|
||||
|
||||
for db_path in index_root.rglob("_index.db"):
|
||||
try:
|
||||
conn = sqlite3.connect(db_path, timeout=5.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Check if splade_postings table exists (old embedded format)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='splade_postings'"
|
||||
)
|
||||
if cursor.fetchone():
|
||||
# Count postings and chunks
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT COUNT(*) as postings, COUNT(DISTINCT chunk_id) as chunks FROM splade_postings"
|
||||
).fetchone()
|
||||
results.append({
|
||||
"db_path": db_path,
|
||||
"posting_count": row["postings"] if row else 0,
|
||||
"chunk_count": row["chunks"] if row else 0,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _discover_distributed_hnsw(index_root: Path) -> List[Dict[str, Any]]:
|
||||
"""Discover distributed HNSW index files.
|
||||
|
||||
@@ -4075,33 +3611,18 @@ def _check_centralized_storage(index_root: Path) -> Dict[str, Any]:
|
||||
index_root: Root directory to check
|
||||
|
||||
Returns:
|
||||
Dict with has_splade, has_vectors, splade_stats, vector_stats
|
||||
Dict with has_vectors, vector_stats
|
||||
"""
|
||||
from codexlens.config import SPLADE_DB_NAME, VECTORS_HNSW_NAME
|
||||
from codexlens.config import VECTORS_HNSW_NAME
|
||||
|
||||
splade_db = index_root / SPLADE_DB_NAME
|
||||
vectors_hnsw = index_root / VECTORS_HNSW_NAME
|
||||
|
||||
result = {
|
||||
"has_splade": splade_db.exists(),
|
||||
"has_vectors": vectors_hnsw.exists(),
|
||||
"splade_path": str(splade_db) if splade_db.exists() else None,
|
||||
"vectors_path": str(vectors_hnsw) if vectors_hnsw.exists() else None,
|
||||
"splade_stats": None,
|
||||
"vector_stats": None,
|
||||
}
|
||||
|
||||
# Get SPLADE stats if exists
|
||||
if splade_db.exists():
|
||||
try:
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
splade = SpladeIndex(splade_db)
|
||||
if splade.has_index():
|
||||
result["splade_stats"] = splade.get_stats()
|
||||
splade.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get vector stats if exists
|
||||
if vectors_hnsw.exists():
|
||||
try:
|
||||
@@ -4125,21 +3646,19 @@ def index_migrate_cmd(
|
||||
"""Migrate old distributed index to new centralized architecture.
|
||||
|
||||
This command upgrades indexes from the old distributed storage format
|
||||
(where SPLADE/vectors were stored in each _index.db) to the new centralized
|
||||
format (single _splade.db and _vectors.hnsw at index root).
|
||||
(where vectors were stored in each _index.db) to the new centralized
|
||||
format (single _vectors.hnsw at index root).
|
||||
|
||||
Migration Steps:
|
||||
1. Detect if migration is needed (check version marker)
|
||||
2. Discover distributed SPLADE data in _index.db files
|
||||
3. Discover distributed .hnsw files
|
||||
4. Report current status
|
||||
5. Create version marker (unless --dry-run)
|
||||
2. Discover distributed .hnsw files
|
||||
3. Report current status
|
||||
4. Create version marker (unless --dry-run)
|
||||
|
||||
Use --dry-run to preview what would be migrated without making changes.
|
||||
Use --force to re-run migration even if version marker exists.
|
||||
|
||||
Note: For full data migration (SPLADE/vectors consolidation), run:
|
||||
codexlens index splade <path> --rebuild
|
||||
Note: For full data migration (vectors consolidation), run:
|
||||
codexlens index embeddings <path> --force
|
||||
|
||||
Examples:
|
||||
@@ -4222,7 +3741,6 @@ def index_migrate_cmd(
|
||||
return
|
||||
|
||||
# Discover distributed data
|
||||
distributed_splade = _discover_distributed_splade(index_root)
|
||||
distributed_hnsw = _discover_distributed_hnsw(index_root)
|
||||
centralized = _check_centralized_storage(index_root)
|
||||
|
||||
@@ -4239,8 +3757,6 @@ def index_migrate_cmd(
|
||||
"needs_migration": needs_migration,
|
||||
"discovery": {
|
||||
"total_index_dbs": len(all_index_dbs),
|
||||
"distributed_splade_count": len(distributed_splade),
|
||||
"distributed_splade_total_postings": sum(d["posting_count"] for d in distributed_splade),
|
||||
"distributed_hnsw_count": len(distributed_hnsw),
|
||||
"distributed_hnsw_total_bytes": sum(d["size_bytes"] for d in distributed_hnsw),
|
||||
},
|
||||
@@ -4249,17 +3765,12 @@ def index_migrate_cmd(
|
||||
}
|
||||
|
||||
# Generate recommendations
|
||||
if distributed_splade and not centralized["has_splade"]:
|
||||
migration_report["recommendations"].append(
|
||||
f"Run 'codexlens splade-index {target_path} --rebuild' to consolidate SPLADE data"
|
||||
)
|
||||
|
||||
if distributed_hnsw and not centralized["has_vectors"]:
|
||||
migration_report["recommendations"].append(
|
||||
f"Run 'codexlens embeddings-generate {target_path} --recursive --force' to consolidate vector data"
|
||||
)
|
||||
|
||||
if not distributed_splade and not distributed_hnsw:
|
||||
if not distributed_hnsw:
|
||||
migration_report["recommendations"].append(
|
||||
"No distributed data found. Index may already be using centralized storage."
|
||||
)
|
||||
@@ -4280,23 +3791,6 @@ def index_migrate_cmd(
|
||||
console.print(f" Total _index.db files: {len(all_index_dbs)}")
|
||||
console.print()
|
||||
|
||||
# Distributed SPLADE
|
||||
console.print("[bold]Distributed SPLADE Data:[/bold]")
|
||||
if distributed_splade:
|
||||
total_postings = sum(d["posting_count"] for d in distributed_splade)
|
||||
total_chunks = sum(d["chunk_count"] for d in distributed_splade)
|
||||
console.print(f" Found in {len(distributed_splade)} _index.db files")
|
||||
console.print(f" Total postings: {total_postings:,}")
|
||||
console.print(f" Total chunks: {total_chunks:,}")
|
||||
if verbose:
|
||||
for d in distributed_splade[:5]:
|
||||
console.print(f" [dim]{d['db_path'].parent.name}: {d['posting_count']} postings[/dim]")
|
||||
if len(distributed_splade) > 5:
|
||||
console.print(f" [dim]... and {len(distributed_splade) - 5} more[/dim]")
|
||||
else:
|
||||
console.print(" [dim]None found (already centralized or not generated)[/dim]")
|
||||
console.print()
|
||||
|
||||
# Distributed HNSW
|
||||
console.print("[bold]Distributed HNSW Files:[/bold]")
|
||||
if distributed_hnsw:
|
||||
@@ -4314,15 +3808,6 @@ def index_migrate_cmd(
|
||||
|
||||
# Centralized storage status
|
||||
console.print("[bold]Centralized Storage:[/bold]")
|
||||
if centralized["has_splade"]:
|
||||
stats = centralized.get("splade_stats") or {}
|
||||
console.print(f" [green]OK[/green] _splade.db exists")
|
||||
if stats:
|
||||
console.print(f" Chunks: {stats.get('unique_chunks', 0):,}")
|
||||
console.print(f" Postings: {stats.get('total_postings', 0):,}")
|
||||
else:
|
||||
console.print(f" [yellow]--[/yellow] _splade.db not found")
|
||||
|
||||
if centralized["has_vectors"]:
|
||||
stats = centralized.get("vector_stats") or {}
|
||||
size_mb = stats.get("size_bytes", 0) / (1024 * 1024)
|
||||
@@ -4440,20 +3925,6 @@ def init_deprecated(
|
||||
)
|
||||
|
||||
|
||||
@app.command("splade-index", hidden=True, deprecated=True)
|
||||
def splade_index_deprecated(
|
||||
path: Path = typer.Argument(..., help="Project path to index"),
|
||||
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""[Deprecated] Use 'codexlens index splade' instead."""
|
||||
_deprecated_command_warning("splade-index", "index splade")
|
||||
index_splade(
|
||||
path=path,
|
||||
rebuild=rebuild,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
|
||||
@app.command("cascade-index", hidden=True, deprecated=True)
|
||||
def cascade_index_deprecated(
|
||||
|
||||
@@ -151,15 +151,6 @@ def _cleanup_fastembed_resources() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _cleanup_splade_resources() -> None:
|
||||
"""Release SPLADE encoder ONNX resources."""
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import clear_splade_cache
|
||||
clear_splade_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _generate_chunks_from_cursor(
|
||||
cursor,
|
||||
chunker,
|
||||
@@ -398,7 +389,6 @@ def generate_embeddings(
|
||||
endpoints: Optional[List] = None,
|
||||
strategy: Optional[str] = None,
|
||||
cooldown: Optional[float] = None,
|
||||
splade_db_path: Optional[Path] = None,
|
||||
) -> Dict[str, any]:
|
||||
"""Generate embeddings for an index using memory-efficient batch processing.
|
||||
|
||||
@@ -428,9 +418,6 @@ def generate_embeddings(
|
||||
Each dict has keys: model, api_key, api_base, weight.
|
||||
strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware).
|
||||
cooldown: Default cooldown seconds for rate-limited endpoints.
|
||||
splade_db_path: Optional path to centralized SPLADE database. If None, SPLADE
|
||||
is written to index_path (legacy behavior). Use index_root / SPLADE_DB_NAME
|
||||
for centralized storage.
|
||||
|
||||
Returns:
|
||||
Result dictionary with generation statistics
|
||||
@@ -822,97 +809,10 @@ def generate_embeddings(
|
||||
if progress_callback:
|
||||
progress_callback(f"Finalizing index... Building ANN index for {total_chunks_created} chunks")
|
||||
|
||||
# --- SPLADE SPARSE ENCODING (after dense embeddings) ---
|
||||
# Add SPLADE encoding if enabled in config
|
||||
splade_success = False
|
||||
splade_error = None
|
||||
|
||||
try:
|
||||
from codexlens.config import Config, SPLADE_DB_NAME
|
||||
config = Config.load()
|
||||
|
||||
if config.enable_splade:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if ok:
|
||||
if progress_callback:
|
||||
progress_callback(f"Generating SPLADE sparse vectors for {total_chunks_created} chunks...")
|
||||
|
||||
# Initialize SPLADE encoder and index
|
||||
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
|
||||
# Use centralized SPLADE database if provided, otherwise fallback to index_path
|
||||
effective_splade_path = splade_db_path if splade_db_path else index_path
|
||||
splade_index = SpladeIndex(effective_splade_path)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Retrieve all chunks from database for SPLADE encoding
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute("SELECT id, content FROM semantic_chunks ORDER BY id")
|
||||
|
||||
# Batch encode for efficiency
|
||||
SPLADE_BATCH_SIZE = 32
|
||||
batch_postings = []
|
||||
chunk_batch = []
|
||||
chunk_ids = []
|
||||
|
||||
for row in cursor:
|
||||
chunk_id = row["id"]
|
||||
content = row["content"]
|
||||
|
||||
chunk_ids.append(chunk_id)
|
||||
chunk_batch.append(content)
|
||||
|
||||
# Process batch when full
|
||||
if len(chunk_batch) >= SPLADE_BATCH_SIZE:
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
batch_postings.append((cid, sparse_vec))
|
||||
|
||||
chunk_batch = []
|
||||
chunk_ids = []
|
||||
|
||||
# Process remaining chunks
|
||||
if chunk_batch:
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
batch_postings.append((cid, sparse_vec))
|
||||
|
||||
# Batch insert all postings
|
||||
if batch_postings:
|
||||
splade_index.add_postings_batch(batch_postings)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=splade_encoder.model_name,
|
||||
vocab_size=splade_encoder.vocab_size
|
||||
)
|
||||
|
||||
splade_success = True
|
||||
if progress_callback:
|
||||
stats = splade_index.get_stats()
|
||||
progress_callback(
|
||||
f"SPLADE index created: {stats['total_postings']} postings, "
|
||||
f"{stats['unique_tokens']} unique tokens"
|
||||
)
|
||||
else:
|
||||
logger.debug("SPLADE not available: %s", err)
|
||||
splade_error = f"SPLADE not available: {err}"
|
||||
except Exception as e:
|
||||
splade_error = str(e)
|
||||
logger.warning("SPLADE encoding failed: %s", e)
|
||||
|
||||
# Report SPLADE status after processing
|
||||
if progress_callback and not splade_success and splade_error:
|
||||
progress_callback(f"SPLADE index: FAILED - {splade_error}")
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup on error to prevent process hanging
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -924,7 +824,6 @@ def generate_embeddings(
|
||||
# This is critical - without it, ONNX Runtime threads prevent Python from exiting
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1098,10 +997,6 @@ def generate_embeddings_recursive(
|
||||
if progress_callback:
|
||||
progress_callback(f"Found {len(index_files)} index databases to process")
|
||||
|
||||
# Calculate centralized SPLADE database path
|
||||
from codexlens.config import SPLADE_DB_NAME
|
||||
splade_db_path = index_root / SPLADE_DB_NAME
|
||||
|
||||
# Process each index database
|
||||
all_results = []
|
||||
total_chunks = 0
|
||||
@@ -1131,7 +1026,6 @@ def generate_embeddings_recursive(
|
||||
endpoints=endpoints,
|
||||
strategy=strategy,
|
||||
cooldown=cooldown,
|
||||
splade_db_path=splade_db_path, # Use centralized SPLADE storage
|
||||
)
|
||||
|
||||
all_results.append({
|
||||
@@ -1153,7 +1047,6 @@ def generate_embeddings_recursive(
|
||||
# Each generate_embeddings() call does its own cleanup, but do a final one to be safe
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -1197,7 +1090,6 @@ def generate_dense_embeddings_centralized(
|
||||
Target architecture:
|
||||
<index_root>/
|
||||
|-- _vectors.hnsw # Centralized dense vector ANN index
|
||||
|-- _splade.db # Centralized sparse vector index
|
||||
|-- src/
|
||||
|-- _index.db # No longer contains .hnsw file
|
||||
|
||||
@@ -1219,7 +1111,7 @@ def generate_dense_embeddings_centralized(
|
||||
Returns:
|
||||
Result dictionary with generation statistics
|
||||
"""
|
||||
from codexlens.config import VECTORS_HNSW_NAME, SPLADE_DB_NAME
|
||||
from codexlens.config import VECTORS_HNSW_NAME
|
||||
|
||||
# Get defaults from config if not specified
|
||||
(default_backend, default_model, default_gpu,
|
||||
@@ -1543,90 +1435,6 @@ def generate_dense_embeddings_centralized(
|
||||
logger.warning("Binary vector generation failed: %s", e)
|
||||
# Non-fatal: continue without binary vectors
|
||||
|
||||
# --- SPLADE Sparse Index Generation (Centralized) ---
|
||||
splade_success = False
|
||||
splade_chunks_count = 0
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
config = Config.load()
|
||||
|
||||
if config.enable_splade and chunk_id_to_info:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
import json
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if ok:
|
||||
if progress_callback:
|
||||
progress_callback(f"Generating SPLADE sparse vectors for {len(chunk_id_to_info)} chunks...")
|
||||
|
||||
# Initialize SPLADE encoder and index
|
||||
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
|
||||
splade_db_path = index_root / SPLADE_DB_NAME
|
||||
splade_index = SpladeIndex(splade_db_path)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Batch encode for efficiency
|
||||
SPLADE_BATCH_SIZE = 32
|
||||
all_postings = []
|
||||
all_chunk_metadata = []
|
||||
|
||||
# Create batches from chunk_id_to_info
|
||||
chunk_items = list(chunk_id_to_info.items())
|
||||
|
||||
for i in range(0, len(chunk_items), SPLADE_BATCH_SIZE):
|
||||
batch_items = chunk_items[i:i + SPLADE_BATCH_SIZE]
|
||||
chunk_ids = [item[0] for item in batch_items]
|
||||
chunk_contents = [item[1]["content"] for item in batch_items]
|
||||
|
||||
# Generate sparse vectors
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_contents, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
all_postings.append((cid, sparse_vec))
|
||||
|
||||
if progress_callback and (i + SPLADE_BATCH_SIZE) % 100 == 0:
|
||||
progress_callback(f"SPLADE encoding: {min(i + SPLADE_BATCH_SIZE, len(chunk_items))}/{len(chunk_items)}")
|
||||
|
||||
# Batch insert all postings
|
||||
if all_postings:
|
||||
splade_index.add_postings_batch(all_postings)
|
||||
|
||||
# CRITICAL FIX: Populate splade_chunks table
|
||||
for cid, info in chunk_id_to_info.items():
|
||||
metadata_str = json.dumps(info.get("metadata", {})) if info.get("metadata") else None
|
||||
all_chunk_metadata.append((
|
||||
cid,
|
||||
info["file_path"],
|
||||
info["content"],
|
||||
metadata_str,
|
||||
info.get("source_index_db")
|
||||
))
|
||||
|
||||
if all_chunk_metadata:
|
||||
splade_index.add_chunks_metadata_batch(all_chunk_metadata)
|
||||
splade_chunks_count = len(all_chunk_metadata)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=splade_encoder.model_name,
|
||||
vocab_size=splade_encoder.vocab_size
|
||||
)
|
||||
|
||||
splade_index.close()
|
||||
splade_success = True
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"SPLADE index created: {len(all_postings)} postings, {splade_chunks_count} chunks")
|
||||
|
||||
else:
|
||||
if progress_callback:
|
||||
progress_callback(f"SPLADE not available, skipping sparse index: {err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("SPLADE encoding failed: %s", e)
|
||||
if progress_callback:
|
||||
progress_callback(f"SPLADE encoding failed: {e}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# Cleanup
|
||||
@@ -1647,8 +1455,6 @@ def generate_dense_embeddings_centralized(
|
||||
"model_name": embedder.model_name,
|
||||
"central_index_path": str(central_hnsw_path),
|
||||
"failed_files": failed_files[:5],
|
||||
"splade_success": splade_success,
|
||||
"splade_chunks": splade_chunks_count,
|
||||
"binary_success": binary_success,
|
||||
"binary_count": binary_count,
|
||||
},
|
||||
|
||||
@@ -19,9 +19,6 @@ WORKSPACE_DIR_NAME = ".codexlens"
|
||||
# Settings file name
|
||||
SETTINGS_FILE_NAME = "settings.json"
|
||||
|
||||
# SPLADE index database name (centralized storage)
|
||||
SPLADE_DB_NAME = "_splade.db"
|
||||
|
||||
# Dense vector storage names (centralized storage)
|
||||
VECTORS_HNSW_NAME = "_vectors.hnsw"
|
||||
VECTORS_META_DB_NAME = "_vectors_meta.db"
|
||||
@@ -113,15 +110,6 @@ class Config:
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
# SPLADE sparse retrieval configuration
|
||||
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
|
||||
splade_model: str = "naver/splade-cocondenser-ensembledistil"
|
||||
splade_threshold: float = 0.01 # Min weight to store in index
|
||||
splade_onnx_path: Optional[str] = None # Custom ONNX model path
|
||||
|
||||
# FTS fallback (disabled by default, available via --use-fts)
|
||||
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
|
||||
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||
@@ -152,7 +140,7 @@ class Config:
|
||||
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
|
||||
cascade_fine_k: int = 10 # Number of final results after reranking
|
||||
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
|
||||
cascade_strategy: str = "binary" # "binary", "binary_rerank", "dense_rerank", or "staged"
|
||||
|
||||
# Staged cascade search configuration (4-stage pipeline)
|
||||
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
|
||||
@@ -398,11 +386,11 @@ class Config:
|
||||
cascade = settings.get("cascade", {})
|
||||
if "strategy" in cascade:
|
||||
strategy = cascade["strategy"]
|
||||
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
|
||||
if strategy in {"binary", "binary_rerank", "dense_rerank", "staged"}:
|
||||
self.cascade_strategy = strategy
|
||||
else:
|
||||
log.warning(
|
||||
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
|
||||
"Invalid cascade strategy in %s: %r (expected 'binary', 'binary_rerank', 'dense_rerank', or 'staged')",
|
||||
self.settings_path,
|
||||
strategy,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,6 @@ class SearchOptions:
|
||||
enable_fuzzy: Enable fuzzy FTS in hybrid mode (default True)
|
||||
enable_vector: Enable vector semantic search (default False)
|
||||
pure_vector: If True, only use vector search without FTS fallback (default False)
|
||||
enable_splade: Enable SPLADE sparse neural search (default False)
|
||||
enable_cascade: Enable cascade (binary+dense) two-stage retrieval (default False)
|
||||
hybrid_weights: Custom RRF weights for hybrid search (optional)
|
||||
group_results: Enable grouping of similar results (default False)
|
||||
@@ -75,7 +74,6 @@ class SearchOptions:
|
||||
enable_fuzzy: bool = True
|
||||
enable_vector: bool = False
|
||||
pure_vector: bool = False
|
||||
enable_splade: bool = False
|
||||
enable_cascade: bool = False
|
||||
hybrid_weights: Optional[Dict[str, float]] = None
|
||||
group_results: bool = False
|
||||
@@ -306,154 +304,6 @@ class ChainSearchEngine:
|
||||
related_results=related_results,
|
||||
)
|
||||
|
||||
def hybrid_cascade_search(
|
||||
self,
|
||||
query: str,
|
||||
source_path: Path,
|
||||
k: int = 10,
|
||||
coarse_k: int = 100,
|
||||
options: Optional[SearchOptions] = None,
|
||||
) -> ChainSearchResult:
|
||||
"""Execute two-stage cascade search with hybrid coarse retrieval and cross-encoder reranking.
|
||||
|
||||
Hybrid cascade search process:
|
||||
1. Stage 1 (Coarse): Fast retrieval using RRF fusion of FTS + SPLADE + Vector
|
||||
to get coarse_k candidates
|
||||
2. Stage 2 (Fine): CrossEncoder reranking of candidates to get final k results
|
||||
|
||||
This approach balances recall (from broad coarse search) with precision
|
||||
(from expensive but accurate cross-encoder scoring).
|
||||
|
||||
Note: This method is the original hybrid approach. For binary vector cascade,
|
||||
use binary_cascade_search() instead.
|
||||
|
||||
Args:
|
||||
query: Natural language or keyword query string
|
||||
source_path: Starting directory path
|
||||
k: Number of final results to return (default 10)
|
||||
coarse_k: Number of coarse candidates from first stage (default 100)
|
||||
options: Search configuration (uses defaults if None)
|
||||
|
||||
Returns:
|
||||
ChainSearchResult with reranked results and statistics
|
||||
|
||||
Examples:
|
||||
>>> engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
>>> result = engine.hybrid_cascade_search(
|
||||
... "how to authenticate users",
|
||||
... Path("D:/project/src"),
|
||||
... k=10,
|
||||
... coarse_k=100
|
||||
... )
|
||||
>>> for r in result.results:
|
||||
... print(f"{r.path}: {r.score:.3f}")
|
||||
"""
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
stats = SearchStats()
|
||||
|
||||
# Use config defaults if available
|
||||
if self._config is not None:
|
||||
if hasattr(self._config, "cascade_coarse_k"):
|
||||
coarse_k = coarse_k or self._config.cascade_coarse_k
|
||||
if hasattr(self._config, "cascade_fine_k"):
|
||||
k = k or self._config.cascade_fine_k
|
||||
|
||||
# Step 1: Find starting index
|
||||
start_index = self._find_start_index(source_path)
|
||||
if not start_index:
|
||||
self.logger.warning(f"No index found for {source_path}")
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# Step 2: Collect all index paths
|
||||
index_paths = self._collect_index_paths(start_index, options.depth)
|
||||
stats.dirs_searched = len(index_paths)
|
||||
|
||||
if not index_paths:
|
||||
self.logger.warning(f"No indexes collected from {start_index}")
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# Stage 1: Coarse retrieval with hybrid search (FTS + SPLADE + Vector)
|
||||
# Use hybrid mode for multi-signal retrieval
|
||||
coarse_options = SearchOptions(
|
||||
depth=options.depth,
|
||||
max_workers=1, # Single thread for GPU safety
|
||||
limit_per_dir=max(coarse_k // len(index_paths), 20),
|
||||
total_limit=coarse_k,
|
||||
hybrid_mode=True,
|
||||
enable_fuzzy=options.enable_fuzzy,
|
||||
enable_vector=True, # Enable vector for semantic matching
|
||||
pure_vector=False,
|
||||
hybrid_weights=options.hybrid_weights,
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
"Cascade Stage 1: Coarse retrieval for %d candidates", coarse_k
|
||||
)
|
||||
coarse_results, search_stats = self._search_parallel(
|
||||
index_paths, query, coarse_options
|
||||
)
|
||||
stats.errors = search_stats.errors
|
||||
|
||||
# Merge and deduplicate coarse results
|
||||
coarse_merged = self._merge_and_rank(coarse_results, coarse_k)
|
||||
self.logger.debug(
|
||||
"Cascade Stage 1 complete: %d candidates retrieved", len(coarse_merged)
|
||||
)
|
||||
|
||||
if not coarse_merged:
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=[],
|
||||
symbols=[],
|
||||
stats=stats
|
||||
)
|
||||
|
||||
# Stage 2: Cross-encoder reranking
|
||||
self.logger.debug(
|
||||
"Cascade Stage 2: Cross-encoder reranking %d candidates to top-%d",
|
||||
len(coarse_merged),
|
||||
k,
|
||||
)
|
||||
|
||||
final_results = self._cross_encoder_rerank(query, coarse_merged, k)
|
||||
|
||||
# Optional: grouping of similar results
|
||||
if options.group_results:
|
||||
from codexlens.search.ranking import group_similar_results
|
||||
final_results = group_similar_results(
|
||||
final_results, score_threshold_abs=options.grouping_threshold
|
||||
)
|
||||
|
||||
stats.files_matched = len(final_results)
|
||||
stats.time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
self.logger.debug(
|
||||
"Cascade search complete: %d results in %.2fms",
|
||||
len(final_results),
|
||||
stats.time_ms,
|
||||
)
|
||||
|
||||
return ChainSearchResult(
|
||||
query=query,
|
||||
results=final_results,
|
||||
symbols=[],
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
def binary_cascade_search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -501,9 +351,9 @@ class ChainSearchEngine:
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
self.logger.warning(
|
||||
"NumPy not available, falling back to hybrid cascade search"
|
||||
"NumPy not available, falling back to standard search"
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
@@ -552,10 +402,10 @@ class ChainSearchEngine:
|
||||
except ImportError as exc:
|
||||
self.logger.warning(
|
||||
"Binary cascade dependencies not available: %s. "
|
||||
"Falling back to hybrid cascade search.",
|
||||
"Falling back to standard search.",
|
||||
exc
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Stage 1: Binary vector coarse retrieval
|
||||
self.logger.debug(
|
||||
@@ -573,10 +423,10 @@ class ChainSearchEngine:
|
||||
except Exception as exc:
|
||||
self.logger.warning(
|
||||
"Failed to generate binary query embedding: %s. "
|
||||
"Falling back to hybrid cascade search.",
|
||||
"Falling back to standard search.",
|
||||
exc
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Try centralized BinarySearcher first (preferred for mmap indexes)
|
||||
# The index root is the parent of the first index path
|
||||
@@ -629,8 +479,8 @@ class ChainSearchEngine:
|
||||
stats.errors.append(f"Binary search failed for {index_path}: {exc}")
|
||||
|
||||
if not all_candidates:
|
||||
self.logger.debug("No binary candidates found, falling back to hybrid")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
self.logger.debug("No binary candidates found, falling back to standard search")
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Sort by Hamming distance and take top coarse_k
|
||||
all_candidates.sort(key=lambda x: x[1])
|
||||
@@ -828,13 +678,12 @@ class ChainSearchEngine:
|
||||
k: int = 10,
|
||||
coarse_k: int = 100,
|
||||
options: Optional[SearchOptions] = None,
|
||||
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]] = None,
|
||||
strategy: Optional[Literal["binary", "binary_rerank", "dense_rerank", "staged"]] = None,
|
||||
) -> ChainSearchResult:
|
||||
"""Unified cascade search entry point with strategy selection.
|
||||
|
||||
Provides a single interface for cascade search with configurable strategy:
|
||||
- "binary": Uses binary vector coarse ranking + dense fine ranking (fastest)
|
||||
- "hybrid": Uses FTS+SPLADE+Vector coarse ranking + cross-encoder reranking (original)
|
||||
- "binary_rerank": Uses binary vector coarse ranking + cross-encoder reranking (best balance)
|
||||
- "dense_rerank": Uses dense vector coarse ranking + cross-encoder reranking
|
||||
- "staged": 4-stage pipeline: binary -> LSP expand -> clustering -> optional rerank
|
||||
@@ -850,7 +699,7 @@ class ChainSearchEngine:
|
||||
k: Number of final results to return (default 10)
|
||||
coarse_k: Number of coarse candidates from first stage (default 100)
|
||||
options: Search configuration (uses defaults if None)
|
||||
strategy: Cascade strategy - "binary", "hybrid", "binary_rerank", "dense_rerank", or "staged".
|
||||
strategy: Cascade strategy - "binary", "binary_rerank", "dense_rerank", or "staged".
|
||||
|
||||
Returns:
|
||||
ChainSearchResult with reranked results and statistics
|
||||
@@ -859,8 +708,6 @@ class ChainSearchEngine:
|
||||
>>> engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
>>> # Use binary cascade (default, fastest)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"))
|
||||
>>> # Use hybrid cascade (original behavior)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="hybrid")
|
||||
>>> # Use binary + cross-encoder (best balance of speed and quality)
|
||||
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="binary_rerank")
|
||||
>>> # Use 4-stage pipeline (binary + LSP expand + clustering + optional rerank)
|
||||
@@ -868,7 +715,7 @@ class ChainSearchEngine:
|
||||
"""
|
||||
# Strategy priority: parameter > config > default
|
||||
effective_strategy = strategy
|
||||
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank", "staged")
|
||||
valid_strategies = ("binary", "binary_rerank", "dense_rerank", "staged")
|
||||
if effective_strategy is None:
|
||||
# Not passed via parameter, check config
|
||||
if self._config is not None:
|
||||
@@ -889,7 +736,7 @@ class ChainSearchEngine:
|
||||
elif effective_strategy == "staged":
|
||||
return self.staged_cascade_search(query, source_path, k, coarse_k, options)
|
||||
else:
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.binary_cascade_search(query, source_path, k, coarse_k, options)
|
||||
|
||||
def staged_cascade_search(
|
||||
self,
|
||||
@@ -943,9 +790,9 @@ class ChainSearchEngine:
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
self.logger.warning(
|
||||
"NumPy not available, falling back to hybrid cascade search"
|
||||
"NumPy not available, falling back to standard search"
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
@@ -1002,8 +849,8 @@ class ChainSearchEngine:
|
||||
)
|
||||
|
||||
if not coarse_results:
|
||||
self.logger.debug("No binary candidates found, falling back to hybrid cascade")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
self.logger.debug("No binary candidates found, falling back to standard search")
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# ========== Stage 2: LSP Graph Expansion ==========
|
||||
stage2_start = time.time()
|
||||
@@ -1534,7 +1381,7 @@ class ChainSearchEngine:
|
||||
2. Stage 2 (Fine): Cross-encoder reranking for precise semantic ranking
|
||||
of candidates using query-document attention
|
||||
|
||||
This approach is typically faster than hybrid_cascade_search while
|
||||
This approach is typically faster than binary_cascade_search while
|
||||
achieving similar or better quality through cross-encoder reranking.
|
||||
|
||||
Performance characteristics:
|
||||
@@ -1565,9 +1412,9 @@ class ChainSearchEngine:
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
self.logger.warning(
|
||||
"NumPy not available, falling back to hybrid cascade search"
|
||||
"NumPy not available, falling back to standard search"
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
@@ -1611,10 +1458,10 @@ class ChainSearchEngine:
|
||||
from codexlens.indexing.embedding import BinaryEmbeddingBackend
|
||||
except ImportError as exc:
|
||||
self.logger.warning(
|
||||
"BinaryEmbeddingBackend not available: %s, falling back to hybrid cascade",
|
||||
"BinaryEmbeddingBackend not available: %s, falling back to standard search",
|
||||
exc
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Step 4: Binary coarse search (same as binary_cascade_search)
|
||||
binary_coarse_time = time.time()
|
||||
@@ -1658,7 +1505,7 @@ class ChainSearchEngine:
|
||||
query_binary = binary_backend.embed_packed([query])[0]
|
||||
except Exception as exc:
|
||||
self.logger.warning(f"Failed to generate binary query embedding: {exc}")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Fallback to per-directory binary indexes
|
||||
for index_path in index_paths:
|
||||
@@ -1676,9 +1523,9 @@ class ChainSearchEngine:
|
||||
)
|
||||
|
||||
if not coarse_candidates:
|
||||
self.logger.info("No binary candidates found, falling back to hybrid cascade for reranking")
|
||||
# Fall back to hybrid_cascade_search which uses FTS+Vector coarse + cross-encoder rerank
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
self.logger.info("No binary candidates found, falling back to standard search for reranking")
|
||||
# Fall back to standard search which uses FTS+Vector
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Sort by Hamming distance and take top coarse_k
|
||||
coarse_candidates.sort(key=lambda x: x[1])
|
||||
@@ -1785,7 +1632,7 @@ class ChainSearchEngine:
|
||||
"Retrieved %d chunks for cross-encoder reranking", len(coarse_results)
|
||||
)
|
||||
|
||||
# Step 6: Cross-encoder reranking (same as hybrid_cascade_search)
|
||||
# Step 6: Cross-encoder reranking
|
||||
rerank_time = time.time()
|
||||
reranked_results = self._cross_encoder_rerank(query, coarse_results, top_k=k)
|
||||
|
||||
@@ -1848,9 +1695,9 @@ class ChainSearchEngine:
|
||||
"""
|
||||
if not NUMPY_AVAILABLE:
|
||||
self.logger.warning(
|
||||
"NumPy not available, falling back to hybrid cascade search"
|
||||
"NumPy not available, falling back to standard search"
|
||||
)
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
options = options or SearchOptions()
|
||||
start_time = time.time()
|
||||
@@ -1955,7 +1802,7 @@ class ChainSearchEngine:
|
||||
self.logger.debug(f"Dense query embedding: {query_dense.shape[0]}-dim via {embedding_backend}/{embedding_model}")
|
||||
except Exception as exc:
|
||||
self.logger.warning(f"Failed to generate dense query embedding: {exc}")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Step 5: Dense coarse search using centralized HNSW index
|
||||
coarse_candidates: List[Tuple[int, float, Path]] = [] # (chunk_id, distance, index_path)
|
||||
@@ -2006,8 +1853,8 @@ class ChainSearchEngine:
|
||||
)
|
||||
|
||||
if not coarse_candidates:
|
||||
self.logger.info("No dense candidates found, falling back to hybrid cascade")
|
||||
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
|
||||
self.logger.info("No dense candidates found, falling back to standard search")
|
||||
return self.search(query, source_path, options=options)
|
||||
|
||||
# Sort by distance (ascending for cosine distance) and take top coarse_k
|
||||
coarse_candidates.sort(key=lambda x: x[1])
|
||||
@@ -2972,7 +2819,6 @@ class ChainSearchEngine:
|
||||
options.enable_fuzzy,
|
||||
options.enable_vector,
|
||||
options.pure_vector,
|
||||
options.enable_splade,
|
||||
options.hybrid_weights
|
||||
): idx_path
|
||||
for idx_path in index_paths
|
||||
@@ -3001,7 +2847,6 @@ class ChainSearchEngine:
|
||||
enable_fuzzy: bool = True,
|
||||
enable_vector: bool = False,
|
||||
pure_vector: bool = False,
|
||||
enable_splade: bool = False,
|
||||
hybrid_weights: Optional[Dict[str, float]] = None) -> List[SearchResult]:
|
||||
"""Search a single index database.
|
||||
|
||||
@@ -3017,7 +2862,6 @@ class ChainSearchEngine:
|
||||
enable_fuzzy: Enable fuzzy FTS in hybrid mode
|
||||
enable_vector: Enable vector semantic search
|
||||
pure_vector: If True, only use vector search without FTS fallback
|
||||
enable_splade: If True, force SPLADE sparse neural search
|
||||
hybrid_weights: Custom RRF weights for hybrid search
|
||||
|
||||
Returns:
|
||||
@@ -3034,7 +2878,6 @@ class ChainSearchEngine:
|
||||
enable_fuzzy=enable_fuzzy,
|
||||
enable_vector=enable_vector,
|
||||
pure_vector=pure_vector,
|
||||
enable_splade=enable_splade,
|
||||
)
|
||||
else:
|
||||
# Single-FTS search (exact or fuzzy mode)
|
||||
|
||||
@@ -35,7 +35,6 @@ from codexlens.config import VECTORS_HNSW_NAME
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import (
|
||||
DEFAULT_WEIGHTS,
|
||||
FTS_FALLBACK_WEIGHTS,
|
||||
QueryIntent,
|
||||
apply_symbol_boost,
|
||||
cross_encoder_rerank,
|
||||
@@ -57,14 +56,6 @@ except ImportError:
|
||||
HAS_LSP = False
|
||||
|
||||
|
||||
# Three-way fusion weights (FTS + Vector + SPLADE)
|
||||
THREE_WAY_WEIGHTS = {
|
||||
"exact": 0.2,
|
||||
"splade": 0.3,
|
||||
"vector": 0.5,
|
||||
}
|
||||
|
||||
|
||||
class HybridSearchEngine:
|
||||
"""Hybrid search engine with parallel execution and RRF fusion.
|
||||
|
||||
@@ -77,8 +68,7 @@ class HybridSearchEngine:
|
||||
"""
|
||||
|
||||
# NOTE: DEFAULT_WEIGHTS imported from ranking.py - single source of truth
|
||||
# Default RRF weights: SPLADE-based hybrid (splade: 0.4, vector: 0.6)
|
||||
# FTS fallback mode uses FTS_FALLBACK_WEIGHTS (exact: 0.3, fuzzy: 0.1, vector: 0.6)
|
||||
# FTS + vector hybrid mode (exact: 0.3, fuzzy: 0.1, vector: 0.6)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -119,7 +109,6 @@ class HybridSearchEngine:
|
||||
enable_fuzzy: bool = True,
|
||||
enable_vector: bool = False,
|
||||
pure_vector: bool = False,
|
||||
enable_splade: bool = False,
|
||||
enable_lsp_graph: bool = False,
|
||||
lsp_max_depth: int = 1,
|
||||
lsp_max_nodes: int = 20,
|
||||
@@ -133,7 +122,6 @@ class HybridSearchEngine:
|
||||
enable_fuzzy: Enable fuzzy FTS search (default True)
|
||||
enable_vector: Enable vector search (default False)
|
||||
pure_vector: If True, only use vector search without FTS fallback (default False)
|
||||
enable_splade: If True, force SPLADE sparse neural search (default False)
|
||||
enable_lsp_graph: If True, enable real-time LSP graph expansion (default False)
|
||||
lsp_max_depth: Maximum depth for LSP graph BFS expansion (default 1)
|
||||
lsp_max_nodes: Maximum nodes to collect in LSP graph (default 20)
|
||||
@@ -150,9 +138,6 @@ class HybridSearchEngine:
|
||||
>>> results = engine.search(Path("project/_index.db"),
|
||||
... "how to authenticate users",
|
||||
... enable_vector=True, pure_vector=True)
|
||||
>>> # SPLADE sparse neural search
|
||||
>>> results = engine.search(Path("project/_index.db"), "auth flow",
|
||||
... enable_splade=True, enable_vector=True)
|
||||
>>> # With LSP graph expansion (real-time)
|
||||
>>> results = engine.search(Path("project/_index.db"), "auth flow",
|
||||
... enable_vector=True, enable_lsp_graph=True)
|
||||
@@ -180,26 +165,6 @@ class HybridSearchEngine:
|
||||
# Determine which backends to use
|
||||
backends = {}
|
||||
|
||||
# Check if SPLADE is available
|
||||
splade_available = False
|
||||
# Respect config.enable_splade flag and use_fts_fallback flag
|
||||
if self._config and getattr(self._config, 'use_fts_fallback', False):
|
||||
# Config explicitly requests FTS fallback - disable SPLADE
|
||||
splade_available = False
|
||||
elif self._config and not getattr(self._config, 'enable_splade', True):
|
||||
# Config explicitly disabled SPLADE
|
||||
splade_available = False
|
||||
else:
|
||||
# Check if SPLADE dependencies are available
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
ok, _ = check_splade_available()
|
||||
if ok:
|
||||
# SPLADE tables are in main index database, will check table existence in _search_splade
|
||||
splade_available = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if pure_vector:
|
||||
# Pure vector mode: only use vector search, no FTS fallback
|
||||
if enable_vector:
|
||||
@@ -212,37 +177,13 @@ class HybridSearchEngine:
|
||||
"To use pure vector search, enable vector search mode."
|
||||
)
|
||||
backends["exact"] = True
|
||||
elif enable_splade:
|
||||
# Explicit SPLADE mode requested via CLI --method splade
|
||||
if splade_available:
|
||||
backends["splade"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
else:
|
||||
# SPLADE requested but not available - warn and fallback
|
||||
self.logger.warning(
|
||||
"SPLADE search requested but not available. "
|
||||
"Falling back to FTS. Run 'codexlens index splade' to enable."
|
||||
)
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
else:
|
||||
# Hybrid mode: default to SPLADE if available, otherwise use FTS
|
||||
if splade_available:
|
||||
# Default: enable SPLADE, disable exact and fuzzy
|
||||
backends["splade"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
else:
|
||||
# Fallback mode: enable exact+fuzzy when SPLADE unavailable
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
# Standard hybrid mode: FTS + optional vector
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
|
||||
# Add LSP graph expansion if requested and available
|
||||
if enable_lsp_graph and HAS_LSP:
|
||||
@@ -502,13 +443,6 @@ class HybridSearchEngine:
|
||||
)
|
||||
future_to_source[future] = "vector"
|
||||
|
||||
if backends.get("splade"):
|
||||
submit_times["splade"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_splade, index_path, query, limit
|
||||
)
|
||||
future_to_source[future] = "splade"
|
||||
|
||||
if backends.get("lsp_graph"):
|
||||
submit_times["lsp_graph"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
@@ -599,8 +533,7 @@ class HybridSearchEngine:
|
||||
def _find_vectors_hnsw(self, index_path: Path) -> Optional[Path]:
|
||||
"""Find the centralized _vectors.hnsw file by traversing up from index_path.
|
||||
|
||||
Similar to _search_splade's approach, this method searches for the
|
||||
centralized dense vector index file in parent directories.
|
||||
Searches for the centralized dense vector index file in parent directories.
|
||||
|
||||
Args:
|
||||
index_path: Path to the current _index.db file
|
||||
@@ -1138,124 +1071,6 @@ class HybridSearchEngine:
|
||||
self.logger.error("Vector search error: %s", exc)
|
||||
return []
|
||||
|
||||
def _search_splade(
|
||||
self, index_path: Path, query: str, limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""SPLADE sparse retrieval via inverted index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
query: Natural language query string
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of SearchResult ordered by SPLADE score
|
||||
"""
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.config import SPLADE_DB_NAME
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
# Check dependencies
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
self.logger.debug("SPLADE not available: %s", err)
|
||||
return []
|
||||
|
||||
# SPLADE index is stored in _splade.db at the project index root
|
||||
# Traverse up from the current index to find the root _splade.db
|
||||
current_dir = index_path.parent
|
||||
splade_db_path = None
|
||||
for _ in range(10): # Limit search depth
|
||||
candidate = current_dir / SPLADE_DB_NAME
|
||||
if candidate.exists():
|
||||
splade_db_path = candidate
|
||||
break
|
||||
parent = current_dir.parent
|
||||
if parent == current_dir: # Reached root
|
||||
break
|
||||
current_dir = parent
|
||||
|
||||
if not splade_db_path:
|
||||
self.logger.debug("SPLADE index not found in ancestor directories of %s", index_path)
|
||||
return []
|
||||
|
||||
splade_index = SpladeIndex(splade_db_path)
|
||||
if not splade_index.has_index():
|
||||
self.logger.debug("SPLADE index not initialized")
|
||||
return []
|
||||
|
||||
# Encode query to sparse vector
|
||||
encoder = get_splade_encoder(use_gpu=self._use_gpu)
|
||||
query_sparse = encoder.encode_text(query)
|
||||
|
||||
# Search inverted index for top matches
|
||||
raw_results = splade_index.search(query_sparse, limit=limit, min_score=0.0)
|
||||
|
||||
if not raw_results:
|
||||
return []
|
||||
|
||||
# Fetch chunk details from splade_chunks table (self-contained)
|
||||
chunk_ids = [chunk_id for chunk_id, _ in raw_results]
|
||||
score_map = {chunk_id: score for chunk_id, score in raw_results}
|
||||
|
||||
# Get chunk metadata from SPLADE database
|
||||
rows = splade_index.get_chunks_by_ids(chunk_ids)
|
||||
|
||||
# Build SearchResult objects
|
||||
results = []
|
||||
for row in rows:
|
||||
chunk_id = row["id"]
|
||||
file_path = row["file_path"]
|
||||
content = row["content"]
|
||||
metadata_json = row["metadata"]
|
||||
metadata = json.loads(metadata_json) if metadata_json else {}
|
||||
|
||||
score = score_map.get(chunk_id, 0.0)
|
||||
|
||||
# Build excerpt (short preview)
|
||||
excerpt = content[:200] + "..." if len(content) > 200 else content
|
||||
|
||||
# Extract symbol information from metadata
|
||||
symbol_name = metadata.get("symbol_name")
|
||||
symbol_kind = metadata.get("symbol_kind")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
|
||||
# Build Symbol object if we have symbol info
|
||||
symbol = None
|
||||
if symbol_name and symbol_kind and start_line and end_line:
|
||||
try:
|
||||
from codexlens.entities import Symbol
|
||||
symbol = Symbol(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
range=(start_line, end_line)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
results.append(SearchResult(
|
||||
path=file_path,
|
||||
score=score,
|
||||
excerpt=excerpt,
|
||||
content=content,
|
||||
symbol=symbol,
|
||||
metadata=metadata,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.debug("SPLADE search error: %s", exc)
|
||||
return []
|
||||
|
||||
def _search_lsp_graph(
|
||||
self,
|
||||
index_path: Path,
|
||||
@@ -1295,21 +1110,14 @@ class HybridSearchEngine:
|
||||
if seeds:
|
||||
seed_source = "vector"
|
||||
|
||||
# 2. Fallback to SPLADE if vector returns nothing
|
||||
# 2. Fallback to exact FTS if vector returns nothing
|
||||
if not seeds:
|
||||
self.logger.debug("Vector search returned no seeds, trying SPLADE")
|
||||
seeds = self._search_splade(index_path, query, limit=3)
|
||||
if seeds:
|
||||
seed_source = "splade"
|
||||
|
||||
# 3. Fallback to exact FTS if SPLADE also fails
|
||||
if not seeds:
|
||||
self.logger.debug("SPLADE returned no seeds, trying exact FTS")
|
||||
self.logger.debug("Vector search returned no seeds, trying exact FTS")
|
||||
seeds = self._search_exact(index_path, query, limit=3)
|
||||
if seeds:
|
||||
seed_source = "exact_fts"
|
||||
|
||||
# 4. No seeds available from any source
|
||||
# 3. No seeds available from any source
|
||||
if not seeds:
|
||||
self.logger.debug("No seed results available for LSP graph expansion")
|
||||
return []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Ranking algorithms for hybrid search result fusion.
|
||||
|
||||
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
||||
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
||||
for combining results from heterogeneous search backends (exact FTS, fuzzy FTS, vector search).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,19 +15,12 @@ from typing import Any, Dict, List, Optional
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
# Default RRF weights for SPLADE-based hybrid search
|
||||
# Default RRF weights for hybrid search
|
||||
DEFAULT_WEIGHTS = {
|
||||
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
}
|
||||
|
||||
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||
FTS_FALLBACK_WEIGHTS = {
|
||||
"exact": 0.25,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.5,
|
||||
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
||||
"lsp_graph": 0.15,
|
||||
}
|
||||
|
||||
|
||||
@@ -105,22 +98,13 @@ def adjust_weights_by_intent(
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Adjust RRF weights based on query intent."""
|
||||
# Check if using SPLADE or FTS mode
|
||||
use_splade = "splade" in base_weights
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
if use_splade:
|
||||
target = {"splade": 0.6, "vector": 0.4}
|
||||
else:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
if use_splade:
|
||||
target = {"splade": 0.3, "vector": 0.7}
|
||||
else:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
|
||||
# Filter to active backends
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
@@ -225,7 +209,7 @@ def simple_weighted_fusion(
|
||||
|
||||
Args:
|
||||
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||
Sources: 'exact', 'fuzzy', 'vector'
|
||||
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||
|
||||
@@ -331,14 +315,11 @@ def reciprocal_rank_fusion(
|
||||
|
||||
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
|
||||
|
||||
Supports three-way fusion with FTS, Vector, and SPLADE sources.
|
||||
|
||||
Args:
|
||||
results_map: Dictionary mapping source name to list of SearchResult objects
|
||||
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
||||
Sources: 'exact', 'fuzzy', 'vector'
|
||||
weights: Dictionary mapping source name to weight (default: equal weights)
|
||||
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
||||
Or: {'splade': 0.4, 'vector': 0.6}
|
||||
k: Constant to avoid division by zero and control rank influence (default 60)
|
||||
|
||||
Returns:
|
||||
@@ -349,14 +330,6 @@ def reciprocal_rank_fusion(
|
||||
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
|
||||
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
|
||||
>>> fused = reciprocal_rank_fusion(results_map)
|
||||
|
||||
# Three-way fusion with SPLADE
|
||||
>>> results_map = {
|
||||
... 'exact': exact_results,
|
||||
... 'vector': vector_results,
|
||||
... 'splade': splade_results
|
||||
... }
|
||||
>>> fused = reciprocal_rank_fusion(results_map, k=60)
|
||||
"""
|
||||
if not results_map:
|
||||
return []
|
||||
|
||||
@@ -1,225 +0,0 @@
|
||||
# SPLADE Encoder Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
Created `splade_encoder.py` - A complete ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
## File Location
|
||||
|
||||
`src/codexlens/semantic/splade_encoder.py` (474 lines)
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Dependency Checking
|
||||
|
||||
**Function**: `check_splade_available() -> Tuple[bool, Optional[str]]`
|
||||
- Validates numpy, onnxruntime, optimum, transformers availability
|
||||
- Returns (True, None) if all dependencies present
|
||||
- Returns (False, error_message) with install instructions if missing
|
||||
|
||||
### 2. Caching System
|
||||
|
||||
**Global Cache**: Thread-safe singleton pattern
|
||||
- `_splade_cache: Dict[str, SpladeEncoder]` - Global encoder cache
|
||||
- `_cache_lock: threading.RLock()` - Thread safety lock
|
||||
|
||||
**Factory Function**: `get_splade_encoder(...) -> SpladeEncoder`
|
||||
- Cache key includes: model_name, gpu/cpu, max_length, sparsity_threshold
|
||||
- Pre-loads model on first access
|
||||
- Returns cached instance on subsequent calls
|
||||
|
||||
**Cleanup Function**: `clear_splade_cache() -> None`
|
||||
- Releases ONNX resources
|
||||
- Clears model and tokenizer references
|
||||
- Prevents memory leaks
|
||||
|
||||
### 3. SpladeEncoder Class
|
||||
|
||||
#### Initialization Parameters
|
||||
- `model_name: str` - Default: "naver/splade-cocondenser-ensembledistil"
|
||||
- `use_gpu: bool` - Enable GPU acceleration (default: True)
|
||||
- `max_length: int` - Max sequence length (default: 512)
|
||||
- `sparsity_threshold: float` - Min weight threshold (default: 0.01)
|
||||
- `providers: Optional[List]` - Explicit ONNX providers (overrides use_gpu)
|
||||
|
||||
#### Core Methods
|
||||
|
||||
**`_load_model()`**: Lazy loading with GPU support
|
||||
- Uses `optimum.onnxruntime.ORTModelForMaskedLM`
|
||||
- Falls back to CPU if GPU unavailable
|
||||
- Integrates with `gpu_support.get_optimal_providers()`
|
||||
- Handles device_id options for DirectML/CUDA
|
||||
|
||||
**`_splade_activation(logits, attention_mask)`**: Static method
|
||||
- Formula: `log(1 + ReLU(logits)) * attention_mask`
|
||||
- Input: (batch, seq_len, vocab_size)
|
||||
- Output: (batch, seq_len, vocab_size)
|
||||
|
||||
**`_max_pooling(splade_repr)`**: Static method
|
||||
- Max pooling over sequence dimension
|
||||
- Input: (batch, seq_len, vocab_size)
|
||||
- Output: (batch, vocab_size)
|
||||
|
||||
**`_to_sparse_dict(dense_vec)`**: Conversion helper
|
||||
- Filters by sparsity_threshold
|
||||
- Returns: `Dict[int, float]` mapping token_id to weight
|
||||
|
||||
**`encode_text(text: str) -> Dict[int, float]`**: Single text encoding
|
||||
- Tokenizes input with truncation/padding
|
||||
- Forward pass through ONNX model
|
||||
- Applies SPLADE activation + max pooling
|
||||
- Returns sparse vector
|
||||
|
||||
**`encode_batch(texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]`**: Batch encoding
|
||||
- Processes in batches for memory efficiency
|
||||
- Returns list of sparse vectors
|
||||
|
||||
#### Properties
|
||||
|
||||
**`vocab_size: int`**: Vocabulary size (~30k for BERT)
|
||||
- Cached after first model load
|
||||
- Returns tokenizer length
|
||||
|
||||
#### Debugging Methods
|
||||
|
||||
**`get_token(token_id: int) -> str`**
|
||||
- Converts token_id to human-readable string
|
||||
- Uses tokenizer.decode()
|
||||
|
||||
**`get_top_tokens(sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]`**
|
||||
- Extracts top-k highest-weight tokens
|
||||
- Returns (token_string, weight) pairs
|
||||
- Useful for understanding model focus
|
||||
|
||||
## Design Patterns Followed
|
||||
|
||||
### 1. From `onnx_reranker.py`
|
||||
✓ ONNX loading with provider detection
|
||||
✓ Lazy model initialization
|
||||
✓ Thread-safe loading with RLock
|
||||
✓ Signature inspection for backward compatibility
|
||||
✓ Fallback for older Optimum versions
|
||||
✓ Static helper methods for numerical operations
|
||||
|
||||
### 2. From `embedder.py`
|
||||
✓ Global cache with thread safety
|
||||
✓ Factory function pattern (get_splade_encoder)
|
||||
✓ Cache cleanup function (clear_splade_cache)
|
||||
✓ GPU provider configuration
|
||||
✓ Batch processing support
|
||||
|
||||
### 3. From `gpu_support.py`
|
||||
✓ `get_optimal_providers(use_gpu, with_device_options=True)`
|
||||
✓ Device ID options for DirectML/CUDA
|
||||
✓ Provider tuple format: (provider_name, options_dict)
|
||||
|
||||
## SPLADE Algorithm
|
||||
|
||||
### Activation Formula
|
||||
```python
|
||||
# Step 1: ReLU activation
|
||||
relu_logits = max(0, logits)
|
||||
|
||||
# Step 2: Log(1 + x) transformation
|
||||
log_relu = log(1 + relu_logits)
|
||||
|
||||
# Step 3: Apply attention mask
|
||||
splade_repr = log_relu * attention_mask
|
||||
|
||||
# Step 4: Max pooling over sequence
|
||||
splade_vec = max(splade_repr, axis=sequence_length)
|
||||
|
||||
# Step 5: Sparsification by threshold
|
||||
sparse_dict = {token_id: weight for token_id, weight in enumerate(splade_vec) if weight > threshold}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
- Sparse dictionary: `{token_id: weight}`
|
||||
- Token IDs: 0 to vocab_size-1 (typically ~30,000)
|
||||
- Weights: Float values > sparsity_threshold
|
||||
- Interpretable: Can decode token_ids to strings
|
||||
|
||||
## Integration Points
|
||||
|
||||
### With `splade_index.py`
|
||||
- `SpladeIndex.add_posting(chunk_id, sparse_vec: Dict[int, float])`
|
||||
- `SpladeIndex.search(query_sparse: Dict[int, float])`
|
||||
- Encoder produces the sparse vectors consumed by index
|
||||
|
||||
### With Indexing Pipeline
|
||||
```python
|
||||
encoder = get_splade_encoder(use_gpu=True)
|
||||
|
||||
# Single document
|
||||
sparse_vec = encoder.encode_text("def main():\n print('hello')")
|
||||
index.add_posting(chunk_id=1, sparse_vec=sparse_vec)
|
||||
|
||||
# Batch indexing
|
||||
texts = ["code chunk 1", "code chunk 2", ...]
|
||||
sparse_vecs = encoder.encode_batch(texts, batch_size=64)
|
||||
postings = [(chunk_id, vec) for chunk_id, vec in enumerate(sparse_vecs)]
|
||||
index.add_postings_batch(postings)
|
||||
```
|
||||
|
||||
### With Search Pipeline
|
||||
```python
|
||||
encoder = get_splade_encoder(use_gpu=True)
|
||||
query_sparse = encoder.encode_text("authentication function")
|
||||
results = index.search(query_sparse, limit=50, min_score=0.5)
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
Required packages:
|
||||
- `numpy` - Numerical operations
|
||||
- `onnxruntime` - ONNX model execution (CPU)
|
||||
- `onnxruntime-gpu` - ONNX with GPU support (optional)
|
||||
- `optimum[onnxruntime]` - Hugging Face ONNX optimization
|
||||
- `transformers` - Tokenizer and model loading
|
||||
|
||||
Install command:
|
||||
```bash
|
||||
# CPU only
|
||||
pip install numpy onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
# With GPU support
|
||||
pip install numpy onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
```
|
||||
|
||||
## Testing Status
|
||||
|
||||
✓ Python syntax validation passed
|
||||
✓ Module import successful
|
||||
✓ Dependency checking works correctly
|
||||
✗ Full functional test pending (requires optimum installation)
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Install dependencies for functional testing
|
||||
2. Create unit tests in `tests/semantic/test_splade_encoder.py`
|
||||
3. Benchmark encoding performance (CPU vs GPU)
|
||||
4. Integrate with codex-lens indexing pipeline
|
||||
5. Add SPLADE option to semantic search configuration
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
- Model size: ~100MB (ONNX optimized)
|
||||
- Sparse vectors: ~100-500 non-zero entries per document
|
||||
- Batch size: 32 recommended (adjust based on GPU memory)
|
||||
|
||||
### Speed Benchmarks (Expected)
|
||||
- CPU encoding: ~10-20 docs/sec
|
||||
- GPU encoding (CUDA): ~100-200 docs/sec
|
||||
- GPU encoding (DirectML): ~50-100 docs/sec
|
||||
|
||||
### Sparsity Analysis
|
||||
- Threshold 0.01: ~200-400 tokens per document
|
||||
- Threshold 0.05: ~100-200 tokens per document
|
||||
- Threshold 0.10: ~50-100 tokens per document
|
||||
|
||||
## References
|
||||
|
||||
- SPLADE paper: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
- Naver model: https://huggingface.co/naver/splade-cocondenser-ensembledistil
|
||||
@@ -1,567 +0,0 @@
|
||||
"""ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
|
||||
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
|
||||
that combine the interpretability of BM25 with neural relevance modeling.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
Install (GPU):
|
||||
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_splade_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check whether SPLADE dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available: bool, error_message: Optional[str])
|
||||
"""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc:
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# Global cache for SPLADE encoders (singleton pattern)
|
||||
_splade_cache: Dict[str, "SpladeEncoder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_splade_encoder(
|
||||
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> "SpladeEncoder":
|
||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
SpladeEncoder instances across multiple searches, avoiding repeated model
|
||||
loading overhead.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
|
||||
Returns:
|
||||
Cached SpladeEncoder instance for the given configuration
|
||||
"""
|
||||
global _splade_cache
|
||||
|
||||
# Cache key includes all configuration parameters
|
||||
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
|
||||
|
||||
with _cache_lock:
|
||||
encoder = _splade_cache.get(cache_key)
|
||||
if encoder is not None:
|
||||
return encoder
|
||||
|
||||
# Create new encoder and cache it
|
||||
encoder = SpladeEncoder(
|
||||
model_name=model_name,
|
||||
use_gpu=use_gpu,
|
||||
max_length=max_length,
|
||||
sparsity_threshold=sparsity_threshold,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Pre-load model to ensure it's ready
|
||||
encoder._load_model()
|
||||
_splade_cache[cache_key] = encoder
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def clear_splade_cache() -> None:
|
||||
"""Clear the SPLADE encoder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when encoders are no longer needed.
|
||||
"""
|
||||
global _splade_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for encoder in _splade_cache.values():
|
||||
if encoder._model is not None:
|
||||
del encoder._model
|
||||
encoder._model = None
|
||||
if encoder._tokenizer is not None:
|
||||
del encoder._tokenizer
|
||||
encoder._tokenizer = None
|
||||
_splade_cache.clear()
|
||||
|
||||
|
||||
class SpladeEncoder:
|
||||
"""ONNX-optimized SPLADE sparse encoder.
|
||||
|
||||
Produces sparse vectors with vocabulary-aligned dimensions.
|
||||
Output: Dict[int, float] mapping token_id to weight.
|
||||
|
||||
SPLADE activation formula:
|
||||
splade_repr = log(1 + ReLU(logits)) * attention_mask
|
||||
splade_vec = max_pooling(splade_repr, axis=sequence_length)
|
||||
|
||||
References:
|
||||
- SPLADE: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
providers: Optional[List[Any]] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Initialize SPLADE encoder.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.max_length = int(max_length) if max_length > 0 else 512
|
||||
self.sparsity_threshold = float(sparsity_threshold)
|
||||
self.providers = providers
|
||||
|
||||
# Setup ONNX cache directory
|
||||
if cache_dir:
|
||||
self._cache_dir = Path(cache_dir)
|
||||
else:
|
||||
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._vocab_size: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _get_local_cache_path(self) -> Path:
|
||||
"""Get local cache path for this model's ONNX files.
|
||||
|
||||
Returns:
|
||||
Path to the local ONNX cache directory for this model
|
||||
"""
|
||||
# Replace / with -- for filesystem-safe naming
|
||||
safe_name = self.model_name.replace("/", "--")
|
||||
return self._cache_dir / safe_name
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load ONNX model and tokenizer.
|
||||
|
||||
First checks local cache for ONNX model, falling back to
|
||||
HuggingFace download and conversion if not cached.
|
||||
"""
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from .gpu_support import get_optimal_providers, get_selected_device_id
|
||||
|
||||
# Get providers as pure string list (cache-friendly)
|
||||
# NOTE: with_device_options=False to avoid tuple-based providers
|
||||
# which break optimum's caching mechanism
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=False
|
||||
)
|
||||
# Get device_id separately for provider_options
|
||||
self._device_id = get_selected_device_id() if self.use_gpu else None
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`
|
||||
# Prefer passing the full providers list, with a conservative fallback
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
# Pass device_id via provider_options for GPU selection
|
||||
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
|
||||
# Build provider_options dict for each GPU provider
|
||||
provider_options = {}
|
||||
for p in self.providers:
|
||||
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
|
||||
provider_options[p] = {"device_id": self._device_id}
|
||||
if provider_options:
|
||||
model_kwargs["provider_options"] = provider_options
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to inspect ORTModel signature: {e}")
|
||||
model_kwargs = {}
|
||||
|
||||
# Check for local ONNX cache first
|
||||
local_cache = self._get_local_cache_path()
|
||||
onnx_model_path = local_cache / "model.onnx"
|
||||
|
||||
if onnx_model_path.exists():
|
||||
# Load from local cache
|
||||
logger.info(f"Loading SPLADE from local cache: {local_cache}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
str(local_cache),
|
||||
**model_kwargs,
|
||||
)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
str(local_cache), use_fast=True
|
||||
)
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.info(
|
||||
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load from cache, redownloading: {e}")
|
||||
|
||||
# Download and convert from HuggingFace
|
||||
logger.info(f"Downloading SPLADE model: {self.model_name}")
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True, # Export to ONNX
|
||||
**model_kwargs,
|
||||
)
|
||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
export=True,
|
||||
)
|
||||
logger.warning(
|
||||
"Optimum version doesn't support provider parameters. "
|
||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||
)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache vocabulary size
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||
|
||||
# Save to local cache for future use
|
||||
try:
|
||||
local_cache.mkdir(parents=True, exist_ok=True)
|
||||
self._model.save_pretrained(str(local_cache))
|
||||
self._tokenizer.save_pretrained(str(local_cache))
|
||||
logger.info(f"SPLADE model cached to: {local_cache}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache SPLADE model: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||
"""Apply SPLADE activation function to model outputs.
|
||||
|
||||
Formula: log(1 + ReLU(logits)) * attention_mask
|
||||
|
||||
Args:
|
||||
logits: Model output logits (batch, seq_len, vocab_size)
|
||||
attention_mask: Attention mask (batch, seq_len)
|
||||
|
||||
Returns:
|
||||
SPLADE representations (batch, seq_len, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# ReLU activation
|
||||
relu_logits = np.maximum(0, logits)
|
||||
|
||||
# Log(1 + x) transformation
|
||||
log_relu = np.log1p(relu_logits)
|
||||
|
||||
# Apply attention mask (expand to match vocab dimension)
|
||||
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
|
||||
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
||||
|
||||
# Element-wise multiplication
|
||||
splade_repr = log_relu * mask_expanded
|
||||
|
||||
return splade_repr
|
||||
|
||||
@staticmethod
|
||||
def _max_pooling(splade_repr: Any) -> Any:
|
||||
"""Max pooling over sequence length dimension.
|
||||
|
||||
Args:
|
||||
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
|
||||
|
||||
Returns:
|
||||
Pooled sparse vectors (batch, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Max pooling over sequence dimension (axis=1)
|
||||
return np.max(splade_repr, axis=1)
|
||||
|
||||
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
|
||||
"""Convert dense vector to sparse dictionary.
|
||||
|
||||
Args:
|
||||
dense_vec: Dense vector (vocab_size,)
|
||||
|
||||
Returns:
|
||||
Sparse dictionary {token_id: weight} with weights above threshold
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Find non-zero indices above threshold
|
||||
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
|
||||
|
||||
# Create sparse dictionary
|
||||
sparse_dict = {
|
||||
int(idx): float(dense_vec[idx])
|
||||
for idx in nonzero_indices
|
||||
}
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def warmup(self, text: str = "warmup query") -> None:
|
||||
"""Warmup the encoder by running a dummy inference.
|
||||
|
||||
First-time model inference includes initialization overhead.
|
||||
Call this method once before the first real search to avoid
|
||||
latency spikes.
|
||||
|
||||
Args:
|
||||
text: Dummy text for warmup (default: "warmup query")
|
||||
"""
|
||||
logger.info("Warming up SPLADE encoder...")
|
||||
# Trigger model loading and first inference
|
||||
_ = self.encode_text(text)
|
||||
logger.info("SPLADE encoder warmup complete")
|
||||
|
||||
def encode_text(self, text: str) -> Dict[int, float]:
|
||||
"""Encode text to sparse vector {token_id: weight}.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Sparse vector as dictionary mapping token_id to weight
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Tokenize input
|
||||
encoded = self._tokenizer(
|
||||
text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vec = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert to sparse dictionary (single item batch)
|
||||
sparse_dict = self._to_sparse_dict(splade_vec[0])
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
|
||||
"""Batch encode texts to sparse vectors.
|
||||
|
||||
Args:
|
||||
texts: List of input texts to encode
|
||||
batch_size: Batch size for encoding (default: 32)
|
||||
|
||||
Returns:
|
||||
List of sparse vectors as dictionaries
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
results: List[Dict[int, float]] = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Tokenize batch
|
||||
encoded = self._tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vecs = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert each vector to sparse dictionary
|
||||
for vec in splade_vecs:
|
||||
sparse_dict = self._to_sparse_dict(vec)
|
||||
results.append(sparse_dict)
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Return vocabulary size (~30k for BERT-based models).
|
||||
|
||||
Returns:
|
||||
Vocabulary size (number of tokens in tokenizer)
|
||||
"""
|
||||
if self._vocab_size is not None:
|
||||
return self._vocab_size
|
||||
|
||||
self._load_model()
|
||||
return self._vocab_size or 0
|
||||
|
||||
def get_token(self, token_id: int) -> str:
|
||||
"""Convert token_id to string (for debugging).
|
||||
|
||||
Args:
|
||||
token_id: Token ID to convert
|
||||
|
||||
Returns:
|
||||
Token string
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded")
|
||||
|
||||
return self._tokenizer.decode([token_id])
|
||||
|
||||
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
|
||||
"""Get top-k tokens with highest weights from sparse vector.
|
||||
|
||||
Useful for debugging and understanding what the model is focusing on.
|
||||
|
||||
Args:
|
||||
sparse_vec: Sparse vector as {token_id: weight}
|
||||
top_k: Number of top tokens to return
|
||||
|
||||
Returns:
|
||||
List of (token_string, weight) tuples, sorted by weight descending
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if not sparse_vec:
|
||||
return []
|
||||
|
||||
# Sort by weight descending
|
||||
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Take top-k and convert token_ids to strings
|
||||
top_items = sorted_items[:top_k]
|
||||
|
||||
return [
|
||||
(self.get_token(token_id), weight)
|
||||
for token_id, weight in top_items
|
||||
]
|
||||
@@ -1,103 +0,0 @@
|
||||
"""
|
||||
Migration 009: Add SPLADE sparse retrieval tables.
|
||||
|
||||
This migration introduces SPLADE (Sparse Lexical AnD Expansion) support:
|
||||
- splade_metadata: Model configuration (model name, vocab size, ONNX path)
|
||||
- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight)
|
||||
|
||||
The SPLADE tables are designed for efficient sparse vector retrieval:
|
||||
- Token-based lookup for query expansion
|
||||
- Chunk-based deletion for index maintenance
|
||||
- Maintains backward compatibility with existing FTS tables
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
"""
|
||||
Adds SPLADE tables for sparse retrieval.
|
||||
|
||||
Creates:
|
||||
- splade_metadata: Stores model configuration and ONNX path
|
||||
- splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings
|
||||
- Indexes for efficient token-based and chunk-based lookups
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating splade_metadata table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
model_name TEXT NOT NULL,
|
||||
vocab_size INTEGER NOT NULL,
|
||||
onnx_path TEXT,
|
||||
created_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating splade_posting_list table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||
token_id INTEGER NOT NULL,
|
||||
chunk_id INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
PRIMARY KEY (token_id, chunk_id),
|
||||
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for splade_posting_list...")
|
||||
# Index for efficient chunk-based lookups (deletion, updates)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||
ON splade_posting_list(chunk_id)
|
||||
"""
|
||||
)
|
||||
|
||||
# Index for efficient term-based retrieval
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||
ON splade_posting_list(token_id)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Migration 009 completed successfully")
|
||||
|
||||
|
||||
def downgrade(db_conn: Connection) -> None:
|
||||
"""
|
||||
Removes SPLADE tables.
|
||||
|
||||
Drops:
|
||||
- splade_posting_list (and associated indexes)
|
||||
- splade_metadata
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Dropping SPLADE indexes...")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token")
|
||||
|
||||
log.info("Dropping splade_posting_list table...")
|
||||
cursor.execute("DROP TABLE IF EXISTS splade_posting_list")
|
||||
|
||||
log.info("Dropping splade_metadata table...")
|
||||
cursor.execute("DROP TABLE IF EXISTS splade_metadata")
|
||||
|
||||
log.info("Migration 009 downgrade completed successfully")
|
||||
@@ -1,578 +0,0 @@
|
||||
"""SPLADE inverted index storage for sparse vector retrieval.
|
||||
|
||||
This module implements SQLite-based inverted index for SPLADE sparse vectors,
|
||||
enabling efficient sparse retrieval using dot-product scoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpladeIndex:
|
||||
"""SQLite-based inverted index for SPLADE sparse vectors.
|
||||
|
||||
Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight).
|
||||
Supports efficient dot-product retrieval using SQL joins.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path | str) -> None:
|
||||
"""Initialize SPLADE index.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file.
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Thread-safe connection management
|
||||
self._lock = threading.RLock()
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get or create a thread-local database connection.
|
||||
|
||||
Each thread gets its own connection to ensure thread safety.
|
||||
Connections are stored in thread-local storage.
|
||||
"""
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
# Thread-local connection - each thread has its own
|
||||
conn = sqlite3.connect(
|
||||
self.db_path,
|
||||
timeout=30.0, # Wait up to 30s for locks
|
||||
check_same_thread=True, # Enforce thread safety
|
||||
)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
# Limit mmap to 1GB to avoid OOM on smaller systems
|
||||
conn.execute("PRAGMA mmap_size=1073741824")
|
||||
# Increase cache size for better query performance (20MB = -20000 pages)
|
||||
conn.execute("PRAGMA cache_size=-20000")
|
||||
self._local.conn = conn
|
||||
return conn
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close thread-local database connection."""
|
||||
with self._lock:
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
self._local.conn = None
|
||||
|
||||
def __enter__(self) -> SpladeIndex:
|
||||
"""Context manager entry."""
|
||||
self.create_tables()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
def has_index(self) -> bool:
|
||||
"""Check if SPLADE tables exist in database.
|
||||
|
||||
Returns:
|
||||
True if tables exist, False otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='splade_posting_list'
|
||||
"""
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to check index existence: %s", e)
|
||||
return False
|
||||
|
||||
def create_tables(self) -> None:
|
||||
"""Create SPLADE schema if not exists.
|
||||
|
||||
Note: When used with distributed indexes (multiple _index.db files),
|
||||
the SPLADE database stores chunk IDs from multiple sources. In this case,
|
||||
foreign key constraints are not enforced to allow cross-database references.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Inverted index for sparse vectors
|
||||
# Note: No FOREIGN KEY constraint to support distributed index architecture
|
||||
# where chunks may come from multiple _index.db files
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||
token_id INTEGER NOT NULL,
|
||||
chunk_id INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
PRIMARY KEY (token_id, chunk_id)
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for efficient lookups
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||
ON splade_posting_list(chunk_id)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||
ON splade_posting_list(token_id)
|
||||
""")
|
||||
|
||||
# Model metadata
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
model_name TEXT NOT NULL,
|
||||
vocab_size INTEGER NOT NULL,
|
||||
onnx_path TEXT,
|
||||
created_at REAL
|
||||
)
|
||||
""")
|
||||
|
||||
# Chunk metadata for self-contained search results
|
||||
# Stores all chunk info needed to build SearchResult without querying _index.db
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_chunks (
|
||||
id INTEGER PRIMARY KEY,
|
||||
file_path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
source_db TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
logger.debug("SPLADE schema created successfully")
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to create SPLADE schema: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="create_tables"
|
||||
) from e
|
||||
|
||||
def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None:
|
||||
"""Add a single document to inverted index.
|
||||
|
||||
Args:
|
||||
chunk_id: Chunk ID (foreign key to semantic_chunks.id).
|
||||
sparse_vec: Sparse vector as {token_id: weight} mapping.
|
||||
"""
|
||||
if not sparse_vec:
|
||||
logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id)
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Insert all non-zero weights for this chunk
|
||||
postings = [
|
||||
(token_id, chunk_id, weight)
|
||||
for token_id, weight in sparse_vec.items()
|
||||
if weight > 0 # Only store non-zero weights
|
||||
]
|
||||
|
||||
if postings:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_posting_list
|
||||
(token_id, chunk_id, weight)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
postings
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug(
|
||||
"Added %d postings for chunk_id=%d", len(postings), chunk_id
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to add posting for chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_posting"
|
||||
) from e
|
||||
|
||||
def add_postings_batch(
|
||||
self, postings: List[Tuple[int, Dict[int, float]]]
|
||||
) -> None:
|
||||
"""Batch insert postings for multiple chunks.
|
||||
|
||||
Args:
|
||||
postings: List of (chunk_id, sparse_vec) tuples.
|
||||
"""
|
||||
if not postings:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Flatten all postings into single batch
|
||||
batch_data = []
|
||||
for chunk_id, sparse_vec in postings:
|
||||
for token_id, weight in sparse_vec.items():
|
||||
if weight > 0: # Only store non-zero weights
|
||||
batch_data.append((token_id, chunk_id, weight))
|
||||
|
||||
if batch_data:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_posting_list
|
||||
(token_id, chunk_id, weight)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
batch_data
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug(
|
||||
"Batch inserted %d postings for %d chunks",
|
||||
len(batch_data),
|
||||
len(postings)
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to batch insert postings: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_postings_batch"
|
||||
) from e
|
||||
|
||||
def add_chunk_metadata(
|
||||
self,
|
||||
chunk_id: int,
|
||||
file_path: str,
|
||||
content: str,
|
||||
metadata: Optional[str] = None,
|
||||
source_db: Optional[str] = None
|
||||
) -> None:
|
||||
"""Store chunk metadata for self-contained search results.
|
||||
|
||||
Args:
|
||||
chunk_id: Global chunk ID.
|
||||
file_path: Path to source file.
|
||||
content: Chunk text content.
|
||||
metadata: JSON metadata string.
|
||||
source_db: Path to source _index.db.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_chunks
|
||||
(id, file_path, content, metadata, source_db)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(chunk_id, file_path, content, metadata, source_db)
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to add chunk metadata for chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_chunk_metadata"
|
||||
) from e
|
||||
|
||||
def add_chunks_metadata_batch(
|
||||
self,
|
||||
chunks: List[Tuple[int, str, str, Optional[str], Optional[str]]]
|
||||
) -> None:
|
||||
"""Batch insert chunk metadata.
|
||||
|
||||
Args:
|
||||
chunks: List of (chunk_id, file_path, content, metadata, source_db) tuples.
|
||||
"""
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_chunks
|
||||
(id, file_path, content, metadata, source_db)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
chunks
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug("Batch inserted %d chunk metadata records", len(chunks))
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to batch insert chunk metadata: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_chunks_metadata_batch"
|
||||
) from e
|
||||
|
||||
def get_chunks_by_ids(self, chunk_ids: List[int]) -> List[Dict]:
|
||||
"""Get chunk metadata by IDs.
|
||||
|
||||
Args:
|
||||
chunk_ids: List of chunk IDs to retrieve.
|
||||
|
||||
Returns:
|
||||
List of dicts with id, file_path, content, metadata, source_db.
|
||||
"""
|
||||
if not chunk_ids:
|
||||
return []
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT id, file_path, content, metadata, source_db
|
||||
FROM splade_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
""",
|
||||
chunk_ids
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"file_path": row["file_path"],
|
||||
"content": row["content"],
|
||||
"metadata": row["metadata"],
|
||||
"source_db": row["source_db"]
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get chunks by IDs: %s", e)
|
||||
return []
|
||||
|
||||
def remove_chunk(self, chunk_id: int) -> int:
|
||||
"""Remove all postings for a chunk.
|
||||
|
||||
Args:
|
||||
chunk_id: Chunk ID to remove.
|
||||
|
||||
Returns:
|
||||
Number of deleted postings.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM splade_posting_list WHERE chunk_id = ?",
|
||||
(chunk_id,)
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount
|
||||
logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id)
|
||||
return deleted
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to remove chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="remove_chunk"
|
||||
) from e
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_sparse: Dict[int, float],
|
||||
limit: int = 50,
|
||||
min_score: float = 0.0,
|
||||
max_query_terms: int = 64
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Search for similar chunks using dot-product scoring.
|
||||
|
||||
Implements efficient sparse dot-product via SQL JOIN:
|
||||
score(q, d) = sum(q[t] * d[t]) for all tokens t
|
||||
|
||||
Args:
|
||||
query_sparse: Query sparse vector as {token_id: weight}.
|
||||
limit: Maximum number of results.
|
||||
min_score: Minimum score threshold.
|
||||
max_query_terms: Maximum query terms to use (default: 64).
|
||||
Pruning to top-K terms reduces search time with minimal impact on quality.
|
||||
Set to 0 or negative to disable pruning (use all terms).
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, score) tuples, ordered by score descending.
|
||||
"""
|
||||
if not query_sparse:
|
||||
logger.warning("Empty query sparse vector")
|
||||
return []
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Build VALUES clause for query terms
|
||||
# Each term: (token_id, weight)
|
||||
query_terms = [
|
||||
(token_id, weight)
|
||||
for token_id, weight in query_sparse.items()
|
||||
if weight > 0
|
||||
]
|
||||
|
||||
if not query_terms:
|
||||
logger.warning("No non-zero query terms")
|
||||
return []
|
||||
|
||||
# Query pruning: keep only top-K terms by weight
|
||||
# max_query_terms <= 0 means no limit (use all terms)
|
||||
if max_query_terms > 0 and len(query_terms) > max_query_terms:
|
||||
query_terms = sorted(query_terms, key=lambda x: x[1], reverse=True)[:max_query_terms]
|
||||
logger.debug(
|
||||
"Query pruned from %d to %d terms",
|
||||
len(query_sparse),
|
||||
len(query_terms)
|
||||
)
|
||||
|
||||
# Create CTE for query terms using parameterized VALUES
|
||||
# Build placeholders and params to prevent SQL injection
|
||||
params = []
|
||||
placeholders = []
|
||||
for token_id, weight in query_terms:
|
||||
placeholders.append("(?, ?)")
|
||||
params.extend([token_id, weight])
|
||||
|
||||
values_placeholders = ", ".join(placeholders)
|
||||
|
||||
sql = f"""
|
||||
WITH query_terms(token_id, weight) AS (
|
||||
VALUES {values_placeholders}
|
||||
)
|
||||
SELECT
|
||||
p.chunk_id,
|
||||
SUM(p.weight * q.weight) as score
|
||||
FROM splade_posting_list p
|
||||
INNER JOIN query_terms q ON p.token_id = q.token_id
|
||||
GROUP BY p.chunk_id
|
||||
HAVING score >= ?
|
||||
ORDER BY score DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
# Append min_score and limit to params
|
||||
params.extend([min_score, limit])
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
results = [(row["chunk_id"], float(row["score"])) for row in rows]
|
||||
logger.debug(
|
||||
"SPLADE search: %d query terms, %d results",
|
||||
len(query_terms),
|
||||
len(results)
|
||||
)
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"SPLADE search failed: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="search"
|
||||
) from e
|
||||
|
||||
def get_metadata(self) -> Optional[Dict]:
|
||||
"""Get SPLADE model metadata.
|
||||
|
||||
Returns:
|
||||
Dictionary with model_name, vocab_size, onnx_path, created_at,
|
||||
or None if not set.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT model_name, vocab_size, onnx_path, created_at
|
||||
FROM splade_metadata
|
||||
WHERE id = 1
|
||||
"""
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"model_name": row["model_name"],
|
||||
"vocab_size": row["vocab_size"],
|
||||
"onnx_path": row["onnx_path"],
|
||||
"created_at": row["created_at"]
|
||||
}
|
||||
return None
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get metadata: %s", e)
|
||||
return None
|
||||
|
||||
def set_metadata(
|
||||
self,
|
||||
model_name: str,
|
||||
vocab_size: int,
|
||||
onnx_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""Set SPLADE model metadata.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name.
|
||||
vocab_size: Vocabulary size (typically ~30k for BERT vocab).
|
||||
onnx_path: Optional path to ONNX model file.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
current_time = time.time()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_metadata
|
||||
(id, model_name, vocab_size, onnx_path, created_at)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""",
|
||||
(model_name, vocab_size, onnx_path, current_time)
|
||||
)
|
||||
conn.commit()
|
||||
logger.info(
|
||||
"Set SPLADE metadata: model=%s, vocab_size=%d",
|
||||
model_name,
|
||||
vocab_size
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to set metadata: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="set_metadata"
|
||||
) from e
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with total_postings, unique_tokens, unique_chunks.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_postings,
|
||||
COUNT(DISTINCT token_id) as unique_tokens,
|
||||
COUNT(DISTINCT chunk_id) as unique_chunks
|
||||
FROM splade_posting_list
|
||||
""").fetchone()
|
||||
|
||||
return {
|
||||
"total_postings": row["total_postings"],
|
||||
"unique_tokens": row["unique_tokens"],
|
||||
"unique_chunks": row["unique_chunks"]
|
||||
}
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get stats: %s", e)
|
||||
return {
|
||||
"total_postings": 0,
|
||||
"unique_tokens": 0,
|
||||
"unique_chunks": 0
|
||||
}
|
||||
Reference in New Issue
Block a user