refactor: 移除 SPLADE 和 hybrid_cascade,精简搜索架构

删除 SPLADE 稀疏神经搜索后端和 hybrid_cascade 策略,
将搜索架构从 6 种后端简化为 4 种(FTS Exact/Fuzzy, Binary Vector, Dense Vector, LSP)。

主要变更:
- 删除 splade_encoder.py, splade_index.py, migration_009 等 4 个文件
- 移除 config.py 中 SPLADE 相关配置(enable_splade, splade_model 等)
- DEFAULT_WEIGHTS 改为 FTS 权重 {exact:0.25, fuzzy:0.1, vector:0.5, lsp:0.15}
- 删除 hybrid_cascade_search(),所有 cascade fallback 改为 self.search()
- API fusion_strategy='hybrid' 向后兼容映射到 binary_rerank
- 删除 CLI index_splade/splade_status 命令和 --method splade
- 更新测试、基准测试和文档
This commit is contained in:
catlog22
2026-02-08 12:07:41 +08:00
parent 72d2ae750b
commit 71faaf43a8
22 changed files with 126 additions and 2883 deletions

View File

@@ -48,7 +48,8 @@ def semantic_search(
- rrf: Reciprocal Rank Fusion (recommended, default)
- staged: Staged cascade -> staged_cascade_search
- binary: Binary rerank cascade -> binary_cascade_search
- hybrid: Hybrid cascade -> hybrid_cascade_search
- hybrid: Binary rerank cascade (backward compat) -> binary_rerank_cascade_search
- dense_rerank: Dense rerank cascade -> dense_rerank_cascade_search
kind_filter: Symbol type filter (e.g., ["function", "class"])
limit: Max return count (default 20)
include_match_reason: Generate match reason (heuristic, not LLM)
@@ -215,7 +216,8 @@ def _execute_search(
- rrf: Standard hybrid search with RRF fusion
- staged: staged_cascade_search
- binary: binary_cascade_search
- hybrid: hybrid_cascade_search
- hybrid: binary_rerank_cascade_search (backward compat)
- dense_rerank: dense_rerank_cascade_search
Args:
engine: ChainSearchEngine instance
@@ -249,8 +251,8 @@ def _execute_search(
options=options,
)
elif fusion_strategy == "hybrid":
# Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder)
return engine.hybrid_cascade_search(
# Backward compat: hybrid now maps to binary_rerank_cascade_search
return engine.binary_rerank_cascade_search(
query=query,
source_path=source_path,
k=limit,
@@ -342,8 +344,6 @@ def _transform_results(
fts_scores.append(source_scores["exact"])
if "fuzzy" in source_scores:
fts_scores.append(source_scores["fuzzy"])
if "splade" in source_scores:
fts_scores.append(source_scores["splade"])
if fts_scores:
structural_score = max(fts_scores)

View File

@@ -6,7 +6,6 @@ import json
import logging
import os
import shutil
import sqlite3
from pathlib import Path
from typing import Annotated, Any, Dict, Iterable, List, Optional
@@ -37,7 +36,7 @@ from .output import (
app = typer.Typer(help="CodexLens CLI — local code indexing and search.")
# Index subcommand group for reorganized commands
index_app = typer.Typer(help="Index management commands (init, embeddings, splade, binary, status, migrate, all)")
index_app = typer.Typer(help="Index management commands (init, embeddings, binary, status, migrate, all)")
app.add_typer(index_app, name="index")
@@ -521,15 +520,15 @@ def search(
print_json(success=False, error=f"Invalid deprecated mode: {mode}. Use --method instead.")
else:
console.print(f"[red]Invalid deprecated mode:[/red] {mode}")
console.print("[dim]Use --method with: fts, vector, splade, hybrid, cascade[/dim]")
console.print("[dim]Use --method with: fts, vector, hybrid, cascade[/dim]")
raise typer.Exit(code=1)
# Configure search (load settings from file)
config = Config.load()
# Validate method - simplified interface exposes only dense_rerank and fts
# Other methods (vector, splade, hybrid, cascade) are hidden but still work for backward compatibility
valid_methods = ["fts", "dense_rerank", "vector", "splade", "hybrid", "cascade"]
# Other methods (vector, hybrid, cascade) are hidden but still work for backward compatibility
valid_methods = ["fts", "dense_rerank", "vector", "hybrid", "cascade"]
if actual_method not in valid_methods:
if json_mode:
print_json(success=False, error=f"Invalid method: {actual_method}. Use 'dense_rerank' (semantic) or 'fts' (exact keyword).")
@@ -561,7 +560,7 @@ def search(
try:
# Check if using key=value format (new) or legacy comma-separated format
if "=" in weights:
# New format: splade=0.4,vector=0.6 or exact=0.3,fuzzy=0.1,vector=0.6
# New format: exact=0.3,fuzzy=0.1,vector=0.6
weight_dict = {}
for pair in weights.split(","):
if "=" in pair:
@@ -592,17 +591,6 @@ def search(
"fuzzy": weight_parts[1],
"vector": weight_parts[2],
}
elif len(weight_parts) == 2:
# Two values: assume splade,vector
weight_sum = sum(weight_parts)
if abs(weight_sum - 1.0) > 0.01:
if not json_mode:
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
weight_parts = [w / weight_sum for w in weight_parts]
hybrid_weights = {
"splade": weight_parts[0],
"vector": weight_parts[1],
}
else:
if not json_mode:
console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]")
@@ -621,7 +609,6 @@ def search(
# Map method to SearchOptions flags
# fts: FTS-only search (optionally with fuzzy)
# vector: Pure vector semantic search
# splade: SPLADE sparse neural search
# hybrid: RRF fusion of sparse + dense
# cascade: Two-stage binary + dense retrieval
if actual_method == "fts":
@@ -629,35 +616,24 @@ def search(
enable_fuzzy = use_fuzzy
enable_vector = False
pure_vector = False
enable_splade = False
enable_cascade = False
elif actual_method == "vector":
hybrid_mode = True
enable_fuzzy = False
enable_vector = True
pure_vector = True
enable_splade = False
enable_cascade = False
elif actual_method == "splade":
hybrid_mode = True
enable_fuzzy = False
enable_vector = False
pure_vector = False
enable_splade = True
enable_cascade = False
elif actual_method == "hybrid":
hybrid_mode = True
enable_fuzzy = use_fuzzy
enable_vector = True
pure_vector = False
enable_splade = True # SPLADE is preferred sparse in hybrid
enable_cascade = False
elif actual_method == "cascade":
hybrid_mode = True
enable_fuzzy = False
enable_vector = True
pure_vector = False
enable_splade = False
enable_cascade = True
else:
raise ValueError(f"Invalid method: {actual_method}")
@@ -678,7 +654,6 @@ def search(
enable_fuzzy=enable_fuzzy,
enable_vector=enable_vector,
pure_vector=pure_vector,
enable_splade=enable_splade,
enable_cascade=enable_cascade,
hybrid_weights=hybrid_weights,
)
@@ -2857,251 +2832,8 @@ def gpu_reset(
# ==================== SPLADE Commands ====================
@index_app.command("splade")
def index_splade(
path: Path = typer.Argument(..., help="Project path to index"),
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
) -> None:
"""Generate SPLADE sparse index for existing codebase.
Encodes all semantic chunks with SPLADE model and builds inverted index
for efficient sparse retrieval.
This command discovers all _index.db files recursively in the project's
index directory and builds SPLADE encodings for chunks across all of them.
Examples:
codexlens index splade ~/projects/my-app
codexlens index splade . --rebuild
"""
_configure_logging(verbose)
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
from codexlens.storage.splade_index import SpladeIndex
from codexlens.semantic.vector_store import VectorStore
# Check SPLADE availability
ok, err = check_splade_available()
if not ok:
console.print(f"[red]SPLADE not available: {err}[/red]")
console.print("[dim]Install with: pip install transformers torch[/dim]")
raise typer.Exit(1)
# Find index root directory
target_path = path.expanduser().resolve()
# Determine index root directory (containing _index.db files)
if target_path.is_file() and target_path.name == "_index.db":
index_root = target_path.parent
elif target_path.is_dir():
# Check for local .codexlens/_index.db
local_index = target_path / ".codexlens" / "_index.db"
if local_index.exists():
index_root = local_index.parent
else:
# Try to find via registry
registry = RegistryStore()
try:
registry.initialize()
mapper = PathMapper()
index_db = mapper.source_to_index_db(target_path)
if not index_db.exists():
console.print(f"[red]Error:[/red] No index found for {target_path}")
console.print("Run 'codexlens init' first to create an index")
raise typer.Exit(1)
index_root = index_db.parent
finally:
registry.close()
else:
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
raise typer.Exit(1)
# Discover all _index.db files recursively
all_index_dbs = sorted(index_root.rglob("_index.db"))
if not all_index_dbs:
console.print(f"[red]Error:[/red] No _index.db files found in {index_root}")
raise typer.Exit(1)
console.print(f"[blue]Discovered {len(all_index_dbs)} index databases[/blue]")
# SPLADE index is stored alongside the root _index.db
from codexlens.config import SPLADE_DB_NAME
splade_db = index_root / SPLADE_DB_NAME
if splade_db.exists() and not rebuild:
console.print("[yellow]SPLADE index exists. Use --rebuild to regenerate.[/yellow]")
return
# If rebuild, delete existing splade database
if splade_db.exists() and rebuild:
splade_db.unlink()
# Collect all chunks from all distributed index databases
# Assign globally unique IDs to avoid collisions (each DB starts with ID 1)
console.print(f"[blue]Loading chunks from {len(all_index_dbs)} distributed indexes...[/blue]")
all_chunks = [] # (global_id, chunk) pairs
total_files_checked = 0
indexes_with_chunks = 0
global_id = 0 # Sequential global ID across all databases
for index_db in all_index_dbs:
total_files_checked += 1
try:
vector_store = VectorStore(index_db)
chunks = vector_store.get_all_chunks()
if chunks:
indexes_with_chunks += 1
# Assign sequential global IDs to avoid collisions
for chunk in chunks:
global_id += 1
all_chunks.append((global_id, chunk, index_db))
if verbose:
console.print(f" [dim]{index_db.parent.name}: {len(chunks)} chunks[/dim]")
vector_store.close()
except Exception as e:
if verbose:
console.print(f" [yellow]Warning: Failed to read {index_db}: {e}[/yellow]")
if not all_chunks:
console.print("[yellow]No chunks found in any index database[/yellow]")
console.print(f"[dim]Checked {total_files_checked} index files, found 0 chunks[/dim]")
console.print("[dim]Generate embeddings first with 'codexlens embeddings-generate --recursive'[/dim]")
raise typer.Exit(1)
console.print(f"[blue]Found {len(all_chunks)} chunks across {indexes_with_chunks} indexes[/blue]")
console.print(f"[blue]Encoding with SPLADE...[/blue]")
# Initialize SPLADE
encoder = get_splade_encoder()
splade_index = SpladeIndex(splade_db)
splade_index.create_tables()
# Encode in batches with progress bar
chunk_metadata_batch = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console,
) as progress:
task = progress.add_task("Encoding...", total=len(all_chunks))
for global_id, chunk, source_db_path in all_chunks:
sparse_vec = encoder.encode_text(chunk.content)
splade_index.add_posting(global_id, sparse_vec)
# Store chunk metadata for self-contained search
# Serialize metadata dict to JSON string
metadata_str = None
if hasattr(chunk, 'metadata') and chunk.metadata:
try:
metadata_str = json.dumps(chunk.metadata) if isinstance(chunk.metadata, dict) else chunk.metadata
except Exception:
pass
chunk_metadata_batch.append((
global_id,
chunk.file_path or "",
chunk.content,
metadata_str,
str(source_db_path)
))
progress.advance(task)
# Batch insert chunk metadata
if chunk_metadata_batch:
splade_index.add_chunks_metadata_batch(chunk_metadata_batch)
# Set metadata
splade_index.set_metadata(
model_name=encoder.model_name,
vocab_size=encoder.vocab_size
)
stats = splade_index.get_stats()
console.print(f"[green]OK[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings")
console.print(f" Source indexes: {indexes_with_chunks}")
console.print(f" Database: [dim]{splade_db}[/dim]")
@app.command("splade-status", hidden=True, deprecated=True)
def splade_status_command(
path: Path = typer.Argument(..., help="Project path"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
) -> None:
"""[Deprecated] Use 'codexlens index status' instead.
Show SPLADE index status and statistics.
Examples:
codexlens splade-status ~/projects/my-app
codexlens splade-status .
"""
_deprecated_command_warning("splade-status", "index status")
_configure_logging(verbose)
from codexlens.storage.splade_index import SpladeIndex
from codexlens.semantic.splade_encoder import check_splade_available
from codexlens.config import SPLADE_DB_NAME
# Find index database
target_path = path.expanduser().resolve()
if target_path.is_file() and target_path.name == "_index.db":
splade_db = target_path.parent / SPLADE_DB_NAME
elif target_path.is_dir():
# Check for local .codexlens/_splade.db
local_splade = target_path / ".codexlens" / SPLADE_DB_NAME
if local_splade.exists():
splade_db = local_splade
else:
# Try to find via registry
registry = RegistryStore()
try:
registry.initialize()
mapper = PathMapper()
index_db = mapper.source_to_index_db(target_path)
splade_db = index_db.parent / SPLADE_DB_NAME
finally:
registry.close()
else:
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
raise typer.Exit(1)
if not splade_db.exists():
console.print("[yellow]No SPLADE index found[/yellow]")
console.print(f"[dim]Run 'codexlens splade-index {path}' to create one[/dim]")
return
splade_index = SpladeIndex(splade_db)
if not splade_index.has_index():
console.print("[yellow]SPLADE tables not initialized[/yellow]")
return
metadata = splade_index.get_metadata()
stats = splade_index.get_stats()
# Create status table
table = Table(title="SPLADE Index Status", show_header=False)
table.add_column("Property", style="cyan")
table.add_column("Value")
table.add_row("Database", str(splade_db))
if metadata:
table.add_row("Model", metadata['model_name'])
table.add_row("Vocab Size", str(metadata['vocab_size']))
table.add_row("Chunks", str(stats['unique_chunks']))
table.add_row("Unique Tokens", str(stats['unique_tokens']))
table.add_row("Total Postings", str(stats['total_postings']))
ok, err = check_splade_available()
status_text = "[green]Yes[/green]" if ok else f"[red]No[/red] - {err}"
table.add_row("SPLADE Available", status_text)
console.print(table)
# ==================== Watch Command ====================
@@ -3516,11 +3248,10 @@ def index_status(
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
) -> None:
"""Show comprehensive index status (embeddings + SPLADE).
"""Show comprehensive index status (embeddings).
Shows combined status for all index types:
- Dense vector embeddings (HNSW)
- SPLADE sparse embeddings
- Binary cascade embeddings
Examples:
@@ -3531,9 +3262,6 @@ def index_status(
_configure_logging(verbose, json_mode)
from codexlens.cli.embedding_manager import check_index_embeddings, get_embedding_stats_summary
from codexlens.storage.splade_index import SpladeIndex
from codexlens.semantic.splade_encoder import check_splade_available
from codexlens.config import SPLADE_DB_NAME
# Determine target path and index root
if path is None:
@@ -3571,36 +3299,11 @@ def index_status(
# Get embeddings status
embeddings_result = get_embedding_stats_summary(index_root)
# Get SPLADE status
splade_db = index_root / SPLADE_DB_NAME
splade_status = {
"available": False,
"has_index": False,
"stats": None,
"metadata": None,
}
splade_available, splade_err = check_splade_available()
splade_status["available"] = splade_available
if splade_db.exists():
try:
splade_index = SpladeIndex(splade_db)
if splade_index.has_index():
splade_status["has_index"] = True
splade_status["stats"] = splade_index.get_stats()
splade_status["metadata"] = splade_index.get_metadata()
splade_index.close()
except Exception as e:
if verbose:
console.print(f"[yellow]Warning: Failed to read SPLADE index: {e}[/yellow]")
# Build combined result
result = {
"index_root": str(index_root),
"embeddings": embeddings_result.get("result") if embeddings_result.get("success") else None,
"embeddings_error": embeddings_result.get("error") if not embeddings_result.get("success") else None,
"splade": splade_status,
}
if json_mode:
@@ -3623,27 +3326,6 @@ def index_status(
else:
console.print(f" [yellow]--[/yellow] {embeddings_result.get('error', 'Not available')}")
# SPLADE section
console.print("\n[bold]SPLADE Sparse Index:[/bold]")
if splade_status["has_index"]:
stats = splade_status["stats"] or {}
metadata = splade_status["metadata"] or {}
console.print(f" [green]OK[/green] SPLADE index available")
console.print(f" Chunks: {stats.get('unique_chunks', 0):,}")
console.print(f" Unique tokens: {stats.get('unique_tokens', 0):,}")
console.print(f" Total postings: {stats.get('total_postings', 0):,}")
if metadata.get("model_name"):
console.print(f" Model: {metadata['model_name']}")
elif splade_available:
console.print(f" [yellow]--[/yellow] No SPLADE index found")
console.print(f" [dim]Run 'codexlens index splade <path>' to create one[/dim]")
else:
console.print(f" [yellow]--[/yellow] SPLADE not available: {splade_err}")
# Runtime availability
console.print("\n[bold]Runtime Availability:[/bold]")
console.print(f" SPLADE encoder: {'[green]Yes[/green]' if splade_available else f'[red]No[/red] ({splade_err})'}")
# ==================== Index Update Command ====================
@@ -3739,22 +3421,19 @@ def index_all(
backend: str = typer.Option("fastembed", "--backend", "-b", help="Embedding backend: fastembed or litellm."),
model: str = typer.Option("code", "--model", "-m", help="Embedding model profile or name."),
max_workers: int = typer.Option(1, "--max-workers", min=1, help="Max concurrent API calls."),
skip_splade: bool = typer.Option(False, "--skip-splade", help="Skip SPLADE index generation."),
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
"""Run all indexing operations in sequence (init, embeddings, splade).
"""Run all indexing operations in sequence (init, embeddings).
This is a convenience command that runs the complete indexing pipeline:
1. FTS index initialization (index init)
2. Dense vector embeddings (index embeddings)
3. SPLADE sparse index (index splade) - unless --skip-splade
Examples:
codexlens index all ~/projects/my-app
codexlens index all . --force
codexlens index all . --backend litellm --model text-embedding-3-small
codexlens index all . --skip-splade
"""
_configure_logging(verbose, json_mode)
@@ -3766,7 +3445,7 @@ def index_all(
# Step 1: Run init
if not json_mode:
console.print(f"[bold]Step 1/3: Initializing FTS index...[/bold]")
console.print(f"[bold]Step 1/2: Initializing FTS index...[/bold]")
try:
# Import and call the init function directly
@@ -3810,7 +3489,7 @@ def index_all(
# Step 2: Generate embeddings
if not json_mode:
console.print(f"\n[bold]Step 2/3: Generating dense embeddings...[/bold]")
console.print(f"\n[bold]Step 2/2: Generating dense embeddings...[/bold]")
try:
from codexlens.cli.embedding_manager import generate_dense_embeddings_centralized
@@ -3851,103 +3530,6 @@ def index_all(
if not json_mode:
console.print(f" [yellow]Warning:[/yellow] {e}")
# Step 3: Generate SPLADE index (unless skipped)
if not skip_splade:
if not json_mode:
console.print(f"\n[bold]Step 3/3: Generating SPLADE index...[/bold]")
try:
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
from codexlens.storage.splade_index import SpladeIndex
from codexlens.semantic.vector_store import VectorStore
from codexlens.config import SPLADE_DB_NAME
ok, err = check_splade_available()
if not ok:
results["steps"]["splade"] = {"success": False, "error": f"SPLADE not available: {err}"}
if not json_mode:
console.print(f" [yellow]Skipped:[/yellow] SPLADE not available ({err})")
else:
# Discover all _index.db files
all_index_dbs = sorted(index_root.rglob("_index.db"))
if not all_index_dbs:
results["steps"]["splade"] = {"success": False, "error": "No index databases found"}
if not json_mode:
console.print(f" [yellow]Skipped:[/yellow] No index databases found")
else:
# Collect chunks
all_chunks = []
global_id = 0
for index_db in all_index_dbs:
try:
vector_store = VectorStore(index_db)
chunks = vector_store.get_all_chunks()
for chunk in chunks:
global_id += 1
all_chunks.append((global_id, chunk, index_db))
vector_store.close()
except Exception:
pass
if all_chunks:
splade_db = index_root / SPLADE_DB_NAME
if splade_db.exists() and force:
splade_db.unlink()
encoder = get_splade_encoder()
splade_index = SpladeIndex(splade_db)
splade_index.create_tables()
chunk_metadata_batch = []
import json as json_module
for gid, chunk, source_db_path in all_chunks:
sparse_vec = encoder.encode_text(chunk.content)
splade_index.add_posting(gid, sparse_vec)
metadata_str = None
if hasattr(chunk, 'metadata') and chunk.metadata:
try:
metadata_str = json_module.dumps(chunk.metadata) if isinstance(chunk.metadata, dict) else chunk.metadata
except Exception:
pass
chunk_metadata_batch.append((
gid,
chunk.file_path or "",
chunk.content,
metadata_str,
str(source_db_path)
))
if chunk_metadata_batch:
splade_index.add_chunks_metadata_batch(chunk_metadata_batch)
splade_index.set_metadata(
model_name=encoder.model_name,
vocab_size=encoder.vocab_size
)
stats = splade_index.get_stats()
results["steps"]["splade"] = {
"success": True,
"chunks": stats['unique_chunks'],
"postings": stats['total_postings'],
}
if not json_mode:
console.print(f" [green]OK[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings")
else:
results["steps"]["splade"] = {"success": False, "error": "No chunks found"}
if not json_mode:
console.print(f" [yellow]Skipped:[/yellow] No chunks found in indexes")
except Exception as e:
results["steps"]["splade"] = {"success": False, "error": str(e)}
if not json_mode:
console.print(f" [yellow]Warning:[/yellow] {e}")
else:
results["steps"]["splade"] = {"success": True, "skipped": True}
if not json_mode:
console.print(f"\n[bold]Step 3/3: SPLADE index...[/bold]")
console.print(f" [dim]Skipped (--skip-splade)[/dim]")
# Summary
if json_mode:
print_json(success=True, result=results)
@@ -3955,10 +3537,8 @@ def index_all(
console.print(f"\n[bold]Indexing Complete[/bold]")
init_ok = results["steps"].get("init", {}).get("success", False)
emb_ok = results["steps"].get("embeddings", {}).get("success", False)
splade_ok = results["steps"].get("splade", {}).get("success", False)
console.print(f" FTS Index: {'[green]OK[/green]' if init_ok else '[red]Failed[/red]'}")
console.print(f" Embeddings: {'[green]OK[/green]' if emb_ok else '[yellow]Partial/Skipped[/yellow]'}")
console.print(f" SPLADE: {'[green]OK[/green]' if splade_ok else '[yellow]Partial/Skipped[/yellow]'}")
# ==================== Index Migration Commands ====================
@@ -3997,50 +3577,6 @@ def _set_index_version(index_root: Path, version: str) -> None:
version_file.write_text(version, encoding="utf-8")
def _discover_distributed_splade(index_root: Path) -> List[Dict[str, Any]]:
"""Discover distributed SPLADE data in _index.db files.
Scans all _index.db files for embedded splade_postings tables.
This is the old distributed format that needs migration.
Args:
index_root: Root directory to scan
Returns:
List of dicts with db_path, posting_count, chunk_count
"""
results = []
for db_path in index_root.rglob("_index.db"):
try:
conn = sqlite3.connect(db_path, timeout=5.0)
conn.row_factory = sqlite3.Row
# Check if splade_postings table exists (old embedded format)
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='splade_postings'"
)
if cursor.fetchone():
# Count postings and chunks
try:
row = conn.execute(
"SELECT COUNT(*) as postings, COUNT(DISTINCT chunk_id) as chunks FROM splade_postings"
).fetchone()
results.append({
"db_path": db_path,
"posting_count": row["postings"] if row else 0,
"chunk_count": row["chunks"] if row else 0,
})
except Exception:
pass
conn.close()
except Exception:
pass
return results
def _discover_distributed_hnsw(index_root: Path) -> List[Dict[str, Any]]:
"""Discover distributed HNSW index files.
@@ -4075,33 +3611,18 @@ def _check_centralized_storage(index_root: Path) -> Dict[str, Any]:
index_root: Root directory to check
Returns:
Dict with has_splade, has_vectors, splade_stats, vector_stats
Dict with has_vectors, vector_stats
"""
from codexlens.config import SPLADE_DB_NAME, VECTORS_HNSW_NAME
from codexlens.config import VECTORS_HNSW_NAME
splade_db = index_root / SPLADE_DB_NAME
vectors_hnsw = index_root / VECTORS_HNSW_NAME
result = {
"has_splade": splade_db.exists(),
"has_vectors": vectors_hnsw.exists(),
"splade_path": str(splade_db) if splade_db.exists() else None,
"vectors_path": str(vectors_hnsw) if vectors_hnsw.exists() else None,
"splade_stats": None,
"vector_stats": None,
}
# Get SPLADE stats if exists
if splade_db.exists():
try:
from codexlens.storage.splade_index import SpladeIndex
splade = SpladeIndex(splade_db)
if splade.has_index():
result["splade_stats"] = splade.get_stats()
splade.close()
except Exception:
pass
# Get vector stats if exists
if vectors_hnsw.exists():
try:
@@ -4125,21 +3646,19 @@ def index_migrate_cmd(
"""Migrate old distributed index to new centralized architecture.
This command upgrades indexes from the old distributed storage format
(where SPLADE/vectors were stored in each _index.db) to the new centralized
format (single _splade.db and _vectors.hnsw at index root).
(where vectors were stored in each _index.db) to the new centralized
format (single _vectors.hnsw at index root).
Migration Steps:
1. Detect if migration is needed (check version marker)
2. Discover distributed SPLADE data in _index.db files
3. Discover distributed .hnsw files
4. Report current status
5. Create version marker (unless --dry-run)
2. Discover distributed .hnsw files
3. Report current status
4. Create version marker (unless --dry-run)
Use --dry-run to preview what would be migrated without making changes.
Use --force to re-run migration even if version marker exists.
Note: For full data migration (SPLADE/vectors consolidation), run:
codexlens index splade <path> --rebuild
Note: For full data migration (vectors consolidation), run:
codexlens index embeddings <path> --force
Examples:
@@ -4222,7 +3741,6 @@ def index_migrate_cmd(
return
# Discover distributed data
distributed_splade = _discover_distributed_splade(index_root)
distributed_hnsw = _discover_distributed_hnsw(index_root)
centralized = _check_centralized_storage(index_root)
@@ -4239,8 +3757,6 @@ def index_migrate_cmd(
"needs_migration": needs_migration,
"discovery": {
"total_index_dbs": len(all_index_dbs),
"distributed_splade_count": len(distributed_splade),
"distributed_splade_total_postings": sum(d["posting_count"] for d in distributed_splade),
"distributed_hnsw_count": len(distributed_hnsw),
"distributed_hnsw_total_bytes": sum(d["size_bytes"] for d in distributed_hnsw),
},
@@ -4249,17 +3765,12 @@ def index_migrate_cmd(
}
# Generate recommendations
if distributed_splade and not centralized["has_splade"]:
migration_report["recommendations"].append(
f"Run 'codexlens splade-index {target_path} --rebuild' to consolidate SPLADE data"
)
if distributed_hnsw and not centralized["has_vectors"]:
migration_report["recommendations"].append(
f"Run 'codexlens embeddings-generate {target_path} --recursive --force' to consolidate vector data"
)
if not distributed_splade and not distributed_hnsw:
if not distributed_hnsw:
migration_report["recommendations"].append(
"No distributed data found. Index may already be using centralized storage."
)
@@ -4280,23 +3791,6 @@ def index_migrate_cmd(
console.print(f" Total _index.db files: {len(all_index_dbs)}")
console.print()
# Distributed SPLADE
console.print("[bold]Distributed SPLADE Data:[/bold]")
if distributed_splade:
total_postings = sum(d["posting_count"] for d in distributed_splade)
total_chunks = sum(d["chunk_count"] for d in distributed_splade)
console.print(f" Found in {len(distributed_splade)} _index.db files")
console.print(f" Total postings: {total_postings:,}")
console.print(f" Total chunks: {total_chunks:,}")
if verbose:
for d in distributed_splade[:5]:
console.print(f" [dim]{d['db_path'].parent.name}: {d['posting_count']} postings[/dim]")
if len(distributed_splade) > 5:
console.print(f" [dim]... and {len(distributed_splade) - 5} more[/dim]")
else:
console.print(" [dim]None found (already centralized or not generated)[/dim]")
console.print()
# Distributed HNSW
console.print("[bold]Distributed HNSW Files:[/bold]")
if distributed_hnsw:
@@ -4314,15 +3808,6 @@ def index_migrate_cmd(
# Centralized storage status
console.print("[bold]Centralized Storage:[/bold]")
if centralized["has_splade"]:
stats = centralized.get("splade_stats") or {}
console.print(f" [green]OK[/green] _splade.db exists")
if stats:
console.print(f" Chunks: {stats.get('unique_chunks', 0):,}")
console.print(f" Postings: {stats.get('total_postings', 0):,}")
else:
console.print(f" [yellow]--[/yellow] _splade.db not found")
if centralized["has_vectors"]:
stats = centralized.get("vector_stats") or {}
size_mb = stats.get("size_bytes", 0) / (1024 * 1024)
@@ -4440,20 +3925,6 @@ def init_deprecated(
)
@app.command("splade-index", hidden=True, deprecated=True)
def splade_index_deprecated(
path: Path = typer.Argument(..., help="Project path to index"),
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
) -> None:
"""[Deprecated] Use 'codexlens index splade' instead."""
_deprecated_command_warning("splade-index", "index splade")
index_splade(
path=path,
rebuild=rebuild,
verbose=verbose,
)
@app.command("cascade-index", hidden=True, deprecated=True)
def cascade_index_deprecated(

View File

@@ -151,15 +151,6 @@ def _cleanup_fastembed_resources() -> None:
pass
def _cleanup_splade_resources() -> None:
"""Release SPLADE encoder ONNX resources."""
try:
from codexlens.semantic.splade_encoder import clear_splade_cache
clear_splade_cache()
except Exception:
pass
def _generate_chunks_from_cursor(
cursor,
chunker,
@@ -398,7 +389,6 @@ def generate_embeddings(
endpoints: Optional[List] = None,
strategy: Optional[str] = None,
cooldown: Optional[float] = None,
splade_db_path: Optional[Path] = None,
) -> Dict[str, any]:
"""Generate embeddings for an index using memory-efficient batch processing.
@@ -428,9 +418,6 @@ def generate_embeddings(
Each dict has keys: model, api_key, api_base, weight.
strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware).
cooldown: Default cooldown seconds for rate-limited endpoints.
splade_db_path: Optional path to centralized SPLADE database. If None, SPLADE
is written to index_path (legacy behavior). Use index_root / SPLADE_DB_NAME
for centralized storage.
Returns:
Result dictionary with generation statistics
@@ -822,97 +809,10 @@ def generate_embeddings(
if progress_callback:
progress_callback(f"Finalizing index... Building ANN index for {total_chunks_created} chunks")
# --- SPLADE SPARSE ENCODING (after dense embeddings) ---
# Add SPLADE encoding if enabled in config
splade_success = False
splade_error = None
try:
from codexlens.config import Config, SPLADE_DB_NAME
config = Config.load()
if config.enable_splade:
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
from codexlens.storage.splade_index import SpladeIndex
ok, err = check_splade_available()
if ok:
if progress_callback:
progress_callback(f"Generating SPLADE sparse vectors for {total_chunks_created} chunks...")
# Initialize SPLADE encoder and index
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
# Use centralized SPLADE database if provided, otherwise fallback to index_path
effective_splade_path = splade_db_path if splade_db_path else index_path
splade_index = SpladeIndex(effective_splade_path)
splade_index.create_tables()
# Retrieve all chunks from database for SPLADE encoding
with sqlite3.connect(index_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute("SELECT id, content FROM semantic_chunks ORDER BY id")
# Batch encode for efficiency
SPLADE_BATCH_SIZE = 32
batch_postings = []
chunk_batch = []
chunk_ids = []
for row in cursor:
chunk_id = row["id"]
content = row["content"]
chunk_ids.append(chunk_id)
chunk_batch.append(content)
# Process batch when full
if len(chunk_batch) >= SPLADE_BATCH_SIZE:
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
batch_postings.append((cid, sparse_vec))
chunk_batch = []
chunk_ids = []
# Process remaining chunks
if chunk_batch:
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
batch_postings.append((cid, sparse_vec))
# Batch insert all postings
if batch_postings:
splade_index.add_postings_batch(batch_postings)
# Set metadata
splade_index.set_metadata(
model_name=splade_encoder.model_name,
vocab_size=splade_encoder.vocab_size
)
splade_success = True
if progress_callback:
stats = splade_index.get_stats()
progress_callback(
f"SPLADE index created: {stats['total_postings']} postings, "
f"{stats['unique_tokens']} unique tokens"
)
else:
logger.debug("SPLADE not available: %s", err)
splade_error = f"SPLADE not available: {err}"
except Exception as e:
splade_error = str(e)
logger.warning("SPLADE encoding failed: %s", e)
# Report SPLADE status after processing
if progress_callback and not splade_success and splade_error:
progress_callback(f"SPLADE index: FAILED - {splade_error}")
except Exception as e:
# Cleanup on error to prevent process hanging
try:
_cleanup_fastembed_resources()
_cleanup_splade_resources()
gc.collect()
except Exception:
pass
@@ -924,7 +824,6 @@ def generate_embeddings(
# This is critical - without it, ONNX Runtime threads prevent Python from exiting
try:
_cleanup_fastembed_resources()
_cleanup_splade_resources()
gc.collect()
except Exception:
pass
@@ -1098,10 +997,6 @@ def generate_embeddings_recursive(
if progress_callback:
progress_callback(f"Found {len(index_files)} index databases to process")
# Calculate centralized SPLADE database path
from codexlens.config import SPLADE_DB_NAME
splade_db_path = index_root / SPLADE_DB_NAME
# Process each index database
all_results = []
total_chunks = 0
@@ -1131,7 +1026,6 @@ def generate_embeddings_recursive(
endpoints=endpoints,
strategy=strategy,
cooldown=cooldown,
splade_db_path=splade_db_path, # Use centralized SPLADE storage
)
all_results.append({
@@ -1153,7 +1047,6 @@ def generate_embeddings_recursive(
# Each generate_embeddings() call does its own cleanup, but do a final one to be safe
try:
_cleanup_fastembed_resources()
_cleanup_splade_resources()
gc.collect()
except Exception:
pass
@@ -1197,7 +1090,6 @@ def generate_dense_embeddings_centralized(
Target architecture:
<index_root>/
|-- _vectors.hnsw # Centralized dense vector ANN index
|-- _splade.db # Centralized sparse vector index
|-- src/
|-- _index.db # No longer contains .hnsw file
@@ -1219,7 +1111,7 @@ def generate_dense_embeddings_centralized(
Returns:
Result dictionary with generation statistics
"""
from codexlens.config import VECTORS_HNSW_NAME, SPLADE_DB_NAME
from codexlens.config import VECTORS_HNSW_NAME
# Get defaults from config if not specified
(default_backend, default_model, default_gpu,
@@ -1543,90 +1435,6 @@ def generate_dense_embeddings_centralized(
logger.warning("Binary vector generation failed: %s", e)
# Non-fatal: continue without binary vectors
# --- SPLADE Sparse Index Generation (Centralized) ---
splade_success = False
splade_chunks_count = 0
try:
from codexlens.config import Config
config = Config.load()
if config.enable_splade and chunk_id_to_info:
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
from codexlens.storage.splade_index import SpladeIndex
import json
ok, err = check_splade_available()
if ok:
if progress_callback:
progress_callback(f"Generating SPLADE sparse vectors for {len(chunk_id_to_info)} chunks...")
# Initialize SPLADE encoder and index
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
splade_db_path = index_root / SPLADE_DB_NAME
splade_index = SpladeIndex(splade_db_path)
splade_index.create_tables()
# Batch encode for efficiency
SPLADE_BATCH_SIZE = 32
all_postings = []
all_chunk_metadata = []
# Create batches from chunk_id_to_info
chunk_items = list(chunk_id_to_info.items())
for i in range(0, len(chunk_items), SPLADE_BATCH_SIZE):
batch_items = chunk_items[i:i + SPLADE_BATCH_SIZE]
chunk_ids = [item[0] for item in batch_items]
chunk_contents = [item[1]["content"] for item in batch_items]
# Generate sparse vectors
sparse_vecs = splade_encoder.encode_batch(chunk_contents, batch_size=SPLADE_BATCH_SIZE)
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
all_postings.append((cid, sparse_vec))
if progress_callback and (i + SPLADE_BATCH_SIZE) % 100 == 0:
progress_callback(f"SPLADE encoding: {min(i + SPLADE_BATCH_SIZE, len(chunk_items))}/{len(chunk_items)}")
# Batch insert all postings
if all_postings:
splade_index.add_postings_batch(all_postings)
# CRITICAL FIX: Populate splade_chunks table
for cid, info in chunk_id_to_info.items():
metadata_str = json.dumps(info.get("metadata", {})) if info.get("metadata") else None
all_chunk_metadata.append((
cid,
info["file_path"],
info["content"],
metadata_str,
info.get("source_index_db")
))
if all_chunk_metadata:
splade_index.add_chunks_metadata_batch(all_chunk_metadata)
splade_chunks_count = len(all_chunk_metadata)
# Set metadata
splade_index.set_metadata(
model_name=splade_encoder.model_name,
vocab_size=splade_encoder.vocab_size
)
splade_index.close()
splade_success = True
if progress_callback:
progress_callback(f"SPLADE index created: {len(all_postings)} postings, {splade_chunks_count} chunks")
else:
if progress_callback:
progress_callback(f"SPLADE not available, skipping sparse index: {err}")
except Exception as e:
logger.warning("SPLADE encoding failed: %s", e)
if progress_callback:
progress_callback(f"SPLADE encoding failed: {e}")
elapsed_time = time.time() - start_time
# Cleanup
@@ -1647,8 +1455,6 @@ def generate_dense_embeddings_centralized(
"model_name": embedder.model_name,
"central_index_path": str(central_hnsw_path),
"failed_files": failed_files[:5],
"splade_success": splade_success,
"splade_chunks": splade_chunks_count,
"binary_success": binary_success,
"binary_count": binary_count,
},

View File

@@ -19,9 +19,6 @@ WORKSPACE_DIR_NAME = ".codexlens"
# Settings file name
SETTINGS_FILE_NAME = "settings.json"
# SPLADE index database name (centralized storage)
SPLADE_DB_NAME = "_splade.db"
# Dense vector storage names (centralized storage)
VECTORS_HNSW_NAME = "_vectors.hnsw"
VECTORS_META_DB_NAME = "_vectors_meta.db"
@@ -113,15 +110,6 @@ class Config:
# For litellm: model name from config (e.g., "qwen3-embedding")
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
# SPLADE sparse retrieval configuration
enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead)
splade_model: str = "naver/splade-cocondenser-ensembledistil"
splade_threshold: float = 0.01 # Min weight to store in index
splade_onnx_path: Optional[str] = None # Custom ONNX model path
# FTS fallback (disabled by default, available via --use-fts)
use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled)
# Indexing/search optimizations
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
@@ -152,7 +140,7 @@ class Config:
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
cascade_coarse_k: int = 100 # Number of coarse candidates from first stage
cascade_fine_k: int = 10 # Number of final results after reranking
cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder)
cascade_strategy: str = "binary" # "binary", "binary_rerank", "dense_rerank", or "staged"
# Staged cascade search configuration (4-stage pipeline)
staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search
@@ -398,11 +386,11 @@ class Config:
cascade = settings.get("cascade", {})
if "strategy" in cascade:
strategy = cascade["strategy"]
if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}:
if strategy in {"binary", "binary_rerank", "dense_rerank", "staged"}:
self.cascade_strategy = strategy
else:
log.warning(
"Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')",
"Invalid cascade strategy in %s: %r (expected 'binary', 'binary_rerank', 'dense_rerank', or 'staged')",
self.settings_path,
strategy,
)

View File

@@ -55,7 +55,6 @@ class SearchOptions:
enable_fuzzy: Enable fuzzy FTS in hybrid mode (default True)
enable_vector: Enable vector semantic search (default False)
pure_vector: If True, only use vector search without FTS fallback (default False)
enable_splade: Enable SPLADE sparse neural search (default False)
enable_cascade: Enable cascade (binary+dense) two-stage retrieval (default False)
hybrid_weights: Custom RRF weights for hybrid search (optional)
group_results: Enable grouping of similar results (default False)
@@ -75,7 +74,6 @@ class SearchOptions:
enable_fuzzy: bool = True
enable_vector: bool = False
pure_vector: bool = False
enable_splade: bool = False
enable_cascade: bool = False
hybrid_weights: Optional[Dict[str, float]] = None
group_results: bool = False
@@ -306,154 +304,6 @@ class ChainSearchEngine:
related_results=related_results,
)
def hybrid_cascade_search(
self,
query: str,
source_path: Path,
k: int = 10,
coarse_k: int = 100,
options: Optional[SearchOptions] = None,
) -> ChainSearchResult:
"""Execute two-stage cascade search with hybrid coarse retrieval and cross-encoder reranking.
Hybrid cascade search process:
1. Stage 1 (Coarse): Fast retrieval using RRF fusion of FTS + SPLADE + Vector
to get coarse_k candidates
2. Stage 2 (Fine): CrossEncoder reranking of candidates to get final k results
This approach balances recall (from broad coarse search) with precision
(from expensive but accurate cross-encoder scoring).
Note: This method is the original hybrid approach. For binary vector cascade,
use binary_cascade_search() instead.
Args:
query: Natural language or keyword query string
source_path: Starting directory path
k: Number of final results to return (default 10)
coarse_k: Number of coarse candidates from first stage (default 100)
options: Search configuration (uses defaults if None)
Returns:
ChainSearchResult with reranked results and statistics
Examples:
>>> engine = ChainSearchEngine(registry, mapper, config=config)
>>> result = engine.hybrid_cascade_search(
... "how to authenticate users",
... Path("D:/project/src"),
... k=10,
... coarse_k=100
... )
>>> for r in result.results:
... print(f"{r.path}: {r.score:.3f}")
"""
options = options or SearchOptions()
start_time = time.time()
stats = SearchStats()
# Use config defaults if available
if self._config is not None:
if hasattr(self._config, "cascade_coarse_k"):
coarse_k = coarse_k or self._config.cascade_coarse_k
if hasattr(self._config, "cascade_fine_k"):
k = k or self._config.cascade_fine_k
# Step 1: Find starting index
start_index = self._find_start_index(source_path)
if not start_index:
self.logger.warning(f"No index found for {source_path}")
stats.time_ms = (time.time() - start_time) * 1000
return ChainSearchResult(
query=query,
results=[],
symbols=[],
stats=stats
)
# Step 2: Collect all index paths
index_paths = self._collect_index_paths(start_index, options.depth)
stats.dirs_searched = len(index_paths)
if not index_paths:
self.logger.warning(f"No indexes collected from {start_index}")
stats.time_ms = (time.time() - start_time) * 1000
return ChainSearchResult(
query=query,
results=[],
symbols=[],
stats=stats
)
# Stage 1: Coarse retrieval with hybrid search (FTS + SPLADE + Vector)
# Use hybrid mode for multi-signal retrieval
coarse_options = SearchOptions(
depth=options.depth,
max_workers=1, # Single thread for GPU safety
limit_per_dir=max(coarse_k // len(index_paths), 20),
total_limit=coarse_k,
hybrid_mode=True,
enable_fuzzy=options.enable_fuzzy,
enable_vector=True, # Enable vector for semantic matching
pure_vector=False,
hybrid_weights=options.hybrid_weights,
)
self.logger.debug(
"Cascade Stage 1: Coarse retrieval for %d candidates", coarse_k
)
coarse_results, search_stats = self._search_parallel(
index_paths, query, coarse_options
)
stats.errors = search_stats.errors
# Merge and deduplicate coarse results
coarse_merged = self._merge_and_rank(coarse_results, coarse_k)
self.logger.debug(
"Cascade Stage 1 complete: %d candidates retrieved", len(coarse_merged)
)
if not coarse_merged:
stats.time_ms = (time.time() - start_time) * 1000
return ChainSearchResult(
query=query,
results=[],
symbols=[],
stats=stats
)
# Stage 2: Cross-encoder reranking
self.logger.debug(
"Cascade Stage 2: Cross-encoder reranking %d candidates to top-%d",
len(coarse_merged),
k,
)
final_results = self._cross_encoder_rerank(query, coarse_merged, k)
# Optional: grouping of similar results
if options.group_results:
from codexlens.search.ranking import group_similar_results
final_results = group_similar_results(
final_results, score_threshold_abs=options.grouping_threshold
)
stats.files_matched = len(final_results)
stats.time_ms = (time.time() - start_time) * 1000
self.logger.debug(
"Cascade search complete: %d results in %.2fms",
len(final_results),
stats.time_ms,
)
return ChainSearchResult(
query=query,
results=final_results,
symbols=[],
stats=stats,
)
def binary_cascade_search(
self,
query: str,
@@ -501,9 +351,9 @@ class ChainSearchEngine:
"""
if not NUMPY_AVAILABLE:
self.logger.warning(
"NumPy not available, falling back to hybrid cascade search"
"NumPy not available, falling back to standard search"
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
options = options or SearchOptions()
start_time = time.time()
@@ -552,10 +402,10 @@ class ChainSearchEngine:
except ImportError as exc:
self.logger.warning(
"Binary cascade dependencies not available: %s. "
"Falling back to hybrid cascade search.",
"Falling back to standard search.",
exc
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
# Stage 1: Binary vector coarse retrieval
self.logger.debug(
@@ -573,10 +423,10 @@ class ChainSearchEngine:
except Exception as exc:
self.logger.warning(
"Failed to generate binary query embedding: %s. "
"Falling back to hybrid cascade search.",
"Falling back to standard search.",
exc
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
# Try centralized BinarySearcher first (preferred for mmap indexes)
# The index root is the parent of the first index path
@@ -629,8 +479,8 @@ class ChainSearchEngine:
stats.errors.append(f"Binary search failed for {index_path}: {exc}")
if not all_candidates:
self.logger.debug("No binary candidates found, falling back to hybrid")
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
self.logger.debug("No binary candidates found, falling back to standard search")
return self.search(query, source_path, options=options)
# Sort by Hamming distance and take top coarse_k
all_candidates.sort(key=lambda x: x[1])
@@ -828,13 +678,12 @@ class ChainSearchEngine:
k: int = 10,
coarse_k: int = 100,
options: Optional[SearchOptions] = None,
strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]] = None,
strategy: Optional[Literal["binary", "binary_rerank", "dense_rerank", "staged"]] = None,
) -> ChainSearchResult:
"""Unified cascade search entry point with strategy selection.
Provides a single interface for cascade search with configurable strategy:
- "binary": Uses binary vector coarse ranking + dense fine ranking (fastest)
- "hybrid": Uses FTS+SPLADE+Vector coarse ranking + cross-encoder reranking (original)
- "binary_rerank": Uses binary vector coarse ranking + cross-encoder reranking (best balance)
- "dense_rerank": Uses dense vector coarse ranking + cross-encoder reranking
- "staged": 4-stage pipeline: binary -> LSP expand -> clustering -> optional rerank
@@ -850,7 +699,7 @@ class ChainSearchEngine:
k: Number of final results to return (default 10)
coarse_k: Number of coarse candidates from first stage (default 100)
options: Search configuration (uses defaults if None)
strategy: Cascade strategy - "binary", "hybrid", "binary_rerank", "dense_rerank", or "staged".
strategy: Cascade strategy - "binary", "binary_rerank", "dense_rerank", or "staged".
Returns:
ChainSearchResult with reranked results and statistics
@@ -859,8 +708,6 @@ class ChainSearchEngine:
>>> engine = ChainSearchEngine(registry, mapper, config=config)
>>> # Use binary cascade (default, fastest)
>>> result = engine.cascade_search("auth", Path("D:/project"))
>>> # Use hybrid cascade (original behavior)
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="hybrid")
>>> # Use binary + cross-encoder (best balance of speed and quality)
>>> result = engine.cascade_search("auth", Path("D:/project"), strategy="binary_rerank")
>>> # Use 4-stage pipeline (binary + LSP expand + clustering + optional rerank)
@@ -868,7 +715,7 @@ class ChainSearchEngine:
"""
# Strategy priority: parameter > config > default
effective_strategy = strategy
valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank", "staged")
valid_strategies = ("binary", "binary_rerank", "dense_rerank", "staged")
if effective_strategy is None:
# Not passed via parameter, check config
if self._config is not None:
@@ -889,7 +736,7 @@ class ChainSearchEngine:
elif effective_strategy == "staged":
return self.staged_cascade_search(query, source_path, k, coarse_k, options)
else:
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.binary_cascade_search(query, source_path, k, coarse_k, options)
def staged_cascade_search(
self,
@@ -943,9 +790,9 @@ class ChainSearchEngine:
"""
if not NUMPY_AVAILABLE:
self.logger.warning(
"NumPy not available, falling back to hybrid cascade search"
"NumPy not available, falling back to standard search"
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
options = options or SearchOptions()
start_time = time.time()
@@ -1002,8 +849,8 @@ class ChainSearchEngine:
)
if not coarse_results:
self.logger.debug("No binary candidates found, falling back to hybrid cascade")
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
self.logger.debug("No binary candidates found, falling back to standard search")
return self.search(query, source_path, options=options)
# ========== Stage 2: LSP Graph Expansion ==========
stage2_start = time.time()
@@ -1534,7 +1381,7 @@ class ChainSearchEngine:
2. Stage 2 (Fine): Cross-encoder reranking for precise semantic ranking
of candidates using query-document attention
This approach is typically faster than hybrid_cascade_search while
This approach is typically faster than binary_cascade_search while
achieving similar or better quality through cross-encoder reranking.
Performance characteristics:
@@ -1565,9 +1412,9 @@ class ChainSearchEngine:
"""
if not NUMPY_AVAILABLE:
self.logger.warning(
"NumPy not available, falling back to hybrid cascade search"
"NumPy not available, falling back to standard search"
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
options = options or SearchOptions()
start_time = time.time()
@@ -1611,10 +1458,10 @@ class ChainSearchEngine:
from codexlens.indexing.embedding import BinaryEmbeddingBackend
except ImportError as exc:
self.logger.warning(
"BinaryEmbeddingBackend not available: %s, falling back to hybrid cascade",
"BinaryEmbeddingBackend not available: %s, falling back to standard search",
exc
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
# Step 4: Binary coarse search (same as binary_cascade_search)
binary_coarse_time = time.time()
@@ -1658,7 +1505,7 @@ class ChainSearchEngine:
query_binary = binary_backend.embed_packed([query])[0]
except Exception as exc:
self.logger.warning(f"Failed to generate binary query embedding: {exc}")
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
# Fallback to per-directory binary indexes
for index_path in index_paths:
@@ -1676,9 +1523,9 @@ class ChainSearchEngine:
)
if not coarse_candidates:
self.logger.info("No binary candidates found, falling back to hybrid cascade for reranking")
# Fall back to hybrid_cascade_search which uses FTS+Vector coarse + cross-encoder rerank
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
self.logger.info("No binary candidates found, falling back to standard search for reranking")
# Fall back to standard search which uses FTS+Vector
return self.search(query, source_path, options=options)
# Sort by Hamming distance and take top coarse_k
coarse_candidates.sort(key=lambda x: x[1])
@@ -1785,7 +1632,7 @@ class ChainSearchEngine:
"Retrieved %d chunks for cross-encoder reranking", len(coarse_results)
)
# Step 6: Cross-encoder reranking (same as hybrid_cascade_search)
# Step 6: Cross-encoder reranking
rerank_time = time.time()
reranked_results = self._cross_encoder_rerank(query, coarse_results, top_k=k)
@@ -1848,9 +1695,9 @@ class ChainSearchEngine:
"""
if not NUMPY_AVAILABLE:
self.logger.warning(
"NumPy not available, falling back to hybrid cascade search"
"NumPy not available, falling back to standard search"
)
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
options = options or SearchOptions()
start_time = time.time()
@@ -1955,7 +1802,7 @@ class ChainSearchEngine:
self.logger.debug(f"Dense query embedding: {query_dense.shape[0]}-dim via {embedding_backend}/{embedding_model}")
except Exception as exc:
self.logger.warning(f"Failed to generate dense query embedding: {exc}")
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
return self.search(query, source_path, options=options)
# Step 5: Dense coarse search using centralized HNSW index
coarse_candidates: List[Tuple[int, float, Path]] = [] # (chunk_id, distance, index_path)
@@ -2006,8 +1853,8 @@ class ChainSearchEngine:
)
if not coarse_candidates:
self.logger.info("No dense candidates found, falling back to hybrid cascade")
return self.hybrid_cascade_search(query, source_path, k, coarse_k, options)
self.logger.info("No dense candidates found, falling back to standard search")
return self.search(query, source_path, options=options)
# Sort by distance (ascending for cosine distance) and take top coarse_k
coarse_candidates.sort(key=lambda x: x[1])
@@ -2972,7 +2819,6 @@ class ChainSearchEngine:
options.enable_fuzzy,
options.enable_vector,
options.pure_vector,
options.enable_splade,
options.hybrid_weights
): idx_path
for idx_path in index_paths
@@ -3001,7 +2847,6 @@ class ChainSearchEngine:
enable_fuzzy: bool = True,
enable_vector: bool = False,
pure_vector: bool = False,
enable_splade: bool = False,
hybrid_weights: Optional[Dict[str, float]] = None) -> List[SearchResult]:
"""Search a single index database.
@@ -3017,7 +2862,6 @@ class ChainSearchEngine:
enable_fuzzy: Enable fuzzy FTS in hybrid mode
enable_vector: Enable vector semantic search
pure_vector: If True, only use vector search without FTS fallback
enable_splade: If True, force SPLADE sparse neural search
hybrid_weights: Custom RRF weights for hybrid search
Returns:
@@ -3034,7 +2878,6 @@ class ChainSearchEngine:
enable_fuzzy=enable_fuzzy,
enable_vector=enable_vector,
pure_vector=pure_vector,
enable_splade=enable_splade,
)
else:
# Single-FTS search (exact or fuzzy mode)

View File

@@ -35,7 +35,6 @@ from codexlens.config import VECTORS_HNSW_NAME
from codexlens.entities import SearchResult
from codexlens.search.ranking import (
DEFAULT_WEIGHTS,
FTS_FALLBACK_WEIGHTS,
QueryIntent,
apply_symbol_boost,
cross_encoder_rerank,
@@ -57,14 +56,6 @@ except ImportError:
HAS_LSP = False
# Three-way fusion weights (FTS + Vector + SPLADE)
THREE_WAY_WEIGHTS = {
"exact": 0.2,
"splade": 0.3,
"vector": 0.5,
}
class HybridSearchEngine:
"""Hybrid search engine with parallel execution and RRF fusion.
@@ -77,8 +68,7 @@ class HybridSearchEngine:
"""
# NOTE: DEFAULT_WEIGHTS imported from ranking.py - single source of truth
# Default RRF weights: SPLADE-based hybrid (splade: 0.4, vector: 0.6)
# FTS fallback mode uses FTS_FALLBACK_WEIGHTS (exact: 0.3, fuzzy: 0.1, vector: 0.6)
# FTS + vector hybrid mode (exact: 0.3, fuzzy: 0.1, vector: 0.6)
def __init__(
self,
@@ -119,7 +109,6 @@ class HybridSearchEngine:
enable_fuzzy: bool = True,
enable_vector: bool = False,
pure_vector: bool = False,
enable_splade: bool = False,
enable_lsp_graph: bool = False,
lsp_max_depth: int = 1,
lsp_max_nodes: int = 20,
@@ -133,7 +122,6 @@ class HybridSearchEngine:
enable_fuzzy: Enable fuzzy FTS search (default True)
enable_vector: Enable vector search (default False)
pure_vector: If True, only use vector search without FTS fallback (default False)
enable_splade: If True, force SPLADE sparse neural search (default False)
enable_lsp_graph: If True, enable real-time LSP graph expansion (default False)
lsp_max_depth: Maximum depth for LSP graph BFS expansion (default 1)
lsp_max_nodes: Maximum nodes to collect in LSP graph (default 20)
@@ -150,9 +138,6 @@ class HybridSearchEngine:
>>> results = engine.search(Path("project/_index.db"),
... "how to authenticate users",
... enable_vector=True, pure_vector=True)
>>> # SPLADE sparse neural search
>>> results = engine.search(Path("project/_index.db"), "auth flow",
... enable_splade=True, enable_vector=True)
>>> # With LSP graph expansion (real-time)
>>> results = engine.search(Path("project/_index.db"), "auth flow",
... enable_vector=True, enable_lsp_graph=True)
@@ -180,26 +165,6 @@ class HybridSearchEngine:
# Determine which backends to use
backends = {}
# Check if SPLADE is available
splade_available = False
# Respect config.enable_splade flag and use_fts_fallback flag
if self._config and getattr(self._config, 'use_fts_fallback', False):
# Config explicitly requests FTS fallback - disable SPLADE
splade_available = False
elif self._config and not getattr(self._config, 'enable_splade', True):
# Config explicitly disabled SPLADE
splade_available = False
else:
# Check if SPLADE dependencies are available
try:
from codexlens.semantic.splade_encoder import check_splade_available
ok, _ = check_splade_available()
if ok:
# SPLADE tables are in main index database, will check table existence in _search_splade
splade_available = True
except Exception:
pass
if pure_vector:
# Pure vector mode: only use vector search, no FTS fallback
if enable_vector:
@@ -212,37 +177,13 @@ class HybridSearchEngine:
"To use pure vector search, enable vector search mode."
)
backends["exact"] = True
elif enable_splade:
# Explicit SPLADE mode requested via CLI --method splade
if splade_available:
backends["splade"] = True
if enable_vector:
backends["vector"] = True
else:
# SPLADE requested but not available - warn and fallback
self.logger.warning(
"SPLADE search requested but not available. "
"Falling back to FTS. Run 'codexlens index splade' to enable."
)
backends["exact"] = True
if enable_fuzzy:
backends["fuzzy"] = True
if enable_vector:
backends["vector"] = True
else:
# Hybrid mode: default to SPLADE if available, otherwise use FTS
if splade_available:
# Default: enable SPLADE, disable exact and fuzzy
backends["splade"] = True
if enable_vector:
backends["vector"] = True
else:
# Fallback mode: enable exact+fuzzy when SPLADE unavailable
backends["exact"] = True
if enable_fuzzy:
backends["fuzzy"] = True
if enable_vector:
backends["vector"] = True
# Standard hybrid mode: FTS + optional vector
backends["exact"] = True
if enable_fuzzy:
backends["fuzzy"] = True
if enable_vector:
backends["vector"] = True
# Add LSP graph expansion if requested and available
if enable_lsp_graph and HAS_LSP:
@@ -502,13 +443,6 @@ class HybridSearchEngine:
)
future_to_source[future] = "vector"
if backends.get("splade"):
submit_times["splade"] = time.perf_counter()
future = executor.submit(
self._search_splade, index_path, query, limit
)
future_to_source[future] = "splade"
if backends.get("lsp_graph"):
submit_times["lsp_graph"] = time.perf_counter()
future = executor.submit(
@@ -599,8 +533,7 @@ class HybridSearchEngine:
def _find_vectors_hnsw(self, index_path: Path) -> Optional[Path]:
"""Find the centralized _vectors.hnsw file by traversing up from index_path.
Similar to _search_splade's approach, this method searches for the
centralized dense vector index file in parent directories.
Searches for the centralized dense vector index file in parent directories.
Args:
index_path: Path to the current _index.db file
@@ -1138,124 +1071,6 @@ class HybridSearchEngine:
self.logger.error("Vector search error: %s", exc)
return []
def _search_splade(
self, index_path: Path, query: str, limit: int
) -> List[SearchResult]:
"""SPLADE sparse retrieval via inverted index.
Args:
index_path: Path to _index.db file
query: Natural language query string
limit: Maximum results
Returns:
List of SearchResult ordered by SPLADE score
"""
try:
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
from codexlens.storage.splade_index import SpladeIndex
from codexlens.config import SPLADE_DB_NAME
import sqlite3
import json
# Check dependencies
ok, err = check_splade_available()
if not ok:
self.logger.debug("SPLADE not available: %s", err)
return []
# SPLADE index is stored in _splade.db at the project index root
# Traverse up from the current index to find the root _splade.db
current_dir = index_path.parent
splade_db_path = None
for _ in range(10): # Limit search depth
candidate = current_dir / SPLADE_DB_NAME
if candidate.exists():
splade_db_path = candidate
break
parent = current_dir.parent
if parent == current_dir: # Reached root
break
current_dir = parent
if not splade_db_path:
self.logger.debug("SPLADE index not found in ancestor directories of %s", index_path)
return []
splade_index = SpladeIndex(splade_db_path)
if not splade_index.has_index():
self.logger.debug("SPLADE index not initialized")
return []
# Encode query to sparse vector
encoder = get_splade_encoder(use_gpu=self._use_gpu)
query_sparse = encoder.encode_text(query)
# Search inverted index for top matches
raw_results = splade_index.search(query_sparse, limit=limit, min_score=0.0)
if not raw_results:
return []
# Fetch chunk details from splade_chunks table (self-contained)
chunk_ids = [chunk_id for chunk_id, _ in raw_results]
score_map = {chunk_id: score for chunk_id, score in raw_results}
# Get chunk metadata from SPLADE database
rows = splade_index.get_chunks_by_ids(chunk_ids)
# Build SearchResult objects
results = []
for row in rows:
chunk_id = row["id"]
file_path = row["file_path"]
content = row["content"]
metadata_json = row["metadata"]
metadata = json.loads(metadata_json) if metadata_json else {}
score = score_map.get(chunk_id, 0.0)
# Build excerpt (short preview)
excerpt = content[:200] + "..." if len(content) > 200 else content
# Extract symbol information from metadata
symbol_name = metadata.get("symbol_name")
symbol_kind = metadata.get("symbol_kind")
start_line = metadata.get("start_line")
end_line = metadata.get("end_line")
# Build Symbol object if we have symbol info
symbol = None
if symbol_name and symbol_kind and start_line and end_line:
try:
from codexlens.entities import Symbol
symbol = Symbol(
name=symbol_name,
kind=symbol_kind,
range=(start_line, end_line)
)
except Exception:
pass
results.append(SearchResult(
path=file_path,
score=score,
excerpt=excerpt,
content=content,
symbol=symbol,
metadata=metadata,
start_line=start_line,
end_line=end_line,
symbol_name=symbol_name,
symbol_kind=symbol_kind,
))
return results
except Exception as exc:
self.logger.debug("SPLADE search error: %s", exc)
return []
def _search_lsp_graph(
self,
index_path: Path,
@@ -1295,21 +1110,14 @@ class HybridSearchEngine:
if seeds:
seed_source = "vector"
# 2. Fallback to SPLADE if vector returns nothing
# 2. Fallback to exact FTS if vector returns nothing
if not seeds:
self.logger.debug("Vector search returned no seeds, trying SPLADE")
seeds = self._search_splade(index_path, query, limit=3)
if seeds:
seed_source = "splade"
# 3. Fallback to exact FTS if SPLADE also fails
if not seeds:
self.logger.debug("SPLADE returned no seeds, trying exact FTS")
self.logger.debug("Vector search returned no seeds, trying exact FTS")
seeds = self._search_exact(index_path, query, limit=3)
if seeds:
seed_source = "exact_fts"
# 4. No seeds available from any source
# 3. No seeds available from any source
if not seeds:
self.logger.debug("No seed results available for LSP graph expansion")
return []

View File

@@ -1,7 +1,7 @@
"""Ranking algorithms for hybrid search result fusion.
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
for combining results from heterogeneous search backends (exact FTS, fuzzy FTS, vector search).
"""
from __future__ import annotations
@@ -15,19 +15,12 @@ from typing import Any, Dict, List, Optional
from codexlens.entities import SearchResult, AdditionalLocation
# Default RRF weights for SPLADE-based hybrid search
# Default RRF weights for hybrid search
DEFAULT_WEIGHTS = {
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
"vector": 0.5,
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
}
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
FTS_FALLBACK_WEIGHTS = {
"exact": 0.25,
"fuzzy": 0.1,
"vector": 0.5,
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
"lsp_graph": 0.15,
}
@@ -105,22 +98,13 @@ def adjust_weights_by_intent(
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Adjust RRF weights based on query intent."""
# Check if using SPLADE or FTS mode
use_splade = "splade" in base_weights
if intent == QueryIntent.KEYWORD:
if use_splade:
target = {"splade": 0.6, "vector": 0.4}
else:
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
elif intent == QueryIntent.SEMANTIC:
if use_splade:
target = {"splade": 0.3, "vector": 0.7}
else:
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
else:
target = dict(base_weights)
# Filter to active backends
keys = list(base_weights.keys())
filtered = {k: float(target.get(k, 0.0)) for k in keys}
@@ -225,7 +209,7 @@ def simple_weighted_fusion(
Args:
results_map: Dictionary mapping source name to list of SearchResult objects
Sources: 'exact', 'fuzzy', 'vector', 'splade'
Sources: 'exact', 'fuzzy', 'vector'
weights: Dictionary mapping source name to weight (default: equal weights)
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
@@ -331,14 +315,11 @@ def reciprocal_rank_fusion(
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
Supports three-way fusion with FTS, Vector, and SPLADE sources.
Args:
results_map: Dictionary mapping source name to list of SearchResult objects
Sources: 'exact', 'fuzzy', 'vector', 'splade'
Sources: 'exact', 'fuzzy', 'vector'
weights: Dictionary mapping source name to weight (default: equal weights)
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
Or: {'splade': 0.4, 'vector': 0.6}
k: Constant to avoid division by zero and control rank influence (default 60)
Returns:
@@ -349,14 +330,6 @@ def reciprocal_rank_fusion(
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
>>> fused = reciprocal_rank_fusion(results_map)
# Three-way fusion with SPLADE
>>> results_map = {
... 'exact': exact_results,
... 'vector': vector_results,
... 'splade': splade_results
... }
>>> fused = reciprocal_rank_fusion(results_map, k=60)
"""
if not results_map:
return []

View File

@@ -1,225 +0,0 @@
# SPLADE Encoder Implementation
## Overview
Created `splade_encoder.py` - A complete ONNX-optimized SPLADE sparse encoder for code search.
## File Location
`src/codexlens/semantic/splade_encoder.py` (474 lines)
## Key Components
### 1. Dependency Checking
**Function**: `check_splade_available() -> Tuple[bool, Optional[str]]`
- Validates numpy, onnxruntime, optimum, transformers availability
- Returns (True, None) if all dependencies present
- Returns (False, error_message) with install instructions if missing
### 2. Caching System
**Global Cache**: Thread-safe singleton pattern
- `_splade_cache: Dict[str, SpladeEncoder]` - Global encoder cache
- `_cache_lock: threading.RLock()` - Thread safety lock
**Factory Function**: `get_splade_encoder(...) -> SpladeEncoder`
- Cache key includes: model_name, gpu/cpu, max_length, sparsity_threshold
- Pre-loads model on first access
- Returns cached instance on subsequent calls
**Cleanup Function**: `clear_splade_cache() -> None`
- Releases ONNX resources
- Clears model and tokenizer references
- Prevents memory leaks
### 3. SpladeEncoder Class
#### Initialization Parameters
- `model_name: str` - Default: "naver/splade-cocondenser-ensembledistil"
- `use_gpu: bool` - Enable GPU acceleration (default: True)
- `max_length: int` - Max sequence length (default: 512)
- `sparsity_threshold: float` - Min weight threshold (default: 0.01)
- `providers: Optional[List]` - Explicit ONNX providers (overrides use_gpu)
#### Core Methods
**`_load_model()`**: Lazy loading with GPU support
- Uses `optimum.onnxruntime.ORTModelForMaskedLM`
- Falls back to CPU if GPU unavailable
- Integrates with `gpu_support.get_optimal_providers()`
- Handles device_id options for DirectML/CUDA
**`_splade_activation(logits, attention_mask)`**: Static method
- Formula: `log(1 + ReLU(logits)) * attention_mask`
- Input: (batch, seq_len, vocab_size)
- Output: (batch, seq_len, vocab_size)
**`_max_pooling(splade_repr)`**: Static method
- Max pooling over sequence dimension
- Input: (batch, seq_len, vocab_size)
- Output: (batch, vocab_size)
**`_to_sparse_dict(dense_vec)`**: Conversion helper
- Filters by sparsity_threshold
- Returns: `Dict[int, float]` mapping token_id to weight
**`encode_text(text: str) -> Dict[int, float]`**: Single text encoding
- Tokenizes input with truncation/padding
- Forward pass through ONNX model
- Applies SPLADE activation + max pooling
- Returns sparse vector
**`encode_batch(texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]`**: Batch encoding
- Processes in batches for memory efficiency
- Returns list of sparse vectors
#### Properties
**`vocab_size: int`**: Vocabulary size (~30k for BERT)
- Cached after first model load
- Returns tokenizer length
#### Debugging Methods
**`get_token(token_id: int) -> str`**
- Converts token_id to human-readable string
- Uses tokenizer.decode()
**`get_top_tokens(sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]`**
- Extracts top-k highest-weight tokens
- Returns (token_string, weight) pairs
- Useful for understanding model focus
## Design Patterns Followed
### 1. From `onnx_reranker.py`
✓ ONNX loading with provider detection
✓ Lazy model initialization
✓ Thread-safe loading with RLock
✓ Signature inspection for backward compatibility
✓ Fallback for older Optimum versions
✓ Static helper methods for numerical operations
### 2. From `embedder.py`
✓ Global cache with thread safety
✓ Factory function pattern (get_splade_encoder)
✓ Cache cleanup function (clear_splade_cache)
✓ GPU provider configuration
✓ Batch processing support
### 3. From `gpu_support.py`
`get_optimal_providers(use_gpu, with_device_options=True)`
✓ Device ID options for DirectML/CUDA
✓ Provider tuple format: (provider_name, options_dict)
## SPLADE Algorithm
### Activation Formula
```python
# Step 1: ReLU activation
relu_logits = max(0, logits)
# Step 2: Log(1 + x) transformation
log_relu = log(1 + relu_logits)
# Step 3: Apply attention mask
splade_repr = log_relu * attention_mask
# Step 4: Max pooling over sequence
splade_vec = max(splade_repr, axis=sequence_length)
# Step 5: Sparsification by threshold
sparse_dict = {token_id: weight for token_id, weight in enumerate(splade_vec) if weight > threshold}
```
### Output Format
- Sparse dictionary: `{token_id: weight}`
- Token IDs: 0 to vocab_size-1 (typically ~30,000)
- Weights: Float values > sparsity_threshold
- Interpretable: Can decode token_ids to strings
## Integration Points
### With `splade_index.py`
- `SpladeIndex.add_posting(chunk_id, sparse_vec: Dict[int, float])`
- `SpladeIndex.search(query_sparse: Dict[int, float])`
- Encoder produces the sparse vectors consumed by index
### With Indexing Pipeline
```python
encoder = get_splade_encoder(use_gpu=True)
# Single document
sparse_vec = encoder.encode_text("def main():\n print('hello')")
index.add_posting(chunk_id=1, sparse_vec=sparse_vec)
# Batch indexing
texts = ["code chunk 1", "code chunk 2", ...]
sparse_vecs = encoder.encode_batch(texts, batch_size=64)
postings = [(chunk_id, vec) for chunk_id, vec in enumerate(sparse_vecs)]
index.add_postings_batch(postings)
```
### With Search Pipeline
```python
encoder = get_splade_encoder(use_gpu=True)
query_sparse = encoder.encode_text("authentication function")
results = index.search(query_sparse, limit=50, min_score=0.5)
```
## Dependencies
Required packages:
- `numpy` - Numerical operations
- `onnxruntime` - ONNX model execution (CPU)
- `onnxruntime-gpu` - ONNX with GPU support (optional)
- `optimum[onnxruntime]` - Hugging Face ONNX optimization
- `transformers` - Tokenizer and model loading
Install command:
```bash
# CPU only
pip install numpy onnxruntime optimum[onnxruntime] transformers
# With GPU support
pip install numpy onnxruntime-gpu optimum[onnxruntime-gpu] transformers
```
## Testing Status
✓ Python syntax validation passed
✓ Module import successful
✓ Dependency checking works correctly
✗ Full functional test pending (requires optimum installation)
## Next Steps
1. Install dependencies for functional testing
2. Create unit tests in `tests/semantic/test_splade_encoder.py`
3. Benchmark encoding performance (CPU vs GPU)
4. Integrate with codex-lens indexing pipeline
5. Add SPLADE option to semantic search configuration
## Performance Considerations
### Memory Usage
- Model size: ~100MB (ONNX optimized)
- Sparse vectors: ~100-500 non-zero entries per document
- Batch size: 32 recommended (adjust based on GPU memory)
### Speed Benchmarks (Expected)
- CPU encoding: ~10-20 docs/sec
- GPU encoding (CUDA): ~100-200 docs/sec
- GPU encoding (DirectML): ~50-100 docs/sec
### Sparsity Analysis
- Threshold 0.01: ~200-400 tokens per document
- Threshold 0.05: ~100-200 tokens per document
- Threshold 0.10: ~50-100 tokens per document
## References
- SPLADE paper: https://arxiv.org/abs/2107.05720
- SPLADE v2: https://arxiv.org/abs/2109.10086
- Naver model: https://huggingface.co/naver/splade-cocondenser-ensembledistil

View File

@@ -1,567 +0,0 @@
"""ONNX-optimized SPLADE sparse encoder for code search.
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
that combine the interpretability of BM25 with neural relevance modeling.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
Install (GPU):
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
"""
from __future__ import annotations
import logging
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
def check_splade_available() -> Tuple[bool, Optional[str]]:
"""Check whether SPLADE dependencies are available.
Returns:
Tuple of (available: bool, error_message: Optional[str])
"""
try:
import numpy # noqa: F401
except ImportError as exc:
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc:
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
except ImportError as exc:
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc:
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
# Global cache for SPLADE encoders (singleton pattern)
_splade_cache: Dict[str, "SpladeEncoder"] = {}
_cache_lock = threading.RLock()
def get_splade_encoder(
model_name: str = "naver/splade-cocondenser-ensembledistil",
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
cache_dir: Optional[str] = None,
) -> "SpladeEncoder":
"""Get or create cached SPLADE encoder (thread-safe singleton).
This function provides significant performance improvement by reusing
SpladeEncoder instances across multiple searches, avoiding repeated model
loading overhead.
Args:
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
Returns:
Cached SpladeEncoder instance for the given configuration
"""
global _splade_cache
# Cache key includes all configuration parameters
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
with _cache_lock:
encoder = _splade_cache.get(cache_key)
if encoder is not None:
return encoder
# Create new encoder and cache it
encoder = SpladeEncoder(
model_name=model_name,
use_gpu=use_gpu,
max_length=max_length,
sparsity_threshold=sparsity_threshold,
cache_dir=cache_dir,
)
# Pre-load model to ensure it's ready
encoder._load_model()
_splade_cache[cache_key] = encoder
return encoder
def clear_splade_cache() -> None:
"""Clear the SPLADE encoder cache and release ONNX resources.
This method ensures proper cleanup of ONNX model resources to prevent
memory leaks when encoders are no longer needed.
"""
global _splade_cache
with _cache_lock:
# Release ONNX resources before clearing cache
for encoder in _splade_cache.values():
if encoder._model is not None:
del encoder._model
encoder._model = None
if encoder._tokenizer is not None:
del encoder._tokenizer
encoder._tokenizer = None
_splade_cache.clear()
class SpladeEncoder:
"""ONNX-optimized SPLADE sparse encoder.
Produces sparse vectors with vocabulary-aligned dimensions.
Output: Dict[int, float] mapping token_id to weight.
SPLADE activation formula:
splade_repr = log(1 + ReLU(logits)) * attention_mask
splade_vec = max_pooling(splade_repr, axis=sequence_length)
References:
- SPLADE: https://arxiv.org/abs/2107.05720
- SPLADE v2: https://arxiv.org/abs/2109.10086
"""
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
def __init__(
self,
model_name: str = DEFAULT_MODEL,
use_gpu: bool = True,
max_length: int = 512,
sparsity_threshold: float = 0.01,
providers: Optional[List[Any]] = None,
cache_dir: Optional[str] = None,
) -> None:
"""Initialize SPLADE encoder.
Args:
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
use_gpu: If True, use GPU acceleration when available
max_length: Maximum sequence length for tokenization
sparsity_threshold: Minimum weight to include in sparse vector
providers: Explicit ONNX providers list (overrides use_gpu)
cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade)
"""
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.max_length = int(max_length) if max_length > 0 else 512
self.sparsity_threshold = float(sparsity_threshold)
self.providers = providers
# Setup ONNX cache directory
if cache_dir:
self._cache_dir = Path(cache_dir)
else:
self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade"
self._tokenizer: Any | None = None
self._model: Any | None = None
self._vocab_size: int | None = None
self._lock = threading.RLock()
def _get_local_cache_path(self) -> Path:
"""Get local cache path for this model's ONNX files.
Returns:
Path to the local ONNX cache directory for this model
"""
# Replace / with -- for filesystem-safe naming
safe_name = self.model_name.replace("/", "--")
return self._cache_dir / safe_name
def _load_model(self) -> None:
"""Lazy load ONNX model and tokenizer.
First checks local cache for ONNX model, falling back to
HuggingFace download and conversion if not cached.
"""
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_splade_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForMaskedLM
from transformers import AutoTokenizer
if self.providers is None:
from .gpu_support import get_optimal_providers, get_selected_device_id
# Get providers as pure string list (cache-friendly)
# NOTE: with_device_options=False to avoid tuple-based providers
# which break optimum's caching mechanism
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=False
)
# Get device_id separately for provider_options
self._device_id = get_selected_device_id() if self.use_gpu else None
# Some Optimum versions accept `providers`, others accept a single `provider`
# Prefer passing the full providers list, with a conservative fallback
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
# Pass device_id via provider_options for GPU selection
if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None:
# Build provider_options dict for each GPU provider
provider_options = {}
for p in self.providers:
if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"):
provider_options[p] = {"device_id": self._device_id}
if provider_options:
model_kwargs["provider_options"] = provider_options
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception as e:
logger.debug(f"Failed to inspect ORTModel signature: {e}")
model_kwargs = {}
# Check for local ONNX cache first
local_cache = self._get_local_cache_path()
onnx_model_path = local_cache / "model.onnx"
if onnx_model_path.exists():
# Load from local cache
logger.info(f"Loading SPLADE from local cache: {local_cache}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
str(local_cache),
**model_kwargs,
)
self._tokenizer = AutoTokenizer.from_pretrained(
str(local_cache), use_fast=True
)
self._vocab_size = len(self._tokenizer)
logger.info(
f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}"
)
return
except Exception as e:
logger.warning(f"Failed to load from cache, redownloading: {e}")
# Download and convert from HuggingFace
logger.info(f"Downloading SPLADE model: {self.model_name}")
try:
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True, # Export to ONNX
**model_kwargs,
)
logger.debug(f"SPLADE model loaded: {self.model_name}")
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments
self._model = ORTModelForMaskedLM.from_pretrained(
self.model_name,
export=True,
)
logger.warning(
"Optimum version doesn't support provider parameters. "
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache vocabulary size
self._vocab_size = len(self._tokenizer)
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
# Save to local cache for future use
try:
local_cache.mkdir(parents=True, exist_ok=True)
self._model.save_pretrained(str(local_cache))
self._tokenizer.save_pretrained(str(local_cache))
logger.info(f"SPLADE model cached to: {local_cache}")
except Exception as e:
logger.warning(f"Failed to cache SPLADE model: {e}")
@staticmethod
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
"""Apply SPLADE activation function to model outputs.
Formula: log(1 + ReLU(logits)) * attention_mask
Args:
logits: Model output logits (batch, seq_len, vocab_size)
attention_mask: Attention mask (batch, seq_len)
Returns:
SPLADE representations (batch, seq_len, vocab_size)
"""
import numpy as np
# ReLU activation
relu_logits = np.maximum(0, logits)
# Log(1 + x) transformation
log_relu = np.log1p(relu_logits)
# Apply attention mask (expand to match vocab dimension)
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
mask_expanded = np.expand_dims(attention_mask, axis=-1)
# Element-wise multiplication
splade_repr = log_relu * mask_expanded
return splade_repr
@staticmethod
def _max_pooling(splade_repr: Any) -> Any:
"""Max pooling over sequence length dimension.
Args:
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
Returns:
Pooled sparse vectors (batch, vocab_size)
"""
import numpy as np
# Max pooling over sequence dimension (axis=1)
return np.max(splade_repr, axis=1)
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
"""Convert dense vector to sparse dictionary.
Args:
dense_vec: Dense vector (vocab_size,)
Returns:
Sparse dictionary {token_id: weight} with weights above threshold
"""
import numpy as np
# Find non-zero indices above threshold
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
# Create sparse dictionary
sparse_dict = {
int(idx): float(dense_vec[idx])
for idx in nonzero_indices
}
return sparse_dict
def warmup(self, text: str = "warmup query") -> None:
"""Warmup the encoder by running a dummy inference.
First-time model inference includes initialization overhead.
Call this method once before the first real search to avoid
latency spikes.
Args:
text: Dummy text for warmup (default: "warmup query")
"""
logger.info("Warming up SPLADE encoder...")
# Trigger model loading and first inference
_ = self.encode_text(text)
logger.info("SPLADE encoder warmup complete")
def encode_text(self, text: str) -> Dict[int, float]:
"""Encode text to sparse vector {token_id: weight}.
Args:
text: Input text to encode
Returns:
Sparse vector as dictionary mapping token_id to weight
"""
self._load_model()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
import numpy as np
# Tokenize input
encoded = self._tokenizer(
text,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="np",
)
# Forward pass through model
outputs = self._model(**encoded)
# Extract logits
if hasattr(outputs, "logits"):
logits = outputs.logits
elif isinstance(outputs, dict) and "logits" in outputs:
logits = outputs["logits"]
elif isinstance(outputs, (list, tuple)) and outputs:
logits = outputs[0]
else:
raise RuntimeError("Unexpected model output format")
# Apply SPLADE activation
attention_mask = encoded["attention_mask"]
splade_repr = self._splade_activation(logits, attention_mask)
# Max pooling over sequence length
splade_vec = self._max_pooling(splade_repr)
# Convert to sparse dictionary (single item batch)
sparse_dict = self._to_sparse_dict(splade_vec[0])
return sparse_dict
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
"""Batch encode texts to sparse vectors.
Args:
texts: List of input texts to encode
batch_size: Batch size for encoding (default: 32)
Returns:
List of sparse vectors as dictionaries
"""
if not texts:
return []
self._load_model()
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
import numpy as np
results: List[Dict[int, float]] = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenize batch
encoded = self._tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="np",
)
# Forward pass through model
outputs = self._model(**encoded)
# Extract logits
if hasattr(outputs, "logits"):
logits = outputs.logits
elif isinstance(outputs, dict) and "logits" in outputs:
logits = outputs["logits"]
elif isinstance(outputs, (list, tuple)) and outputs:
logits = outputs[0]
else:
raise RuntimeError("Unexpected model output format")
# Apply SPLADE activation
attention_mask = encoded["attention_mask"]
splade_repr = self._splade_activation(logits, attention_mask)
# Max pooling over sequence length
splade_vecs = self._max_pooling(splade_repr)
# Convert each vector to sparse dictionary
for vec in splade_vecs:
sparse_dict = self._to_sparse_dict(vec)
results.append(sparse_dict)
return results
@property
def vocab_size(self) -> int:
"""Return vocabulary size (~30k for BERT-based models).
Returns:
Vocabulary size (number of tokens in tokenizer)
"""
if self._vocab_size is not None:
return self._vocab_size
self._load_model()
return self._vocab_size or 0
def get_token(self, token_id: int) -> str:
"""Convert token_id to string (for debugging).
Args:
token_id: Token ID to convert
Returns:
Token string
"""
self._load_model()
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded")
return self._tokenizer.decode([token_id])
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
"""Get top-k tokens with highest weights from sparse vector.
Useful for debugging and understanding what the model is focusing on.
Args:
sparse_vec: Sparse vector as {token_id: weight}
top_k: Number of top tokens to return
Returns:
List of (token_string, weight) tuples, sorted by weight descending
"""
self._load_model()
if not sparse_vec:
return []
# Sort by weight descending
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
# Take top-k and convert token_ids to strings
top_items = sorted_items[:top_k]
return [
(self.get_token(token_id), weight)
for token_id, weight in top_items
]

View File

@@ -1,103 +0,0 @@
"""
Migration 009: Add SPLADE sparse retrieval tables.
This migration introduces SPLADE (Sparse Lexical AnD Expansion) support:
- splade_metadata: Model configuration (model name, vocab size, ONNX path)
- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight)
The SPLADE tables are designed for efficient sparse vector retrieval:
- Token-based lookup for query expansion
- Chunk-based deletion for index maintenance
- Maintains backward compatibility with existing FTS tables
"""
import logging
from sqlite3 import Connection
log = logging.getLogger(__name__)
def upgrade(db_conn: Connection) -> None:
"""
Adds SPLADE tables for sparse retrieval.
Creates:
- splade_metadata: Stores model configuration and ONNX path
- splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings
- Indexes for efficient token-based and chunk-based lookups
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Creating splade_metadata table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS splade_metadata (
id INTEGER PRIMARY KEY DEFAULT 1,
model_name TEXT NOT NULL,
vocab_size INTEGER NOT NULL,
onnx_path TEXT,
created_at REAL
)
"""
)
log.info("Creating splade_posting_list table...")
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS splade_posting_list (
token_id INTEGER NOT NULL,
chunk_id INTEGER NOT NULL,
weight REAL NOT NULL,
PRIMARY KEY (token_id, chunk_id),
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
)
"""
)
log.info("Creating indexes for splade_posting_list...")
# Index for efficient chunk-based lookups (deletion, updates)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
ON splade_posting_list(chunk_id)
"""
)
# Index for efficient term-based retrieval
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_splade_by_token
ON splade_posting_list(token_id)
"""
)
log.info("Migration 009 completed successfully")
def downgrade(db_conn: Connection) -> None:
"""
Removes SPLADE tables.
Drops:
- splade_posting_list (and associated indexes)
- splade_metadata
Args:
db_conn: The SQLite database connection.
"""
cursor = db_conn.cursor()
log.info("Dropping SPLADE indexes...")
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk")
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token")
log.info("Dropping splade_posting_list table...")
cursor.execute("DROP TABLE IF EXISTS splade_posting_list")
log.info("Dropping splade_metadata table...")
cursor.execute("DROP TABLE IF EXISTS splade_metadata")
log.info("Migration 009 downgrade completed successfully")

View File

@@ -1,578 +0,0 @@
"""SPLADE inverted index storage for sparse vector retrieval.
This module implements SQLite-based inverted index for SPLADE sparse vectors,
enabling efficient sparse retrieval using dot-product scoring.
"""
from __future__ import annotations
import logging
import sqlite3
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from codexlens.entities import SearchResult
from codexlens.errors import StorageError
logger = logging.getLogger(__name__)
class SpladeIndex:
"""SQLite-based inverted index for SPLADE sparse vectors.
Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight).
Supports efficient dot-product retrieval using SQL joins.
"""
def __init__(self, db_path: Path | str) -> None:
"""Initialize SPLADE index.
Args:
db_path: Path to SQLite database file.
"""
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
# Thread-safe connection management
self._lock = threading.RLock()
self._local = threading.local()
def _get_connection(self) -> sqlite3.Connection:
"""Get or create a thread-local database connection.
Each thread gets its own connection to ensure thread safety.
Connections are stored in thread-local storage.
"""
conn = getattr(self._local, "conn", None)
if conn is None:
# Thread-local connection - each thread has its own
conn = sqlite3.connect(
self.db_path,
timeout=30.0, # Wait up to 30s for locks
check_same_thread=True, # Enforce thread safety
)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA foreign_keys=ON")
# Limit mmap to 1GB to avoid OOM on smaller systems
conn.execute("PRAGMA mmap_size=1073741824")
# Increase cache size for better query performance (20MB = -20000 pages)
conn.execute("PRAGMA cache_size=-20000")
self._local.conn = conn
return conn
def close(self) -> None:
"""Close thread-local database connection."""
with self._lock:
conn = getattr(self._local, "conn", None)
if conn is not None:
conn.close()
self._local.conn = None
def __enter__(self) -> SpladeIndex:
"""Context manager entry."""
self.create_tables()
return self
def __exit__(self, exc_type, exc, tb) -> None:
"""Context manager exit."""
self.close()
def has_index(self) -> bool:
"""Check if SPLADE tables exist in database.
Returns:
True if tables exist, False otherwise.
"""
with self._lock:
conn = self._get_connection()
try:
cursor = conn.execute(
"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='splade_posting_list'
"""
)
return cursor.fetchone() is not None
except sqlite3.Error as e:
logger.error("Failed to check index existence: %s", e)
return False
def create_tables(self) -> None:
"""Create SPLADE schema if not exists.
Note: When used with distributed indexes (multiple _index.db files),
the SPLADE database stores chunk IDs from multiple sources. In this case,
foreign key constraints are not enforced to allow cross-database references.
"""
with self._lock:
conn = self._get_connection()
try:
# Inverted index for sparse vectors
# Note: No FOREIGN KEY constraint to support distributed index architecture
# where chunks may come from multiple _index.db files
conn.execute("""
CREATE TABLE IF NOT EXISTS splade_posting_list (
token_id INTEGER NOT NULL,
chunk_id INTEGER NOT NULL,
weight REAL NOT NULL,
PRIMARY KEY (token_id, chunk_id)
)
""")
# Indexes for efficient lookups
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
ON splade_posting_list(chunk_id)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_splade_by_token
ON splade_posting_list(token_id)
""")
# Model metadata
conn.execute("""
CREATE TABLE IF NOT EXISTS splade_metadata (
id INTEGER PRIMARY KEY DEFAULT 1,
model_name TEXT NOT NULL,
vocab_size INTEGER NOT NULL,
onnx_path TEXT,
created_at REAL
)
""")
# Chunk metadata for self-contained search results
# Stores all chunk info needed to build SearchResult without querying _index.db
conn.execute("""
CREATE TABLE IF NOT EXISTS splade_chunks (
id INTEGER PRIMARY KEY,
file_path TEXT NOT NULL,
content TEXT NOT NULL,
metadata TEXT,
source_db TEXT
)
""")
conn.commit()
logger.debug("SPLADE schema created successfully")
except sqlite3.Error as e:
raise StorageError(
f"Failed to create SPLADE schema: {e}",
db_path=str(self.db_path),
operation="create_tables"
) from e
def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None:
"""Add a single document to inverted index.
Args:
chunk_id: Chunk ID (foreign key to semantic_chunks.id).
sparse_vec: Sparse vector as {token_id: weight} mapping.
"""
if not sparse_vec:
logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id)
return
with self._lock:
conn = self._get_connection()
try:
# Insert all non-zero weights for this chunk
postings = [
(token_id, chunk_id, weight)
for token_id, weight in sparse_vec.items()
if weight > 0 # Only store non-zero weights
]
if postings:
conn.executemany(
"""
INSERT OR REPLACE INTO splade_posting_list
(token_id, chunk_id, weight)
VALUES (?, ?, ?)
""",
postings
)
conn.commit()
logger.debug(
"Added %d postings for chunk_id=%d", len(postings), chunk_id
)
except sqlite3.Error as e:
raise StorageError(
f"Failed to add posting for chunk_id={chunk_id}: {e}",
db_path=str(self.db_path),
operation="add_posting"
) from e
def add_postings_batch(
self, postings: List[Tuple[int, Dict[int, float]]]
) -> None:
"""Batch insert postings for multiple chunks.
Args:
postings: List of (chunk_id, sparse_vec) tuples.
"""
if not postings:
return
with self._lock:
conn = self._get_connection()
try:
# Flatten all postings into single batch
batch_data = []
for chunk_id, sparse_vec in postings:
for token_id, weight in sparse_vec.items():
if weight > 0: # Only store non-zero weights
batch_data.append((token_id, chunk_id, weight))
if batch_data:
conn.executemany(
"""
INSERT OR REPLACE INTO splade_posting_list
(token_id, chunk_id, weight)
VALUES (?, ?, ?)
""",
batch_data
)
conn.commit()
logger.debug(
"Batch inserted %d postings for %d chunks",
len(batch_data),
len(postings)
)
except sqlite3.Error as e:
raise StorageError(
f"Failed to batch insert postings: {e}",
db_path=str(self.db_path),
operation="add_postings_batch"
) from e
def add_chunk_metadata(
self,
chunk_id: int,
file_path: str,
content: str,
metadata: Optional[str] = None,
source_db: Optional[str] = None
) -> None:
"""Store chunk metadata for self-contained search results.
Args:
chunk_id: Global chunk ID.
file_path: Path to source file.
content: Chunk text content.
metadata: JSON metadata string.
source_db: Path to source _index.db.
"""
with self._lock:
conn = self._get_connection()
try:
conn.execute(
"""
INSERT OR REPLACE INTO splade_chunks
(id, file_path, content, metadata, source_db)
VALUES (?, ?, ?, ?, ?)
""",
(chunk_id, file_path, content, metadata, source_db)
)
conn.commit()
except sqlite3.Error as e:
raise StorageError(
f"Failed to add chunk metadata for chunk_id={chunk_id}: {e}",
db_path=str(self.db_path),
operation="add_chunk_metadata"
) from e
def add_chunks_metadata_batch(
self,
chunks: List[Tuple[int, str, str, Optional[str], Optional[str]]]
) -> None:
"""Batch insert chunk metadata.
Args:
chunks: List of (chunk_id, file_path, content, metadata, source_db) tuples.
"""
if not chunks:
return
with self._lock:
conn = self._get_connection()
try:
conn.executemany(
"""
INSERT OR REPLACE INTO splade_chunks
(id, file_path, content, metadata, source_db)
VALUES (?, ?, ?, ?, ?)
""",
chunks
)
conn.commit()
logger.debug("Batch inserted %d chunk metadata records", len(chunks))
except sqlite3.Error as e:
raise StorageError(
f"Failed to batch insert chunk metadata: {e}",
db_path=str(self.db_path),
operation="add_chunks_metadata_batch"
) from e
def get_chunks_by_ids(self, chunk_ids: List[int]) -> List[Dict]:
"""Get chunk metadata by IDs.
Args:
chunk_ids: List of chunk IDs to retrieve.
Returns:
List of dicts with id, file_path, content, metadata, source_db.
"""
if not chunk_ids:
return []
with self._lock:
conn = self._get_connection()
try:
placeholders = ",".join("?" * len(chunk_ids))
rows = conn.execute(
f"""
SELECT id, file_path, content, metadata, source_db
FROM splade_chunks
WHERE id IN ({placeholders})
""",
chunk_ids
).fetchall()
return [
{
"id": row["id"],
"file_path": row["file_path"],
"content": row["content"],
"metadata": row["metadata"],
"source_db": row["source_db"]
}
for row in rows
]
except sqlite3.Error as e:
logger.error("Failed to get chunks by IDs: %s", e)
return []
def remove_chunk(self, chunk_id: int) -> int:
"""Remove all postings for a chunk.
Args:
chunk_id: Chunk ID to remove.
Returns:
Number of deleted postings.
"""
with self._lock:
conn = self._get_connection()
try:
cursor = conn.execute(
"DELETE FROM splade_posting_list WHERE chunk_id = ?",
(chunk_id,)
)
conn.commit()
deleted = cursor.rowcount
logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id)
return deleted
except sqlite3.Error as e:
raise StorageError(
f"Failed to remove chunk_id={chunk_id}: {e}",
db_path=str(self.db_path),
operation="remove_chunk"
) from e
def search(
self,
query_sparse: Dict[int, float],
limit: int = 50,
min_score: float = 0.0,
max_query_terms: int = 64
) -> List[Tuple[int, float]]:
"""Search for similar chunks using dot-product scoring.
Implements efficient sparse dot-product via SQL JOIN:
score(q, d) = sum(q[t] * d[t]) for all tokens t
Args:
query_sparse: Query sparse vector as {token_id: weight}.
limit: Maximum number of results.
min_score: Minimum score threshold.
max_query_terms: Maximum query terms to use (default: 64).
Pruning to top-K terms reduces search time with minimal impact on quality.
Set to 0 or negative to disable pruning (use all terms).
Returns:
List of (chunk_id, score) tuples, ordered by score descending.
"""
if not query_sparse:
logger.warning("Empty query sparse vector")
return []
with self._lock:
conn = self._get_connection()
try:
# Build VALUES clause for query terms
# Each term: (token_id, weight)
query_terms = [
(token_id, weight)
for token_id, weight in query_sparse.items()
if weight > 0
]
if not query_terms:
logger.warning("No non-zero query terms")
return []
# Query pruning: keep only top-K terms by weight
# max_query_terms <= 0 means no limit (use all terms)
if max_query_terms > 0 and len(query_terms) > max_query_terms:
query_terms = sorted(query_terms, key=lambda x: x[1], reverse=True)[:max_query_terms]
logger.debug(
"Query pruned from %d to %d terms",
len(query_sparse),
len(query_terms)
)
# Create CTE for query terms using parameterized VALUES
# Build placeholders and params to prevent SQL injection
params = []
placeholders = []
for token_id, weight in query_terms:
placeholders.append("(?, ?)")
params.extend([token_id, weight])
values_placeholders = ", ".join(placeholders)
sql = f"""
WITH query_terms(token_id, weight) AS (
VALUES {values_placeholders}
)
SELECT
p.chunk_id,
SUM(p.weight * q.weight) as score
FROM splade_posting_list p
INNER JOIN query_terms q ON p.token_id = q.token_id
GROUP BY p.chunk_id
HAVING score >= ?
ORDER BY score DESC
LIMIT ?
"""
# Append min_score and limit to params
params.extend([min_score, limit])
rows = conn.execute(sql, params).fetchall()
results = [(row["chunk_id"], float(row["score"])) for row in rows]
logger.debug(
"SPLADE search: %d query terms, %d results",
len(query_terms),
len(results)
)
return results
except sqlite3.Error as e:
raise StorageError(
f"SPLADE search failed: {e}",
db_path=str(self.db_path),
operation="search"
) from e
def get_metadata(self) -> Optional[Dict]:
"""Get SPLADE model metadata.
Returns:
Dictionary with model_name, vocab_size, onnx_path, created_at,
or None if not set.
"""
with self._lock:
conn = self._get_connection()
try:
row = conn.execute(
"""
SELECT model_name, vocab_size, onnx_path, created_at
FROM splade_metadata
WHERE id = 1
"""
).fetchone()
if row:
return {
"model_name": row["model_name"],
"vocab_size": row["vocab_size"],
"onnx_path": row["onnx_path"],
"created_at": row["created_at"]
}
return None
except sqlite3.Error as e:
logger.error("Failed to get metadata: %s", e)
return None
def set_metadata(
self,
model_name: str,
vocab_size: int,
onnx_path: Optional[str] = None
) -> None:
"""Set SPLADE model metadata.
Args:
model_name: SPLADE model name.
vocab_size: Vocabulary size (typically ~30k for BERT vocab).
onnx_path: Optional path to ONNX model file.
"""
with self._lock:
conn = self._get_connection()
try:
current_time = time.time()
conn.execute(
"""
INSERT OR REPLACE INTO splade_metadata
(id, model_name, vocab_size, onnx_path, created_at)
VALUES (1, ?, ?, ?, ?)
""",
(model_name, vocab_size, onnx_path, current_time)
)
conn.commit()
logger.info(
"Set SPLADE metadata: model=%s, vocab_size=%d",
model_name,
vocab_size
)
except sqlite3.Error as e:
raise StorageError(
f"Failed to set metadata: {e}",
db_path=str(self.db_path),
operation="set_metadata"
) from e
def get_stats(self) -> Dict:
"""Get index statistics.
Returns:
Dictionary with total_postings, unique_tokens, unique_chunks.
"""
with self._lock:
conn = self._get_connection()
try:
row = conn.execute("""
SELECT
COUNT(*) as total_postings,
COUNT(DISTINCT token_id) as unique_tokens,
COUNT(DISTINCT chunk_id) as unique_chunks
FROM splade_posting_list
""").fetchone()
return {
"total_postings": row["total_postings"],
"unique_tokens": row["unique_tokens"],
"unique_chunks": row["unique_chunks"]
}
except sqlite3.Error as e:
logger.error("Failed to get stats: %s", e)
return {
"total_postings": 0,
"unique_tokens": 0,
"unique_chunks": 0
}