From 71faaf43a867c269f7d03fb335031b3914f35e81 Mon Sep 17 00:00:00 2001 From: catlog22 Date: Sun, 8 Feb 2026 12:07:41 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=20SPLADE=20?= =?UTF-8?q?=E5=92=8C=20hybrid=5Fcascade=EF=BC=8C=E7=B2=BE=E7=AE=80?= =?UTF-8?q?=E6=90=9C=E7=B4=A2=E6=9E=B6=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 删除 SPLADE 稀疏神经搜索后端和 hybrid_cascade 策略, 将搜索架构从 6 种后端简化为 4 种(FTS Exact/Fuzzy, Binary Vector, Dense Vector, LSP)。 主要变更: - 删除 splade_encoder.py, splade_index.py, migration_009 等 4 个文件 - 移除 config.py 中 SPLADE 相关配置(enable_splade, splade_model 等) - DEFAULT_WEIGHTS 改为 FTS 权重 {exact:0.25, fuzzy:0.1, vector:0.5, lsp:0.15} - 删除 hybrid_cascade_search(),所有 cascade fallback 改为 self.search() - API fusion_strategy='hybrid' 向后兼容映射到 binary_rerank - 删除 CLI index_splade/splade_status 命令和 --method splade - 更新测试、基准测试和文档 --- codex-lens/benchmarks/analyze_methods.py | 42 +- codex-lens/benchmarks/cascade_benchmark.py | 2 +- .../benchmarks/compare_semantic_methods.py | 108 +--- .../method_contribution_analysis.py | 42 +- codex-lens/docs/CODEXLENS_LSP_API_SPEC.md | 2 +- codex-lens/pyproject.toml | 12 - codex-lens/src/codexlens/api/semantic.py | 12 +- codex-lens/src/codexlens/cli/commands.py | 565 +---------------- .../src/codexlens/cli/embedding_manager.py | 196 +----- codex-lens/src/codexlens/config.py | 18 +- .../src/codexlens/search/chain_search.py | 219 +------ .../src/codexlens/search/hybrid_search.py | 214 +------ codex-lens/src/codexlens/search/ranking.py | 43 +- .../semantic/SPLADE_IMPLEMENTATION.md | 225 ------- .../src/codexlens/semantic/splade_encoder.py | 567 ----------------- .../migrations/migration_009_add_splade.py | 103 ---- .../src/codexlens/storage/splade_index.py | 578 ------------------ codex-lens/tests/api/test_semantic_search.py | 8 +- .../test_lsp_search_integration.py | 23 +- .../tests/real/test_lsp_real_interface.py | 2 +- codex-lens/tests/test_chain_search.py | 8 - codex-lens/tests/test_staged_cascade.py | 20 +- 22 files changed, 126 insertions(+), 2883 deletions(-) delete mode 100644 codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md delete mode 100644 codex-lens/src/codexlens/semantic/splade_encoder.py delete mode 100644 codex-lens/src/codexlens/storage/migrations/migration_009_add_splade.py delete mode 100644 codex-lens/src/codexlens/storage/splade_index.py diff --git a/codex-lens/benchmarks/analyze_methods.py b/codex-lens/benchmarks/analyze_methods.py index fa51aa3b..9973d64c 100644 --- a/codex-lens/benchmarks/analyze_methods.py +++ b/codex-lens/benchmarks/analyze_methods.py @@ -12,7 +12,6 @@ from codexlens.search.ranking import ( reciprocal_rank_fusion, cross_encoder_rerank, DEFAULT_WEIGHTS, - FTS_FALLBACK_WEIGHTS, ) # Use index with most data @@ -65,12 +64,6 @@ with sqlite3.connect(index_path) as conn: non_null = semantic_count - null_count print(f" {col}: {non_null}/{semantic_count} non-null") - if "splade_posting_list" in tables: - splade_count = conn.execute("SELECT COUNT(*) FROM splade_posting_list").fetchone()[0] - print(f"\n splade_posting_list: {splade_count} postings") - else: - print("\n splade_posting_list: NOT EXISTS") - print("\n" + "=" * 60) print("2. METHOD CONTRIBUTION ANALYSIS") print("=" * 60) @@ -87,7 +80,6 @@ results_summary = { "fts_exact": [], "fts_fuzzy": [], "vector": [], - "splade": [], } for query in queries: @@ -95,10 +87,9 @@ for query in queries: # FTS Exact try: - engine = HybridSearchEngine(weights=FTS_FALLBACK_WEIGHTS) + engine = HybridSearchEngine(weights=DEFAULT_WEIGHTS) engine._config = type("obj", (object,), { "use_fts_fallback": True, - "enable_splade": False, "embedding_use_gpu": True, "symbol_boost_factor": 1.5, "enable_reranking": False, @@ -117,10 +108,9 @@ for query in queries: # FTS Fuzzy try: - engine = HybridSearchEngine(weights=FTS_FALLBACK_WEIGHTS) + engine = HybridSearchEngine(weights=DEFAULT_WEIGHTS) engine._config = type("obj", (object,), { "use_fts_fallback": True, - "enable_splade": False, "embedding_use_gpu": True, "symbol_boost_factor": 1.5, "enable_reranking": False, @@ -142,7 +132,6 @@ for query in queries: engine = HybridSearchEngine() engine._config = type("obj", (object,), { "use_fts_fallback": False, - "enable_splade": False, "embedding_use_gpu": True, "symbol_boost_factor": 1.5, "enable_reranking": False, @@ -159,28 +148,6 @@ for query in queries: except Exception as e: print(f" Vector: ERROR - {e}") - # SPLADE - try: - engine = HybridSearchEngine(weights={"splade": 1.0}) - engine._config = type("obj", (object,), { - "use_fts_fallback": False, - "enable_splade": True, - "embedding_use_gpu": True, - "symbol_boost_factor": 1.5, - "enable_reranking": False, - })() - - start = time.perf_counter() - results = engine.search(index_path, query, limit=10, enable_fuzzy=False, enable_vector=False) - latency = (time.perf_counter() - start) * 1000 - - results_summary["splade"].append({"count": len(results), "latency": latency}) - top_file = results[0].path.split("\\")[-1] if results else "N/A" - top_score = results[0].score if results else 0 - print(f" SPLADE: {len(results)} results, {latency:.1f}ms, top: {top_file} ({top_score:.3f})") - except Exception as e: - print(f" SPLADE: ERROR - {e}") - print("\n--- Summary ---") for method, data in results_summary.items(): if data: @@ -210,10 +177,9 @@ for query in test_queries: # Strategy 1: Standard Hybrid (FTS exact+fuzzy RRF) try: - engine = HybridSearchEngine(weights=FTS_FALLBACK_WEIGHTS) + engine = HybridSearchEngine(weights=DEFAULT_WEIGHTS) engine._config = type("obj", (object,), { "use_fts_fallback": True, - "enable_splade": False, "embedding_use_gpu": True, "symbol_boost_factor": 1.5, "enable_reranking": False, @@ -263,7 +229,6 @@ print(""" 1. Storage Architecture: - semantic_chunks: Used by cascade-index (binary+dense vectors) - chunks: Used by legacy SQLiteStore (currently empty in this index) - - splade_posting_list: Used by SPLADE sparse retrieval - files_fts_*: Used by FTS exact/fuzzy search CONFLICT: binary_cascade_search reads from semantic_chunks, @@ -272,7 +237,6 @@ print(""" 2. Method Contributions: - FTS: Fast but limited to keyword matching - Vector: Semantic understanding but requires embeddings - - SPLADE: Sparse retrieval, good for keyword+semantic hybrid 3. FTS + Rerank Fusion: - CrossEncoder reranking can improve precision diff --git a/codex-lens/benchmarks/cascade_benchmark.py b/codex-lens/benchmarks/cascade_benchmark.py index 14461479..90abfda1 100644 --- a/codex-lens/benchmarks/cascade_benchmark.py +++ b/codex-lens/benchmarks/cascade_benchmark.py @@ -3,7 +3,7 @@ Compares: - binary: 256-dim binary coarse ranking + 2048-dim dense fine ranking -- hybrid: FTS+SPLADE+Vector coarse ranking + CrossEncoder fine ranking +- hybrid: FTS+Vector coarse ranking + CrossEncoder fine ranking Usage: python benchmarks/cascade_benchmark.py [--source PATH] [--queries N] [--warmup N] diff --git a/codex-lens/benchmarks/compare_semantic_methods.py b/codex-lens/benchmarks/compare_semantic_methods.py index 23837103..da7b4873 100644 --- a/codex-lens/benchmarks/compare_semantic_methods.py +++ b/codex-lens/benchmarks/compare_semantic_methods.py @@ -1,9 +1,8 @@ -"""Compare Binary Cascade, SPLADE, and Vector semantic search methods. +"""Compare Binary Cascade and Vector semantic search methods. -This script compares the three semantic retrieval approaches: +This script compares the two semantic retrieval approaches: 1. Binary Cascade: 256-bit binary vectors for coarse ranking -2. SPLADE: Sparse learned representations with inverted index -3. Vector Dense: Full semantic embeddings with cosine similarity +2. Vector Dense: Full semantic embeddings with cosine similarity """ import sys @@ -14,7 +13,6 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from codexlens.storage.dir_index import DirIndexStore -from codexlens.storage.splade_index import SpladeIndex from codexlens.semantic.vector_store import VectorStore @@ -27,19 +25,6 @@ def get_filename(path: str) -> str: return path -def find_splade_db(index_root: Path) -> Path: - """Find SPLADE database by searching directory tree.""" - # Check root first - if (index_root / "_splade.db").exists(): - return index_root / "_splade.db" - - # Search in subdirectories - for splade_db in index_root.rglob("_splade.db"): - return splade_db - - return None - - def find_binary_indexes(index_root: Path): """Find all binary index files.""" return list(index_root.rglob("_index_binary_vectors.bin")) @@ -108,55 +93,6 @@ def test_vector_search(query: str, limit: int = 10): return [], 0, str(e) -def test_splade_search(query: str, limit: int = 10): - """Test SPLADE sparse search.""" - try: - from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available - - ok, err = check_splade_available() - if not ok: - return [], 0, f"SPLADE not available: {err}" - - splade_db_path = find_splade_db(INDEX_ROOT) - if not splade_db_path: - return [], 0, "SPLADE database not found" - - splade_index = SpladeIndex(splade_db_path) - if not splade_index.has_index(): - return [], 0, "SPLADE index not initialized" - - start = time.perf_counter() - encoder = get_splade_encoder() - query_sparse = encoder.encode_text(query) - raw_results = splade_index.search(query_sparse, limit=limit, min_score=0.0) - - if not raw_results: - elapsed = (time.perf_counter() - start) * 1000 - return [], elapsed, None - - # Get chunk details - chunk_ids = [chunk_id for chunk_id, _ in raw_results] - score_map = {chunk_id: score for chunk_id, score in raw_results} - rows = splade_index.get_chunks_by_ids(chunk_ids) - - elapsed = (time.perf_counter() - start) * 1000 - - # Build result objects - results = [] - for row in rows: - chunk_id = row["id"] - results.append({ - "path": row["file_path"], - "score": score_map.get(chunk_id, 0.0), - "content": row["content"][:200] + "..." if len(row["content"]) > 200 else row["content"], - }) - - # Sort by score - results.sort(key=lambda x: x["score"], reverse=True) - return results, elapsed, None - except Exception as e: - return [], 0, str(e) - def test_binary_cascade_search(query: str, limit: int = 10): """Test binary cascade search (binary coarse + dense fine ranking).""" @@ -336,16 +272,13 @@ def compare_overlap(results1, results2, name1: str, name2: str): def main(): print("=" * 70) print("SEMANTIC SEARCH METHODS COMPARISON") - print("Binary Cascade vs SPLADE vs Vector Dense") + print("Binary Cascade vs Vector Dense") print("=" * 70) # Check prerequisites print("\n[Prerequisites Check]") print(f" Index Root: {INDEX_ROOT}") - splade_db = find_splade_db(INDEX_ROOT) - print(f" SPLADE DB: {splade_db} - {'EXISTS' if splade_db else 'NOT FOUND'}") - binary_indexes = find_binary_indexes(INDEX_ROOT) print(f" Binary Indexes: {len(binary_indexes)} found") for bi in binary_indexes[:3]: @@ -356,11 +289,10 @@ def main(): # Aggregate statistics all_results = { "binary": {"total_results": 0, "total_time": 0, "queries": 0, "errors": []}, - "splade": {"total_results": 0, "total_time": 0, "queries": 0, "errors": []}, "vector": {"total_results": 0, "total_time": 0, "queries": 0, "errors": []}, } - overlap_scores = {"binary_splade": [], "binary_vector": [], "splade_vector": []} + overlap_scores = {"binary_vector": []} for query in TEST_QUERIES: print(f"\n{'#'*70}") @@ -369,12 +301,10 @@ def main(): # Test each method binary_results, binary_time, binary_err = test_binary_cascade_search(query) - splade_results, splade_time, splade_err = test_splade_search(query) vector_results, vector_time, vector_err = test_vector_search(query) # Print results print_results("Binary Cascade (256-bit + Dense Rerank)", binary_results, binary_time, binary_err) - print_results("SPLADE (Sparse Learned)", splade_results, splade_time, splade_err) print_results("Vector Dense (Semantic Embeddings)", vector_results, vector_time, vector_err) # Update statistics @@ -385,13 +315,6 @@ def main(): else: all_results["binary"]["errors"].append(binary_err) - if not splade_err: - all_results["splade"]["total_results"] += len(splade_results) - all_results["splade"]["total_time"] += splade_time - all_results["splade"]["queries"] += 1 - else: - all_results["splade"]["errors"].append(splade_err) - if not vector_err: all_results["vector"]["total_results"] += len(vector_results) all_results["vector"]["total_time"] += vector_time @@ -401,15 +324,9 @@ def main(): # Compare overlap print("\n[Result Overlap Analysis]") - if binary_results and splade_results: - j = compare_overlap(binary_results, splade_results, "Binary", "SPLADE") - overlap_scores["binary_splade"].append(j) if binary_results and vector_results: j = compare_overlap(binary_results, vector_results, "Binary", "Vector") overlap_scores["binary_vector"].append(j) - if splade_results and vector_results: - j = compare_overlap(splade_results, vector_results, "SPLADE", "Vector") - overlap_scores["splade_vector"].append(j) # Print summary print("\n" + "=" * 70) @@ -447,13 +364,13 @@ def main(): # Analyze working methods working_methods = [m for m, s in all_results.items() if s["queries"] > 0] - if len(working_methods) == 3: + if len(working_methods) == 2: # All methods working - compare quality - print("\nAll three methods working. Quality comparison:") + print("\nBoth methods working. Quality comparison:") # Compare avg results print("\n Result Coverage (higher = more recall):") - for m in ["vector", "splade", "binary"]: + for m in ["vector", "binary"]: stats = all_results[m] if stats["queries"] > 0: avg = stats["total_results"] / stats["queries"] @@ -461,7 +378,7 @@ def main(): # Compare speed print("\n Speed (lower = faster):") - for m in ["binary", "splade", "vector"]: + for m in ["binary", "vector"]: stats = all_results[m] if stats["queries"] > 0: avg = stats["total_time"] / stats["queries"] @@ -470,11 +387,10 @@ def main(): # Recommend fusion strategy print("\n Recommended Fusion Strategy:") print(" For quality-focused hybrid search:") - print(" 1. Run all three in parallel") + print(" 1. Run both methods in parallel") print(" 2. Use RRF fusion with weights:") - print(" - Vector: 0.4 (best semantic understanding)") - print(" - SPLADE: 0.35 (learned sparse representations)") - print(" - Binary: 0.25 (fast coarse filtering)") + print(" - Vector: 0.6 (best semantic understanding)") + print(" - Binary: 0.4 (fast coarse filtering)") print(" 3. Apply CrossEncoder reranking on top-50") elif len(working_methods) >= 2: diff --git a/codex-lens/benchmarks/method_contribution_analysis.py b/codex-lens/benchmarks/method_contribution_analysis.py index e005f958..e16abe6a 100644 --- a/codex-lens/benchmarks/method_contribution_analysis.py +++ b/codex-lens/benchmarks/method_contribution_analysis.py @@ -1,7 +1,7 @@ """Analysis script for hybrid search method contribution and storage architecture. This script analyzes: -1. Individual method contribution in hybrid search (FTS/SPLADE/Vector) +1. Individual method contribution in hybrid search (FTS/Vector) 2. Storage architecture conflicts between different retrieval methods 3. FTS + Rerank fusion experiment """ @@ -24,9 +24,7 @@ from codexlens.search.ranking import ( reciprocal_rank_fusion, cross_encoder_rerank, DEFAULT_WEIGHTS, - FTS_FALLBACK_WEIGHTS, ) -from codexlens.search.hybrid_search import THREE_WAY_WEIGHTS from codexlens.entities import SearchResult @@ -117,15 +115,7 @@ def analyze_storage_architecture(index_path: Path) -> Dict[str, Any]: "binary cascade search." ) - # 2. Check SPLADE index status - if "splade_posting_list" in tables: - splade_count = results["tables"]["splade_posting_list"]["row_count"] - if splade_count == 0: - results["recommendations"].append( - "SPLADE tables exist but empty. Run SPLADE indexing to enable sparse retrieval." - ) - - # 3. Check FTS tables + # 2. Check FTS tables fts_tables = [t for t in tables if t.startswith("files_fts")] if len(fts_tables) >= 2: results["recommendations"].append( @@ -163,10 +153,9 @@ def analyze_method_contributions( # Run each method independently methods = { - "fts_exact": {"fuzzy": False, "vector": False, "splade": False}, - "fts_fuzzy": {"fuzzy": True, "vector": False, "splade": False}, - "vector": {"fuzzy": False, "vector": True, "splade": False}, - "splade": {"fuzzy": False, "vector": False, "splade": True}, + "fts_exact": {"fuzzy": False, "vector": False}, + "fts_fuzzy": {"fuzzy": True, "vector": False}, + "vector": {"fuzzy": False, "vector": True}, } method_results: Dict[str, List[SearchResult]] = {} @@ -178,7 +167,6 @@ def analyze_method_contributions( # Set config to disable/enable specific backends engine._config = type('obj', (object,), { 'use_fts_fallback': method_name.startswith("fts"), - 'enable_splade': method_name == "splade", 'embedding_use_gpu': True, })() @@ -186,13 +174,13 @@ def analyze_method_contributions( if method_name == "fts_exact": # Force FTS fallback mode with fuzzy disabled - engine.weights = FTS_FALLBACK_WEIGHTS.copy() + engine.weights = DEFAULT_WEIGHTS.copy() results_list = engine.search( index_path, query, limit=limit, enable_fuzzy=False, enable_vector=False, pure_vector=False ) elif method_name == "fts_fuzzy": - engine.weights = FTS_FALLBACK_WEIGHTS.copy() + engine.weights = DEFAULT_WEIGHTS.copy() results_list = engine.search( index_path, query, limit=limit, enable_fuzzy=True, enable_vector=False, pure_vector=False @@ -202,12 +190,6 @@ def analyze_method_contributions( index_path, query, limit=limit, enable_fuzzy=False, enable_vector=True, pure_vector=True ) - elif method_name == "splade": - engine.weights = {"splade": 1.0} - results_list = engine.search( - index_path, query, limit=limit, - enable_fuzzy=False, enable_vector=False, pure_vector=False - ) else: results_list = [] @@ -259,7 +241,7 @@ def analyze_method_contributions( # Compute RRF with each method's contribution rrf_map = {} for name, results in method_results.items(): - if results and name in ["fts_exact", "splade", "vector"]: + if results and name in ["fts_exact", "vector"]: # Rename for RRF rrf_name = name.replace("fts_exact", "exact") rrf_map[rrf_name] = results @@ -310,7 +292,7 @@ def experiment_fts_rerank_fusion( """Experiment: FTS + Rerank fusion vs standard hybrid. Compares: - 1. Standard Hybrid (SPLADE + Vector RRF) + 1. Standard Hybrid (FTS + Vector RRF) 2. FTS + CrossEncoder Rerank -> then fuse with Vector """ results = { @@ -336,11 +318,10 @@ def experiment_fts_rerank_fusion( "strategies": {} } - # Strategy 1: Standard Hybrid (SPLADE + Vector) + # Strategy 1: Standard Hybrid (FTS + Vector) try: engine = HybridSearchEngine(weights=DEFAULT_WEIGHTS) engine._config = type('obj', (object,), { - 'enable_splade': True, 'use_fts_fallback': False, 'embedding_use_gpu': True, })() @@ -364,10 +345,9 @@ def experiment_fts_rerank_fusion( # Strategy 2: FTS + Rerank -> Fuse with Vector try: # Step 1: Get FTS results (coarse) - fts_engine = HybridSearchEngine(weights=FTS_FALLBACK_WEIGHTS) + fts_engine = HybridSearchEngine(weights=DEFAULT_WEIGHTS) fts_engine._config = type('obj', (object,), { 'use_fts_fallback': True, - 'enable_splade': False, 'embedding_use_gpu': True, })() diff --git a/codex-lens/docs/CODEXLENS_LSP_API_SPEC.md b/codex-lens/docs/CODEXLENS_LSP_API_SPEC.md index 67467e3b..fc2be840 100644 --- a/codex-lens/docs/CODEXLENS_LSP_API_SPEC.md +++ b/codex-lens/docs/CODEXLENS_LSP_API_SPEC.md @@ -405,7 +405,7 @@ def semantic_search( - rrf: Reciprocal Rank Fusion (推荐,默认) - staged: 分阶段级联 → staged_cascade_search - binary: 二分重排级联 → binary_rerank_cascade_search - - hybrid: 混合级联 → hybrid_cascade_search + - hybrid: 混合级联 → hybrid_search kind_filter: 符号类型过滤 limit: 最大返回数量 include_match_reason: 是否生成匹配原因 (启发式,非 LLM) diff --git a/codex-lens/pyproject.toml b/codex-lens/pyproject.toml index be00a4cb..9819bfe6 100644 --- a/codex-lens/pyproject.toml +++ b/codex-lens/pyproject.toml @@ -80,18 +80,6 @@ reranker = [ "transformers>=4.36", ] -# SPLADE sparse retrieval -splade = [ - "transformers>=4.36", - "optimum[onnxruntime]>=1.16", -] - -# SPLADE with GPU acceleration (CUDA) -splade-gpu = [ - "transformers>=4.36", - "optimum[onnxruntime-gpu]>=1.16", -] - # Encoding detection for non-UTF8 files encoding = [ "chardet>=5.0", diff --git a/codex-lens/src/codexlens/api/semantic.py b/codex-lens/src/codexlens/api/semantic.py index f17e1c8b..4e66405e 100644 --- a/codex-lens/src/codexlens/api/semantic.py +++ b/codex-lens/src/codexlens/api/semantic.py @@ -48,7 +48,8 @@ def semantic_search( - rrf: Reciprocal Rank Fusion (recommended, default) - staged: Staged cascade -> staged_cascade_search - binary: Binary rerank cascade -> binary_cascade_search - - hybrid: Hybrid cascade -> hybrid_cascade_search + - hybrid: Binary rerank cascade (backward compat) -> binary_rerank_cascade_search + - dense_rerank: Dense rerank cascade -> dense_rerank_cascade_search kind_filter: Symbol type filter (e.g., ["function", "class"]) limit: Max return count (default 20) include_match_reason: Generate match reason (heuristic, not LLM) @@ -215,7 +216,8 @@ def _execute_search( - rrf: Standard hybrid search with RRF fusion - staged: staged_cascade_search - binary: binary_cascade_search - - hybrid: hybrid_cascade_search + - hybrid: binary_rerank_cascade_search (backward compat) + - dense_rerank: dense_rerank_cascade_search Args: engine: ChainSearchEngine instance @@ -249,8 +251,8 @@ def _execute_search( options=options, ) elif fusion_strategy == "hybrid": - # Use hybrid cascade search (FTS+SPLADE+Vector + cross-encoder) - return engine.hybrid_cascade_search( + # Backward compat: hybrid now maps to binary_rerank_cascade_search + return engine.binary_rerank_cascade_search( query=query, source_path=source_path, k=limit, @@ -342,8 +344,6 @@ def _transform_results( fts_scores.append(source_scores["exact"]) if "fuzzy" in source_scores: fts_scores.append(source_scores["fuzzy"]) - if "splade" in source_scores: - fts_scores.append(source_scores["splade"]) if fts_scores: structural_score = max(fts_scores) diff --git a/codex-lens/src/codexlens/cli/commands.py b/codex-lens/src/codexlens/cli/commands.py index ebf81101..687043a5 100644 --- a/codex-lens/src/codexlens/cli/commands.py +++ b/codex-lens/src/codexlens/cli/commands.py @@ -6,7 +6,6 @@ import json import logging import os import shutil -import sqlite3 from pathlib import Path from typing import Annotated, Any, Dict, Iterable, List, Optional @@ -37,7 +36,7 @@ from .output import ( app = typer.Typer(help="CodexLens CLI — local code indexing and search.") # Index subcommand group for reorganized commands -index_app = typer.Typer(help="Index management commands (init, embeddings, splade, binary, status, migrate, all)") +index_app = typer.Typer(help="Index management commands (init, embeddings, binary, status, migrate, all)") app.add_typer(index_app, name="index") @@ -521,15 +520,15 @@ def search( print_json(success=False, error=f"Invalid deprecated mode: {mode}. Use --method instead.") else: console.print(f"[red]Invalid deprecated mode:[/red] {mode}") - console.print("[dim]Use --method with: fts, vector, splade, hybrid, cascade[/dim]") + console.print("[dim]Use --method with: fts, vector, hybrid, cascade[/dim]") raise typer.Exit(code=1) # Configure search (load settings from file) config = Config.load() # Validate method - simplified interface exposes only dense_rerank and fts - # Other methods (vector, splade, hybrid, cascade) are hidden but still work for backward compatibility - valid_methods = ["fts", "dense_rerank", "vector", "splade", "hybrid", "cascade"] + # Other methods (vector, hybrid, cascade) are hidden but still work for backward compatibility + valid_methods = ["fts", "dense_rerank", "vector", "hybrid", "cascade"] if actual_method not in valid_methods: if json_mode: print_json(success=False, error=f"Invalid method: {actual_method}. Use 'dense_rerank' (semantic) or 'fts' (exact keyword).") @@ -561,7 +560,7 @@ def search( try: # Check if using key=value format (new) or legacy comma-separated format if "=" in weights: - # New format: splade=0.4,vector=0.6 or exact=0.3,fuzzy=0.1,vector=0.6 + # New format: exact=0.3,fuzzy=0.1,vector=0.6 weight_dict = {} for pair in weights.split(","): if "=" in pair: @@ -592,17 +591,6 @@ def search( "fuzzy": weight_parts[1], "vector": weight_parts[2], } - elif len(weight_parts) == 2: - # Two values: assume splade,vector - weight_sum = sum(weight_parts) - if abs(weight_sum - 1.0) > 0.01: - if not json_mode: - console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]") - weight_parts = [w / weight_sum for w in weight_parts] - hybrid_weights = { - "splade": weight_parts[0], - "vector": weight_parts[1], - } else: if not json_mode: console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]") @@ -621,7 +609,6 @@ def search( # Map method to SearchOptions flags # fts: FTS-only search (optionally with fuzzy) # vector: Pure vector semantic search - # splade: SPLADE sparse neural search # hybrid: RRF fusion of sparse + dense # cascade: Two-stage binary + dense retrieval if actual_method == "fts": @@ -629,35 +616,24 @@ def search( enable_fuzzy = use_fuzzy enable_vector = False pure_vector = False - enable_splade = False enable_cascade = False elif actual_method == "vector": hybrid_mode = True enable_fuzzy = False enable_vector = True pure_vector = True - enable_splade = False - enable_cascade = False - elif actual_method == "splade": - hybrid_mode = True - enable_fuzzy = False - enable_vector = False - pure_vector = False - enable_splade = True enable_cascade = False elif actual_method == "hybrid": hybrid_mode = True enable_fuzzy = use_fuzzy enable_vector = True pure_vector = False - enable_splade = True # SPLADE is preferred sparse in hybrid enable_cascade = False elif actual_method == "cascade": hybrid_mode = True enable_fuzzy = False enable_vector = True pure_vector = False - enable_splade = False enable_cascade = True else: raise ValueError(f"Invalid method: {actual_method}") @@ -678,7 +654,6 @@ def search( enable_fuzzy=enable_fuzzy, enable_vector=enable_vector, pure_vector=pure_vector, - enable_splade=enable_splade, enable_cascade=enable_cascade, hybrid_weights=hybrid_weights, ) @@ -2857,251 +2832,8 @@ def gpu_reset( -# ==================== SPLADE Commands ==================== - -@index_app.command("splade") -def index_splade( - path: Path = typer.Argument(..., help="Project path to index"), - rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"), - verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."), -) -> None: - """Generate SPLADE sparse index for existing codebase. - - Encodes all semantic chunks with SPLADE model and builds inverted index - for efficient sparse retrieval. - - This command discovers all _index.db files recursively in the project's - index directory and builds SPLADE encodings for chunks across all of them. - - Examples: - codexlens index splade ~/projects/my-app - codexlens index splade . --rebuild - """ - _configure_logging(verbose) - - from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available - from codexlens.storage.splade_index import SpladeIndex - from codexlens.semantic.vector_store import VectorStore - - # Check SPLADE availability - ok, err = check_splade_available() - if not ok: - console.print(f"[red]SPLADE not available: {err}[/red]") - console.print("[dim]Install with: pip install transformers torch[/dim]") - raise typer.Exit(1) - - # Find index root directory - target_path = path.expanduser().resolve() - - # Determine index root directory (containing _index.db files) - if target_path.is_file() and target_path.name == "_index.db": - index_root = target_path.parent - elif target_path.is_dir(): - # Check for local .codexlens/_index.db - local_index = target_path / ".codexlens" / "_index.db" - if local_index.exists(): - index_root = local_index.parent - else: - # Try to find via registry - registry = RegistryStore() - try: - registry.initialize() - mapper = PathMapper() - index_db = mapper.source_to_index_db(target_path) - if not index_db.exists(): - console.print(f"[red]Error:[/red] No index found for {target_path}") - console.print("Run 'codexlens init' first to create an index") - raise typer.Exit(1) - index_root = index_db.parent - finally: - registry.close() - else: - console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory") - raise typer.Exit(1) - - # Discover all _index.db files recursively - all_index_dbs = sorted(index_root.rglob("_index.db")) - if not all_index_dbs: - console.print(f"[red]Error:[/red] No _index.db files found in {index_root}") - raise typer.Exit(1) - - console.print(f"[blue]Discovered {len(all_index_dbs)} index databases[/blue]") - - # SPLADE index is stored alongside the root _index.db - from codexlens.config import SPLADE_DB_NAME - splade_db = index_root / SPLADE_DB_NAME - - if splade_db.exists() and not rebuild: - console.print("[yellow]SPLADE index exists. Use --rebuild to regenerate.[/yellow]") - return - - # If rebuild, delete existing splade database - if splade_db.exists() and rebuild: - splade_db.unlink() - - # Collect all chunks from all distributed index databases - # Assign globally unique IDs to avoid collisions (each DB starts with ID 1) - console.print(f"[blue]Loading chunks from {len(all_index_dbs)} distributed indexes...[/blue]") - all_chunks = [] # (global_id, chunk) pairs - total_files_checked = 0 - indexes_with_chunks = 0 - global_id = 0 # Sequential global ID across all databases - - for index_db in all_index_dbs: - total_files_checked += 1 - try: - vector_store = VectorStore(index_db) - chunks = vector_store.get_all_chunks() - if chunks: - indexes_with_chunks += 1 - # Assign sequential global IDs to avoid collisions - for chunk in chunks: - global_id += 1 - all_chunks.append((global_id, chunk, index_db)) - if verbose: - console.print(f" [dim]{index_db.parent.name}: {len(chunks)} chunks[/dim]") - vector_store.close() - except Exception as e: - if verbose: - console.print(f" [yellow]Warning: Failed to read {index_db}: {e}[/yellow]") - - if not all_chunks: - console.print("[yellow]No chunks found in any index database[/yellow]") - console.print(f"[dim]Checked {total_files_checked} index files, found 0 chunks[/dim]") - console.print("[dim]Generate embeddings first with 'codexlens embeddings-generate --recursive'[/dim]") - raise typer.Exit(1) - - console.print(f"[blue]Found {len(all_chunks)} chunks across {indexes_with_chunks} indexes[/blue]") - console.print(f"[blue]Encoding with SPLADE...[/blue]") - - # Initialize SPLADE - encoder = get_splade_encoder() - splade_index = SpladeIndex(splade_db) - splade_index.create_tables() - - # Encode in batches with progress bar - chunk_metadata_batch = [] - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeElapsedColumn(), - console=console, - ) as progress: - task = progress.add_task("Encoding...", total=len(all_chunks)) - for global_id, chunk, source_db_path in all_chunks: - sparse_vec = encoder.encode_text(chunk.content) - splade_index.add_posting(global_id, sparse_vec) - # Store chunk metadata for self-contained search - # Serialize metadata dict to JSON string - metadata_str = None - if hasattr(chunk, 'metadata') and chunk.metadata: - try: - metadata_str = json.dumps(chunk.metadata) if isinstance(chunk.metadata, dict) else chunk.metadata - except Exception: - pass - chunk_metadata_batch.append(( - global_id, - chunk.file_path or "", - chunk.content, - metadata_str, - str(source_db_path) - )) - progress.advance(task) - - # Batch insert chunk metadata - if chunk_metadata_batch: - splade_index.add_chunks_metadata_batch(chunk_metadata_batch) - - # Set metadata - splade_index.set_metadata( - model_name=encoder.model_name, - vocab_size=encoder.vocab_size - ) - - stats = splade_index.get_stats() - console.print(f"[green]OK[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings") - console.print(f" Source indexes: {indexes_with_chunks}") - console.print(f" Database: [dim]{splade_db}[/dim]") -@app.command("splade-status", hidden=True, deprecated=True) -def splade_status_command( - path: Path = typer.Argument(..., help="Project path"), - verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."), -) -> None: - """[Deprecated] Use 'codexlens index status' instead. - - Show SPLADE index status and statistics. - - Examples: - codexlens splade-status ~/projects/my-app - codexlens splade-status . - """ - _deprecated_command_warning("splade-status", "index status") - _configure_logging(verbose) - - from codexlens.storage.splade_index import SpladeIndex - from codexlens.semantic.splade_encoder import check_splade_available - from codexlens.config import SPLADE_DB_NAME - - # Find index database - target_path = path.expanduser().resolve() - - if target_path.is_file() and target_path.name == "_index.db": - splade_db = target_path.parent / SPLADE_DB_NAME - elif target_path.is_dir(): - # Check for local .codexlens/_splade.db - local_splade = target_path / ".codexlens" / SPLADE_DB_NAME - if local_splade.exists(): - splade_db = local_splade - else: - # Try to find via registry - registry = RegistryStore() - try: - registry.initialize() - mapper = PathMapper() - index_db = mapper.source_to_index_db(target_path) - splade_db = index_db.parent / SPLADE_DB_NAME - finally: - registry.close() - else: - console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory") - raise typer.Exit(1) - - if not splade_db.exists(): - console.print("[yellow]No SPLADE index found[/yellow]") - console.print(f"[dim]Run 'codexlens splade-index {path}' to create one[/dim]") - return - - splade_index = SpladeIndex(splade_db) - - if not splade_index.has_index(): - console.print("[yellow]SPLADE tables not initialized[/yellow]") - return - - metadata = splade_index.get_metadata() - stats = splade_index.get_stats() - - # Create status table - table = Table(title="SPLADE Index Status", show_header=False) - table.add_column("Property", style="cyan") - table.add_column("Value") - - table.add_row("Database", str(splade_db)) - if metadata: - table.add_row("Model", metadata['model_name']) - table.add_row("Vocab Size", str(metadata['vocab_size'])) - table.add_row("Chunks", str(stats['unique_chunks'])) - table.add_row("Unique Tokens", str(stats['unique_tokens'])) - table.add_row("Total Postings", str(stats['total_postings'])) - - ok, err = check_splade_available() - status_text = "[green]Yes[/green]" if ok else f"[red]No[/red] - {err}" - table.add_row("SPLADE Available", status_text) - - console.print(table) # ==================== Watch Command ==================== @@ -3516,11 +3248,10 @@ def index_status( json_mode: bool = typer.Option(False, "--json", help="Output JSON response."), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."), ) -> None: - """Show comprehensive index status (embeddings + SPLADE). + """Show comprehensive index status (embeddings). Shows combined status for all index types: - Dense vector embeddings (HNSW) - - SPLADE sparse embeddings - Binary cascade embeddings Examples: @@ -3531,9 +3262,6 @@ def index_status( _configure_logging(verbose, json_mode) from codexlens.cli.embedding_manager import check_index_embeddings, get_embedding_stats_summary - from codexlens.storage.splade_index import SpladeIndex - from codexlens.semantic.splade_encoder import check_splade_available - from codexlens.config import SPLADE_DB_NAME # Determine target path and index root if path is None: @@ -3571,36 +3299,11 @@ def index_status( # Get embeddings status embeddings_result = get_embedding_stats_summary(index_root) - # Get SPLADE status - splade_db = index_root / SPLADE_DB_NAME - splade_status = { - "available": False, - "has_index": False, - "stats": None, - "metadata": None, - } - - splade_available, splade_err = check_splade_available() - splade_status["available"] = splade_available - - if splade_db.exists(): - try: - splade_index = SpladeIndex(splade_db) - if splade_index.has_index(): - splade_status["has_index"] = True - splade_status["stats"] = splade_index.get_stats() - splade_status["metadata"] = splade_index.get_metadata() - splade_index.close() - except Exception as e: - if verbose: - console.print(f"[yellow]Warning: Failed to read SPLADE index: {e}[/yellow]") - # Build combined result result = { "index_root": str(index_root), "embeddings": embeddings_result.get("result") if embeddings_result.get("success") else None, "embeddings_error": embeddings_result.get("error") if not embeddings_result.get("success") else None, - "splade": splade_status, } if json_mode: @@ -3623,27 +3326,6 @@ def index_status( else: console.print(f" [yellow]--[/yellow] {embeddings_result.get('error', 'Not available')}") - # SPLADE section - console.print("\n[bold]SPLADE Sparse Index:[/bold]") - if splade_status["has_index"]: - stats = splade_status["stats"] or {} - metadata = splade_status["metadata"] or {} - console.print(f" [green]OK[/green] SPLADE index available") - console.print(f" Chunks: {stats.get('unique_chunks', 0):,}") - console.print(f" Unique tokens: {stats.get('unique_tokens', 0):,}") - console.print(f" Total postings: {stats.get('total_postings', 0):,}") - if metadata.get("model_name"): - console.print(f" Model: {metadata['model_name']}") - elif splade_available: - console.print(f" [yellow]--[/yellow] No SPLADE index found") - console.print(f" [dim]Run 'codexlens index splade ' to create one[/dim]") - else: - console.print(f" [yellow]--[/yellow] SPLADE not available: {splade_err}") - - # Runtime availability - console.print("\n[bold]Runtime Availability:[/bold]") - console.print(f" SPLADE encoder: {'[green]Yes[/green]' if splade_available else f'[red]No[/red] ({splade_err})'}") - # ==================== Index Update Command ==================== @@ -3739,22 +3421,19 @@ def index_all( backend: str = typer.Option("fastembed", "--backend", "-b", help="Embedding backend: fastembed or litellm."), model: str = typer.Option("code", "--model", "-m", help="Embedding model profile or name."), max_workers: int = typer.Option(1, "--max-workers", min=1, help="Max concurrent API calls."), - skip_splade: bool = typer.Option(False, "--skip-splade", help="Skip SPLADE index generation."), json_mode: bool = typer.Option(False, "--json", help="Output JSON response."), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."), ) -> None: - """Run all indexing operations in sequence (init, embeddings, splade). + """Run all indexing operations in sequence (init, embeddings). This is a convenience command that runs the complete indexing pipeline: 1. FTS index initialization (index init) 2. Dense vector embeddings (index embeddings) - 3. SPLADE sparse index (index splade) - unless --skip-splade Examples: codexlens index all ~/projects/my-app codexlens index all . --force codexlens index all . --backend litellm --model text-embedding-3-small - codexlens index all . --skip-splade """ _configure_logging(verbose, json_mode) @@ -3766,7 +3445,7 @@ def index_all( # Step 1: Run init if not json_mode: - console.print(f"[bold]Step 1/3: Initializing FTS index...[/bold]") + console.print(f"[bold]Step 1/2: Initializing FTS index...[/bold]") try: # Import and call the init function directly @@ -3810,7 +3489,7 @@ def index_all( # Step 2: Generate embeddings if not json_mode: - console.print(f"\n[bold]Step 2/3: Generating dense embeddings...[/bold]") + console.print(f"\n[bold]Step 2/2: Generating dense embeddings...[/bold]") try: from codexlens.cli.embedding_manager import generate_dense_embeddings_centralized @@ -3851,103 +3530,6 @@ def index_all( if not json_mode: console.print(f" [yellow]Warning:[/yellow] {e}") - # Step 3: Generate SPLADE index (unless skipped) - if not skip_splade: - if not json_mode: - console.print(f"\n[bold]Step 3/3: Generating SPLADE index...[/bold]") - - try: - from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available - from codexlens.storage.splade_index import SpladeIndex - from codexlens.semantic.vector_store import VectorStore - from codexlens.config import SPLADE_DB_NAME - - ok, err = check_splade_available() - if not ok: - results["steps"]["splade"] = {"success": False, "error": f"SPLADE not available: {err}"} - if not json_mode: - console.print(f" [yellow]Skipped:[/yellow] SPLADE not available ({err})") - else: - # Discover all _index.db files - all_index_dbs = sorted(index_root.rglob("_index.db")) - if not all_index_dbs: - results["steps"]["splade"] = {"success": False, "error": "No index databases found"} - if not json_mode: - console.print(f" [yellow]Skipped:[/yellow] No index databases found") - else: - # Collect chunks - all_chunks = [] - global_id = 0 - for index_db in all_index_dbs: - try: - vector_store = VectorStore(index_db) - chunks = vector_store.get_all_chunks() - for chunk in chunks: - global_id += 1 - all_chunks.append((global_id, chunk, index_db)) - vector_store.close() - except Exception: - pass - - if all_chunks: - splade_db = index_root / SPLADE_DB_NAME - if splade_db.exists() and force: - splade_db.unlink() - - encoder = get_splade_encoder() - splade_index = SpladeIndex(splade_db) - splade_index.create_tables() - - chunk_metadata_batch = [] - import json as json_module - for gid, chunk, source_db_path in all_chunks: - sparse_vec = encoder.encode_text(chunk.content) - splade_index.add_posting(gid, sparse_vec) - metadata_str = None - if hasattr(chunk, 'metadata') and chunk.metadata: - try: - metadata_str = json_module.dumps(chunk.metadata) if isinstance(chunk.metadata, dict) else chunk.metadata - except Exception: - pass - chunk_metadata_batch.append(( - gid, - chunk.file_path or "", - chunk.content, - metadata_str, - str(source_db_path) - )) - - if chunk_metadata_batch: - splade_index.add_chunks_metadata_batch(chunk_metadata_batch) - - splade_index.set_metadata( - model_name=encoder.model_name, - vocab_size=encoder.vocab_size - ) - - stats = splade_index.get_stats() - results["steps"]["splade"] = { - "success": True, - "chunks": stats['unique_chunks'], - "postings": stats['total_postings'], - } - if not json_mode: - console.print(f" [green]OK[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings") - else: - results["steps"]["splade"] = {"success": False, "error": "No chunks found"} - if not json_mode: - console.print(f" [yellow]Skipped:[/yellow] No chunks found in indexes") - - except Exception as e: - results["steps"]["splade"] = {"success": False, "error": str(e)} - if not json_mode: - console.print(f" [yellow]Warning:[/yellow] {e}") - else: - results["steps"]["splade"] = {"success": True, "skipped": True} - if not json_mode: - console.print(f"\n[bold]Step 3/3: SPLADE index...[/bold]") - console.print(f" [dim]Skipped (--skip-splade)[/dim]") - # Summary if json_mode: print_json(success=True, result=results) @@ -3955,10 +3537,8 @@ def index_all( console.print(f"\n[bold]Indexing Complete[/bold]") init_ok = results["steps"].get("init", {}).get("success", False) emb_ok = results["steps"].get("embeddings", {}).get("success", False) - splade_ok = results["steps"].get("splade", {}).get("success", False) console.print(f" FTS Index: {'[green]OK[/green]' if init_ok else '[red]Failed[/red]'}") console.print(f" Embeddings: {'[green]OK[/green]' if emb_ok else '[yellow]Partial/Skipped[/yellow]'}") - console.print(f" SPLADE: {'[green]OK[/green]' if splade_ok else '[yellow]Partial/Skipped[/yellow]'}") # ==================== Index Migration Commands ==================== @@ -3997,50 +3577,6 @@ def _set_index_version(index_root: Path, version: str) -> None: version_file.write_text(version, encoding="utf-8") -def _discover_distributed_splade(index_root: Path) -> List[Dict[str, Any]]: - """Discover distributed SPLADE data in _index.db files. - - Scans all _index.db files for embedded splade_postings tables. - This is the old distributed format that needs migration. - - Args: - index_root: Root directory to scan - - Returns: - List of dicts with db_path, posting_count, chunk_count - """ - results = [] - - for db_path in index_root.rglob("_index.db"): - try: - conn = sqlite3.connect(db_path, timeout=5.0) - conn.row_factory = sqlite3.Row - - # Check if splade_postings table exists (old embedded format) - cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='splade_postings'" - ) - if cursor.fetchone(): - # Count postings and chunks - try: - row = conn.execute( - "SELECT COUNT(*) as postings, COUNT(DISTINCT chunk_id) as chunks FROM splade_postings" - ).fetchone() - results.append({ - "db_path": db_path, - "posting_count": row["postings"] if row else 0, - "chunk_count": row["chunks"] if row else 0, - }) - except Exception: - pass - - conn.close() - except Exception: - pass - - return results - - def _discover_distributed_hnsw(index_root: Path) -> List[Dict[str, Any]]: """Discover distributed HNSW index files. @@ -4075,33 +3611,18 @@ def _check_centralized_storage(index_root: Path) -> Dict[str, Any]: index_root: Root directory to check Returns: - Dict with has_splade, has_vectors, splade_stats, vector_stats + Dict with has_vectors, vector_stats """ - from codexlens.config import SPLADE_DB_NAME, VECTORS_HNSW_NAME + from codexlens.config import VECTORS_HNSW_NAME - splade_db = index_root / SPLADE_DB_NAME vectors_hnsw = index_root / VECTORS_HNSW_NAME result = { - "has_splade": splade_db.exists(), "has_vectors": vectors_hnsw.exists(), - "splade_path": str(splade_db) if splade_db.exists() else None, "vectors_path": str(vectors_hnsw) if vectors_hnsw.exists() else None, - "splade_stats": None, "vector_stats": None, } - # Get SPLADE stats if exists - if splade_db.exists(): - try: - from codexlens.storage.splade_index import SpladeIndex - splade = SpladeIndex(splade_db) - if splade.has_index(): - result["splade_stats"] = splade.get_stats() - splade.close() - except Exception: - pass - # Get vector stats if exists if vectors_hnsw.exists(): try: @@ -4125,21 +3646,19 @@ def index_migrate_cmd( """Migrate old distributed index to new centralized architecture. This command upgrades indexes from the old distributed storage format - (where SPLADE/vectors were stored in each _index.db) to the new centralized - format (single _splade.db and _vectors.hnsw at index root). + (where vectors were stored in each _index.db) to the new centralized + format (single _vectors.hnsw at index root). Migration Steps: 1. Detect if migration is needed (check version marker) - 2. Discover distributed SPLADE data in _index.db files - 3. Discover distributed .hnsw files - 4. Report current status - 5. Create version marker (unless --dry-run) + 2. Discover distributed .hnsw files + 3. Report current status + 4. Create version marker (unless --dry-run) Use --dry-run to preview what would be migrated without making changes. Use --force to re-run migration even if version marker exists. - Note: For full data migration (SPLADE/vectors consolidation), run: - codexlens index splade --rebuild + Note: For full data migration (vectors consolidation), run: codexlens index embeddings --force Examples: @@ -4222,7 +3741,6 @@ def index_migrate_cmd( return # Discover distributed data - distributed_splade = _discover_distributed_splade(index_root) distributed_hnsw = _discover_distributed_hnsw(index_root) centralized = _check_centralized_storage(index_root) @@ -4239,8 +3757,6 @@ def index_migrate_cmd( "needs_migration": needs_migration, "discovery": { "total_index_dbs": len(all_index_dbs), - "distributed_splade_count": len(distributed_splade), - "distributed_splade_total_postings": sum(d["posting_count"] for d in distributed_splade), "distributed_hnsw_count": len(distributed_hnsw), "distributed_hnsw_total_bytes": sum(d["size_bytes"] for d in distributed_hnsw), }, @@ -4249,17 +3765,12 @@ def index_migrate_cmd( } # Generate recommendations - if distributed_splade and not centralized["has_splade"]: - migration_report["recommendations"].append( - f"Run 'codexlens splade-index {target_path} --rebuild' to consolidate SPLADE data" - ) - if distributed_hnsw and not centralized["has_vectors"]: migration_report["recommendations"].append( f"Run 'codexlens embeddings-generate {target_path} --recursive --force' to consolidate vector data" ) - if not distributed_splade and not distributed_hnsw: + if not distributed_hnsw: migration_report["recommendations"].append( "No distributed data found. Index may already be using centralized storage." ) @@ -4280,23 +3791,6 @@ def index_migrate_cmd( console.print(f" Total _index.db files: {len(all_index_dbs)}") console.print() - # Distributed SPLADE - console.print("[bold]Distributed SPLADE Data:[/bold]") - if distributed_splade: - total_postings = sum(d["posting_count"] for d in distributed_splade) - total_chunks = sum(d["chunk_count"] for d in distributed_splade) - console.print(f" Found in {len(distributed_splade)} _index.db files") - console.print(f" Total postings: {total_postings:,}") - console.print(f" Total chunks: {total_chunks:,}") - if verbose: - for d in distributed_splade[:5]: - console.print(f" [dim]{d['db_path'].parent.name}: {d['posting_count']} postings[/dim]") - if len(distributed_splade) > 5: - console.print(f" [dim]... and {len(distributed_splade) - 5} more[/dim]") - else: - console.print(" [dim]None found (already centralized or not generated)[/dim]") - console.print() - # Distributed HNSW console.print("[bold]Distributed HNSW Files:[/bold]") if distributed_hnsw: @@ -4314,15 +3808,6 @@ def index_migrate_cmd( # Centralized storage status console.print("[bold]Centralized Storage:[/bold]") - if centralized["has_splade"]: - stats = centralized.get("splade_stats") or {} - console.print(f" [green]OK[/green] _splade.db exists") - if stats: - console.print(f" Chunks: {stats.get('unique_chunks', 0):,}") - console.print(f" Postings: {stats.get('total_postings', 0):,}") - else: - console.print(f" [yellow]--[/yellow] _splade.db not found") - if centralized["has_vectors"]: stats = centralized.get("vector_stats") or {} size_mb = stats.get("size_bytes", 0) / (1024 * 1024) @@ -4440,20 +3925,6 @@ def init_deprecated( ) -@app.command("splade-index", hidden=True, deprecated=True) -def splade_index_deprecated( - path: Path = typer.Argument(..., help="Project path to index"), - rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"), - verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."), -) -> None: - """[Deprecated] Use 'codexlens index splade' instead.""" - _deprecated_command_warning("splade-index", "index splade") - index_splade( - path=path, - rebuild=rebuild, - verbose=verbose, - ) - @app.command("cascade-index", hidden=True, deprecated=True) def cascade_index_deprecated( diff --git a/codex-lens/src/codexlens/cli/embedding_manager.py b/codex-lens/src/codexlens/cli/embedding_manager.py index bb6467f5..97544b0a 100644 --- a/codex-lens/src/codexlens/cli/embedding_manager.py +++ b/codex-lens/src/codexlens/cli/embedding_manager.py @@ -151,15 +151,6 @@ def _cleanup_fastembed_resources() -> None: pass -def _cleanup_splade_resources() -> None: - """Release SPLADE encoder ONNX resources.""" - try: - from codexlens.semantic.splade_encoder import clear_splade_cache - clear_splade_cache() - except Exception: - pass - - def _generate_chunks_from_cursor( cursor, chunker, @@ -398,7 +389,6 @@ def generate_embeddings( endpoints: Optional[List] = None, strategy: Optional[str] = None, cooldown: Optional[float] = None, - splade_db_path: Optional[Path] = None, ) -> Dict[str, any]: """Generate embeddings for an index using memory-efficient batch processing. @@ -428,9 +418,6 @@ def generate_embeddings( Each dict has keys: model, api_key, api_base, weight. strategy: Selection strategy for multi-endpoint mode (round_robin, latency_aware). cooldown: Default cooldown seconds for rate-limited endpoints. - splade_db_path: Optional path to centralized SPLADE database. If None, SPLADE - is written to index_path (legacy behavior). Use index_root / SPLADE_DB_NAME - for centralized storage. Returns: Result dictionary with generation statistics @@ -822,97 +809,10 @@ def generate_embeddings( if progress_callback: progress_callback(f"Finalizing index... Building ANN index for {total_chunks_created} chunks") - # --- SPLADE SPARSE ENCODING (after dense embeddings) --- - # Add SPLADE encoding if enabled in config - splade_success = False - splade_error = None - - try: - from codexlens.config import Config, SPLADE_DB_NAME - config = Config.load() - - if config.enable_splade: - from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder - from codexlens.storage.splade_index import SpladeIndex - - ok, err = check_splade_available() - if ok: - if progress_callback: - progress_callback(f"Generating SPLADE sparse vectors for {total_chunks_created} chunks...") - - # Initialize SPLADE encoder and index - splade_encoder = get_splade_encoder(use_gpu=use_gpu) - # Use centralized SPLADE database if provided, otherwise fallback to index_path - effective_splade_path = splade_db_path if splade_db_path else index_path - splade_index = SpladeIndex(effective_splade_path) - splade_index.create_tables() - - # Retrieve all chunks from database for SPLADE encoding - with sqlite3.connect(index_path) as conn: - conn.row_factory = sqlite3.Row - cursor = conn.execute("SELECT id, content FROM semantic_chunks ORDER BY id") - - # Batch encode for efficiency - SPLADE_BATCH_SIZE = 32 - batch_postings = [] - chunk_batch = [] - chunk_ids = [] - - for row in cursor: - chunk_id = row["id"] - content = row["content"] - - chunk_ids.append(chunk_id) - chunk_batch.append(content) - - # Process batch when full - if len(chunk_batch) >= SPLADE_BATCH_SIZE: - sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE) - for cid, sparse_vec in zip(chunk_ids, sparse_vecs): - batch_postings.append((cid, sparse_vec)) - - chunk_batch = [] - chunk_ids = [] - - # Process remaining chunks - if chunk_batch: - sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE) - for cid, sparse_vec in zip(chunk_ids, sparse_vecs): - batch_postings.append((cid, sparse_vec)) - - # Batch insert all postings - if batch_postings: - splade_index.add_postings_batch(batch_postings) - - # Set metadata - splade_index.set_metadata( - model_name=splade_encoder.model_name, - vocab_size=splade_encoder.vocab_size - ) - - splade_success = True - if progress_callback: - stats = splade_index.get_stats() - progress_callback( - f"SPLADE index created: {stats['total_postings']} postings, " - f"{stats['unique_tokens']} unique tokens" - ) - else: - logger.debug("SPLADE not available: %s", err) - splade_error = f"SPLADE not available: {err}" - except Exception as e: - splade_error = str(e) - logger.warning("SPLADE encoding failed: %s", e) - - # Report SPLADE status after processing - if progress_callback and not splade_success and splade_error: - progress_callback(f"SPLADE index: FAILED - {splade_error}") - except Exception as e: # Cleanup on error to prevent process hanging try: _cleanup_fastembed_resources() - _cleanup_splade_resources() gc.collect() except Exception: pass @@ -924,7 +824,6 @@ def generate_embeddings( # This is critical - without it, ONNX Runtime threads prevent Python from exiting try: _cleanup_fastembed_resources() - _cleanup_splade_resources() gc.collect() except Exception: pass @@ -1098,10 +997,6 @@ def generate_embeddings_recursive( if progress_callback: progress_callback(f"Found {len(index_files)} index databases to process") - # Calculate centralized SPLADE database path - from codexlens.config import SPLADE_DB_NAME - splade_db_path = index_root / SPLADE_DB_NAME - # Process each index database all_results = [] total_chunks = 0 @@ -1131,7 +1026,6 @@ def generate_embeddings_recursive( endpoints=endpoints, strategy=strategy, cooldown=cooldown, - splade_db_path=splade_db_path, # Use centralized SPLADE storage ) all_results.append({ @@ -1153,7 +1047,6 @@ def generate_embeddings_recursive( # Each generate_embeddings() call does its own cleanup, but do a final one to be safe try: _cleanup_fastembed_resources() - _cleanup_splade_resources() gc.collect() except Exception: pass @@ -1197,7 +1090,6 @@ def generate_dense_embeddings_centralized( Target architecture: / |-- _vectors.hnsw # Centralized dense vector ANN index - |-- _splade.db # Centralized sparse vector index |-- src/ |-- _index.db # No longer contains .hnsw file @@ -1219,7 +1111,7 @@ def generate_dense_embeddings_centralized( Returns: Result dictionary with generation statistics """ - from codexlens.config import VECTORS_HNSW_NAME, SPLADE_DB_NAME + from codexlens.config import VECTORS_HNSW_NAME # Get defaults from config if not specified (default_backend, default_model, default_gpu, @@ -1543,90 +1435,6 @@ def generate_dense_embeddings_centralized( logger.warning("Binary vector generation failed: %s", e) # Non-fatal: continue without binary vectors - # --- SPLADE Sparse Index Generation (Centralized) --- - splade_success = False - splade_chunks_count = 0 - try: - from codexlens.config import Config - config = Config.load() - - if config.enable_splade and chunk_id_to_info: - from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder - from codexlens.storage.splade_index import SpladeIndex - import json - - ok, err = check_splade_available() - if ok: - if progress_callback: - progress_callback(f"Generating SPLADE sparse vectors for {len(chunk_id_to_info)} chunks...") - - # Initialize SPLADE encoder and index - splade_encoder = get_splade_encoder(use_gpu=use_gpu) - splade_db_path = index_root / SPLADE_DB_NAME - splade_index = SpladeIndex(splade_db_path) - splade_index.create_tables() - - # Batch encode for efficiency - SPLADE_BATCH_SIZE = 32 - all_postings = [] - all_chunk_metadata = [] - - # Create batches from chunk_id_to_info - chunk_items = list(chunk_id_to_info.items()) - - for i in range(0, len(chunk_items), SPLADE_BATCH_SIZE): - batch_items = chunk_items[i:i + SPLADE_BATCH_SIZE] - chunk_ids = [item[0] for item in batch_items] - chunk_contents = [item[1]["content"] for item in batch_items] - - # Generate sparse vectors - sparse_vecs = splade_encoder.encode_batch(chunk_contents, batch_size=SPLADE_BATCH_SIZE) - for cid, sparse_vec in zip(chunk_ids, sparse_vecs): - all_postings.append((cid, sparse_vec)) - - if progress_callback and (i + SPLADE_BATCH_SIZE) % 100 == 0: - progress_callback(f"SPLADE encoding: {min(i + SPLADE_BATCH_SIZE, len(chunk_items))}/{len(chunk_items)}") - - # Batch insert all postings - if all_postings: - splade_index.add_postings_batch(all_postings) - - # CRITICAL FIX: Populate splade_chunks table - for cid, info in chunk_id_to_info.items(): - metadata_str = json.dumps(info.get("metadata", {})) if info.get("metadata") else None - all_chunk_metadata.append(( - cid, - info["file_path"], - info["content"], - metadata_str, - info.get("source_index_db") - )) - - if all_chunk_metadata: - splade_index.add_chunks_metadata_batch(all_chunk_metadata) - splade_chunks_count = len(all_chunk_metadata) - - # Set metadata - splade_index.set_metadata( - model_name=splade_encoder.model_name, - vocab_size=splade_encoder.vocab_size - ) - - splade_index.close() - splade_success = True - - if progress_callback: - progress_callback(f"SPLADE index created: {len(all_postings)} postings, {splade_chunks_count} chunks") - - else: - if progress_callback: - progress_callback(f"SPLADE not available, skipping sparse index: {err}") - - except Exception as e: - logger.warning("SPLADE encoding failed: %s", e) - if progress_callback: - progress_callback(f"SPLADE encoding failed: {e}") - elapsed_time = time.time() - start_time # Cleanup @@ -1647,8 +1455,6 @@ def generate_dense_embeddings_centralized( "model_name": embedder.model_name, "central_index_path": str(central_hnsw_path), "failed_files": failed_files[:5], - "splade_success": splade_success, - "splade_chunks": splade_chunks_count, "binary_success": binary_success, "binary_count": binary_count, }, diff --git a/codex-lens/src/codexlens/config.py b/codex-lens/src/codexlens/config.py index 238e922d..b5012f64 100644 --- a/codex-lens/src/codexlens/config.py +++ b/codex-lens/src/codexlens/config.py @@ -19,9 +19,6 @@ WORKSPACE_DIR_NAME = ".codexlens" # Settings file name SETTINGS_FILE_NAME = "settings.json" -# SPLADE index database name (centralized storage) -SPLADE_DB_NAME = "_splade.db" - # Dense vector storage names (centralized storage) VECTORS_HNSW_NAME = "_vectors.hnsw" VECTORS_META_DB_NAME = "_vectors_meta.db" @@ -113,15 +110,6 @@ class Config: # For litellm: model name from config (e.g., "qwen3-embedding") embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration - # SPLADE sparse retrieval configuration - enable_splade: bool = False # Disable SPLADE by default (slow ~360ms, use FTS instead) - splade_model: str = "naver/splade-cocondenser-ensembledistil" - splade_threshold: float = 0.01 # Min weight to store in index - splade_onnx_path: Optional[str] = None # Custom ONNX model path - - # FTS fallback (disabled by default, available via --use-fts) - use_fts_fallback: bool = True # Use FTS for sparse search (fast, SPLADE disabled) - # Indexing/search optimizations global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path enable_merkle_detection: bool = True # Enable content-hash based incremental indexing @@ -152,7 +140,7 @@ class Config: enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking) cascade_coarse_k: int = 100 # Number of coarse candidates from first stage cascade_fine_k: int = 10 # Number of final results after reranking - cascade_strategy: str = "binary" # "binary" (fast binary+dense) or "hybrid" (FTS+SPLADE+Vector+CrossEncoder) + cascade_strategy: str = "binary" # "binary", "binary_rerank", "dense_rerank", or "staged" # Staged cascade search configuration (4-stage pipeline) staged_coarse_k: int = 200 # Number of coarse candidates from Stage 1 binary search @@ -398,11 +386,11 @@ class Config: cascade = settings.get("cascade", {}) if "strategy" in cascade: strategy = cascade["strategy"] - if strategy in {"binary", "hybrid", "binary_rerank", "dense_rerank"}: + if strategy in {"binary", "binary_rerank", "dense_rerank", "staged"}: self.cascade_strategy = strategy else: log.warning( - "Invalid cascade strategy in %s: %r (expected 'binary', 'hybrid', 'binary_rerank', or 'dense_rerank')", + "Invalid cascade strategy in %s: %r (expected 'binary', 'binary_rerank', 'dense_rerank', or 'staged')", self.settings_path, strategy, ) diff --git a/codex-lens/src/codexlens/search/chain_search.py b/codex-lens/src/codexlens/search/chain_search.py index 9090dbca..5a06b93c 100644 --- a/codex-lens/src/codexlens/search/chain_search.py +++ b/codex-lens/src/codexlens/search/chain_search.py @@ -55,7 +55,6 @@ class SearchOptions: enable_fuzzy: Enable fuzzy FTS in hybrid mode (default True) enable_vector: Enable vector semantic search (default False) pure_vector: If True, only use vector search without FTS fallback (default False) - enable_splade: Enable SPLADE sparse neural search (default False) enable_cascade: Enable cascade (binary+dense) two-stage retrieval (default False) hybrid_weights: Custom RRF weights for hybrid search (optional) group_results: Enable grouping of similar results (default False) @@ -75,7 +74,6 @@ class SearchOptions: enable_fuzzy: bool = True enable_vector: bool = False pure_vector: bool = False - enable_splade: bool = False enable_cascade: bool = False hybrid_weights: Optional[Dict[str, float]] = None group_results: bool = False @@ -306,154 +304,6 @@ class ChainSearchEngine: related_results=related_results, ) - def hybrid_cascade_search( - self, - query: str, - source_path: Path, - k: int = 10, - coarse_k: int = 100, - options: Optional[SearchOptions] = None, - ) -> ChainSearchResult: - """Execute two-stage cascade search with hybrid coarse retrieval and cross-encoder reranking. - - Hybrid cascade search process: - 1. Stage 1 (Coarse): Fast retrieval using RRF fusion of FTS + SPLADE + Vector - to get coarse_k candidates - 2. Stage 2 (Fine): CrossEncoder reranking of candidates to get final k results - - This approach balances recall (from broad coarse search) with precision - (from expensive but accurate cross-encoder scoring). - - Note: This method is the original hybrid approach. For binary vector cascade, - use binary_cascade_search() instead. - - Args: - query: Natural language or keyword query string - source_path: Starting directory path - k: Number of final results to return (default 10) - coarse_k: Number of coarse candidates from first stage (default 100) - options: Search configuration (uses defaults if None) - - Returns: - ChainSearchResult with reranked results and statistics - - Examples: - >>> engine = ChainSearchEngine(registry, mapper, config=config) - >>> result = engine.hybrid_cascade_search( - ... "how to authenticate users", - ... Path("D:/project/src"), - ... k=10, - ... coarse_k=100 - ... ) - >>> for r in result.results: - ... print(f"{r.path}: {r.score:.3f}") - """ - options = options or SearchOptions() - start_time = time.time() - stats = SearchStats() - - # Use config defaults if available - if self._config is not None: - if hasattr(self._config, "cascade_coarse_k"): - coarse_k = coarse_k or self._config.cascade_coarse_k - if hasattr(self._config, "cascade_fine_k"): - k = k or self._config.cascade_fine_k - - # Step 1: Find starting index - start_index = self._find_start_index(source_path) - if not start_index: - self.logger.warning(f"No index found for {source_path}") - stats.time_ms = (time.time() - start_time) * 1000 - return ChainSearchResult( - query=query, - results=[], - symbols=[], - stats=stats - ) - - # Step 2: Collect all index paths - index_paths = self._collect_index_paths(start_index, options.depth) - stats.dirs_searched = len(index_paths) - - if not index_paths: - self.logger.warning(f"No indexes collected from {start_index}") - stats.time_ms = (time.time() - start_time) * 1000 - return ChainSearchResult( - query=query, - results=[], - symbols=[], - stats=stats - ) - - # Stage 1: Coarse retrieval with hybrid search (FTS + SPLADE + Vector) - # Use hybrid mode for multi-signal retrieval - coarse_options = SearchOptions( - depth=options.depth, - max_workers=1, # Single thread for GPU safety - limit_per_dir=max(coarse_k // len(index_paths), 20), - total_limit=coarse_k, - hybrid_mode=True, - enable_fuzzy=options.enable_fuzzy, - enable_vector=True, # Enable vector for semantic matching - pure_vector=False, - hybrid_weights=options.hybrid_weights, - ) - - self.logger.debug( - "Cascade Stage 1: Coarse retrieval for %d candidates", coarse_k - ) - coarse_results, search_stats = self._search_parallel( - index_paths, query, coarse_options - ) - stats.errors = search_stats.errors - - # Merge and deduplicate coarse results - coarse_merged = self._merge_and_rank(coarse_results, coarse_k) - self.logger.debug( - "Cascade Stage 1 complete: %d candidates retrieved", len(coarse_merged) - ) - - if not coarse_merged: - stats.time_ms = (time.time() - start_time) * 1000 - return ChainSearchResult( - query=query, - results=[], - symbols=[], - stats=stats - ) - - # Stage 2: Cross-encoder reranking - self.logger.debug( - "Cascade Stage 2: Cross-encoder reranking %d candidates to top-%d", - len(coarse_merged), - k, - ) - - final_results = self._cross_encoder_rerank(query, coarse_merged, k) - - # Optional: grouping of similar results - if options.group_results: - from codexlens.search.ranking import group_similar_results - final_results = group_similar_results( - final_results, score_threshold_abs=options.grouping_threshold - ) - - stats.files_matched = len(final_results) - stats.time_ms = (time.time() - start_time) * 1000 - - self.logger.debug( - "Cascade search complete: %d results in %.2fms", - len(final_results), - stats.time_ms, - ) - - return ChainSearchResult( - query=query, - results=final_results, - symbols=[], - stats=stats, - ) - def binary_cascade_search( self, query: str, @@ -501,9 +351,9 @@ class ChainSearchEngine: """ if not NUMPY_AVAILABLE: self.logger.warning( - "NumPy not available, falling back to hybrid cascade search" + "NumPy not available, falling back to standard search" ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) options = options or SearchOptions() start_time = time.time() @@ -552,10 +402,10 @@ class ChainSearchEngine: except ImportError as exc: self.logger.warning( "Binary cascade dependencies not available: %s. " - "Falling back to hybrid cascade search.", + "Falling back to standard search.", exc ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) # Stage 1: Binary vector coarse retrieval self.logger.debug( @@ -573,10 +423,10 @@ class ChainSearchEngine: except Exception as exc: self.logger.warning( "Failed to generate binary query embedding: %s. " - "Falling back to hybrid cascade search.", + "Falling back to standard search.", exc ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) # Try centralized BinarySearcher first (preferred for mmap indexes) # The index root is the parent of the first index path @@ -629,8 +479,8 @@ class ChainSearchEngine: stats.errors.append(f"Binary search failed for {index_path}: {exc}") if not all_candidates: - self.logger.debug("No binary candidates found, falling back to hybrid") - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + self.logger.debug("No binary candidates found, falling back to standard search") + return self.search(query, source_path, options=options) # Sort by Hamming distance and take top coarse_k all_candidates.sort(key=lambda x: x[1]) @@ -828,13 +678,12 @@ class ChainSearchEngine: k: int = 10, coarse_k: int = 100, options: Optional[SearchOptions] = None, - strategy: Optional[Literal["binary", "hybrid", "binary_rerank", "dense_rerank", "staged"]] = None, + strategy: Optional[Literal["binary", "binary_rerank", "dense_rerank", "staged"]] = None, ) -> ChainSearchResult: """Unified cascade search entry point with strategy selection. Provides a single interface for cascade search with configurable strategy: - "binary": Uses binary vector coarse ranking + dense fine ranking (fastest) - - "hybrid": Uses FTS+SPLADE+Vector coarse ranking + cross-encoder reranking (original) - "binary_rerank": Uses binary vector coarse ranking + cross-encoder reranking (best balance) - "dense_rerank": Uses dense vector coarse ranking + cross-encoder reranking - "staged": 4-stage pipeline: binary -> LSP expand -> clustering -> optional rerank @@ -850,7 +699,7 @@ class ChainSearchEngine: k: Number of final results to return (default 10) coarse_k: Number of coarse candidates from first stage (default 100) options: Search configuration (uses defaults if None) - strategy: Cascade strategy - "binary", "hybrid", "binary_rerank", "dense_rerank", or "staged". + strategy: Cascade strategy - "binary", "binary_rerank", "dense_rerank", or "staged". Returns: ChainSearchResult with reranked results and statistics @@ -859,8 +708,6 @@ class ChainSearchEngine: >>> engine = ChainSearchEngine(registry, mapper, config=config) >>> # Use binary cascade (default, fastest) >>> result = engine.cascade_search("auth", Path("D:/project")) - >>> # Use hybrid cascade (original behavior) - >>> result = engine.cascade_search("auth", Path("D:/project"), strategy="hybrid") >>> # Use binary + cross-encoder (best balance of speed and quality) >>> result = engine.cascade_search("auth", Path("D:/project"), strategy="binary_rerank") >>> # Use 4-stage pipeline (binary + LSP expand + clustering + optional rerank) @@ -868,7 +715,7 @@ class ChainSearchEngine: """ # Strategy priority: parameter > config > default effective_strategy = strategy - valid_strategies = ("binary", "hybrid", "binary_rerank", "dense_rerank", "staged") + valid_strategies = ("binary", "binary_rerank", "dense_rerank", "staged") if effective_strategy is None: # Not passed via parameter, check config if self._config is not None: @@ -889,7 +736,7 @@ class ChainSearchEngine: elif effective_strategy == "staged": return self.staged_cascade_search(query, source_path, k, coarse_k, options) else: - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.binary_cascade_search(query, source_path, k, coarse_k, options) def staged_cascade_search( self, @@ -943,9 +790,9 @@ class ChainSearchEngine: """ if not NUMPY_AVAILABLE: self.logger.warning( - "NumPy not available, falling back to hybrid cascade search" + "NumPy not available, falling back to standard search" ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) options = options or SearchOptions() start_time = time.time() @@ -1002,8 +849,8 @@ class ChainSearchEngine: ) if not coarse_results: - self.logger.debug("No binary candidates found, falling back to hybrid cascade") - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + self.logger.debug("No binary candidates found, falling back to standard search") + return self.search(query, source_path, options=options) # ========== Stage 2: LSP Graph Expansion ========== stage2_start = time.time() @@ -1534,7 +1381,7 @@ class ChainSearchEngine: 2. Stage 2 (Fine): Cross-encoder reranking for precise semantic ranking of candidates using query-document attention - This approach is typically faster than hybrid_cascade_search while + This approach is typically faster than binary_cascade_search while achieving similar or better quality through cross-encoder reranking. Performance characteristics: @@ -1565,9 +1412,9 @@ class ChainSearchEngine: """ if not NUMPY_AVAILABLE: self.logger.warning( - "NumPy not available, falling back to hybrid cascade search" + "NumPy not available, falling back to standard search" ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) options = options or SearchOptions() start_time = time.time() @@ -1611,10 +1458,10 @@ class ChainSearchEngine: from codexlens.indexing.embedding import BinaryEmbeddingBackend except ImportError as exc: self.logger.warning( - "BinaryEmbeddingBackend not available: %s, falling back to hybrid cascade", + "BinaryEmbeddingBackend not available: %s, falling back to standard search", exc ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) # Step 4: Binary coarse search (same as binary_cascade_search) binary_coarse_time = time.time() @@ -1658,7 +1505,7 @@ class ChainSearchEngine: query_binary = binary_backend.embed_packed([query])[0] except Exception as exc: self.logger.warning(f"Failed to generate binary query embedding: {exc}") - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) # Fallback to per-directory binary indexes for index_path in index_paths: @@ -1676,9 +1523,9 @@ class ChainSearchEngine: ) if not coarse_candidates: - self.logger.info("No binary candidates found, falling back to hybrid cascade for reranking") - # Fall back to hybrid_cascade_search which uses FTS+Vector coarse + cross-encoder rerank - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + self.logger.info("No binary candidates found, falling back to standard search for reranking") + # Fall back to standard search which uses FTS+Vector + return self.search(query, source_path, options=options) # Sort by Hamming distance and take top coarse_k coarse_candidates.sort(key=lambda x: x[1]) @@ -1785,7 +1632,7 @@ class ChainSearchEngine: "Retrieved %d chunks for cross-encoder reranking", len(coarse_results) ) - # Step 6: Cross-encoder reranking (same as hybrid_cascade_search) + # Step 6: Cross-encoder reranking rerank_time = time.time() reranked_results = self._cross_encoder_rerank(query, coarse_results, top_k=k) @@ -1848,9 +1695,9 @@ class ChainSearchEngine: """ if not NUMPY_AVAILABLE: self.logger.warning( - "NumPy not available, falling back to hybrid cascade search" + "NumPy not available, falling back to standard search" ) - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) options = options or SearchOptions() start_time = time.time() @@ -1955,7 +1802,7 @@ class ChainSearchEngine: self.logger.debug(f"Dense query embedding: {query_dense.shape[0]}-dim via {embedding_backend}/{embedding_model}") except Exception as exc: self.logger.warning(f"Failed to generate dense query embedding: {exc}") - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + return self.search(query, source_path, options=options) # Step 5: Dense coarse search using centralized HNSW index coarse_candidates: List[Tuple[int, float, Path]] = [] # (chunk_id, distance, index_path) @@ -2006,8 +1853,8 @@ class ChainSearchEngine: ) if not coarse_candidates: - self.logger.info("No dense candidates found, falling back to hybrid cascade") - return self.hybrid_cascade_search(query, source_path, k, coarse_k, options) + self.logger.info("No dense candidates found, falling back to standard search") + return self.search(query, source_path, options=options) # Sort by distance (ascending for cosine distance) and take top coarse_k coarse_candidates.sort(key=lambda x: x[1]) @@ -2972,7 +2819,6 @@ class ChainSearchEngine: options.enable_fuzzy, options.enable_vector, options.pure_vector, - options.enable_splade, options.hybrid_weights ): idx_path for idx_path in index_paths @@ -3001,7 +2847,6 @@ class ChainSearchEngine: enable_fuzzy: bool = True, enable_vector: bool = False, pure_vector: bool = False, - enable_splade: bool = False, hybrid_weights: Optional[Dict[str, float]] = None) -> List[SearchResult]: """Search a single index database. @@ -3017,7 +2862,6 @@ class ChainSearchEngine: enable_fuzzy: Enable fuzzy FTS in hybrid mode enable_vector: Enable vector semantic search pure_vector: If True, only use vector search without FTS fallback - enable_splade: If True, force SPLADE sparse neural search hybrid_weights: Custom RRF weights for hybrid search Returns: @@ -3034,7 +2878,6 @@ class ChainSearchEngine: enable_fuzzy=enable_fuzzy, enable_vector=enable_vector, pure_vector=pure_vector, - enable_splade=enable_splade, ) else: # Single-FTS search (exact or fuzzy mode) diff --git a/codex-lens/src/codexlens/search/hybrid_search.py b/codex-lens/src/codexlens/search/hybrid_search.py index 805b2fdb..cdc37277 100644 --- a/codex-lens/src/codexlens/search/hybrid_search.py +++ b/codex-lens/src/codexlens/search/hybrid_search.py @@ -35,7 +35,6 @@ from codexlens.config import VECTORS_HNSW_NAME from codexlens.entities import SearchResult from codexlens.search.ranking import ( DEFAULT_WEIGHTS, - FTS_FALLBACK_WEIGHTS, QueryIntent, apply_symbol_boost, cross_encoder_rerank, @@ -57,14 +56,6 @@ except ImportError: HAS_LSP = False -# Three-way fusion weights (FTS + Vector + SPLADE) -THREE_WAY_WEIGHTS = { - "exact": 0.2, - "splade": 0.3, - "vector": 0.5, -} - - class HybridSearchEngine: """Hybrid search engine with parallel execution and RRF fusion. @@ -77,8 +68,7 @@ class HybridSearchEngine: """ # NOTE: DEFAULT_WEIGHTS imported from ranking.py - single source of truth - # Default RRF weights: SPLADE-based hybrid (splade: 0.4, vector: 0.6) - # FTS fallback mode uses FTS_FALLBACK_WEIGHTS (exact: 0.3, fuzzy: 0.1, vector: 0.6) + # FTS + vector hybrid mode (exact: 0.3, fuzzy: 0.1, vector: 0.6) def __init__( self, @@ -119,7 +109,6 @@ class HybridSearchEngine: enable_fuzzy: bool = True, enable_vector: bool = False, pure_vector: bool = False, - enable_splade: bool = False, enable_lsp_graph: bool = False, lsp_max_depth: int = 1, lsp_max_nodes: int = 20, @@ -133,7 +122,6 @@ class HybridSearchEngine: enable_fuzzy: Enable fuzzy FTS search (default True) enable_vector: Enable vector search (default False) pure_vector: If True, only use vector search without FTS fallback (default False) - enable_splade: If True, force SPLADE sparse neural search (default False) enable_lsp_graph: If True, enable real-time LSP graph expansion (default False) lsp_max_depth: Maximum depth for LSP graph BFS expansion (default 1) lsp_max_nodes: Maximum nodes to collect in LSP graph (default 20) @@ -150,9 +138,6 @@ class HybridSearchEngine: >>> results = engine.search(Path("project/_index.db"), ... "how to authenticate users", ... enable_vector=True, pure_vector=True) - >>> # SPLADE sparse neural search - >>> results = engine.search(Path("project/_index.db"), "auth flow", - ... enable_splade=True, enable_vector=True) >>> # With LSP graph expansion (real-time) >>> results = engine.search(Path("project/_index.db"), "auth flow", ... enable_vector=True, enable_lsp_graph=True) @@ -180,26 +165,6 @@ class HybridSearchEngine: # Determine which backends to use backends = {} - # Check if SPLADE is available - splade_available = False - # Respect config.enable_splade flag and use_fts_fallback flag - if self._config and getattr(self._config, 'use_fts_fallback', False): - # Config explicitly requests FTS fallback - disable SPLADE - splade_available = False - elif self._config and not getattr(self._config, 'enable_splade', True): - # Config explicitly disabled SPLADE - splade_available = False - else: - # Check if SPLADE dependencies are available - try: - from codexlens.semantic.splade_encoder import check_splade_available - ok, _ = check_splade_available() - if ok: - # SPLADE tables are in main index database, will check table existence in _search_splade - splade_available = True - except Exception: - pass - if pure_vector: # Pure vector mode: only use vector search, no FTS fallback if enable_vector: @@ -212,37 +177,13 @@ class HybridSearchEngine: "To use pure vector search, enable vector search mode." ) backends["exact"] = True - elif enable_splade: - # Explicit SPLADE mode requested via CLI --method splade - if splade_available: - backends["splade"] = True - if enable_vector: - backends["vector"] = True - else: - # SPLADE requested but not available - warn and fallback - self.logger.warning( - "SPLADE search requested but not available. " - "Falling back to FTS. Run 'codexlens index splade' to enable." - ) - backends["exact"] = True - if enable_fuzzy: - backends["fuzzy"] = True - if enable_vector: - backends["vector"] = True else: - # Hybrid mode: default to SPLADE if available, otherwise use FTS - if splade_available: - # Default: enable SPLADE, disable exact and fuzzy - backends["splade"] = True - if enable_vector: - backends["vector"] = True - else: - # Fallback mode: enable exact+fuzzy when SPLADE unavailable - backends["exact"] = True - if enable_fuzzy: - backends["fuzzy"] = True - if enable_vector: - backends["vector"] = True + # Standard hybrid mode: FTS + optional vector + backends["exact"] = True + if enable_fuzzy: + backends["fuzzy"] = True + if enable_vector: + backends["vector"] = True # Add LSP graph expansion if requested and available if enable_lsp_graph and HAS_LSP: @@ -502,13 +443,6 @@ class HybridSearchEngine: ) future_to_source[future] = "vector" - if backends.get("splade"): - submit_times["splade"] = time.perf_counter() - future = executor.submit( - self._search_splade, index_path, query, limit - ) - future_to_source[future] = "splade" - if backends.get("lsp_graph"): submit_times["lsp_graph"] = time.perf_counter() future = executor.submit( @@ -599,8 +533,7 @@ class HybridSearchEngine: def _find_vectors_hnsw(self, index_path: Path) -> Optional[Path]: """Find the centralized _vectors.hnsw file by traversing up from index_path. - Similar to _search_splade's approach, this method searches for the - centralized dense vector index file in parent directories. + Searches for the centralized dense vector index file in parent directories. Args: index_path: Path to the current _index.db file @@ -1138,124 +1071,6 @@ class HybridSearchEngine: self.logger.error("Vector search error: %s", exc) return [] - def _search_splade( - self, index_path: Path, query: str, limit: int - ) -> List[SearchResult]: - """SPLADE sparse retrieval via inverted index. - - Args: - index_path: Path to _index.db file - query: Natural language query string - limit: Maximum results - - Returns: - List of SearchResult ordered by SPLADE score - """ - try: - from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available - from codexlens.storage.splade_index import SpladeIndex - from codexlens.config import SPLADE_DB_NAME - import sqlite3 - import json - - # Check dependencies - ok, err = check_splade_available() - if not ok: - self.logger.debug("SPLADE not available: %s", err) - return [] - - # SPLADE index is stored in _splade.db at the project index root - # Traverse up from the current index to find the root _splade.db - current_dir = index_path.parent - splade_db_path = None - for _ in range(10): # Limit search depth - candidate = current_dir / SPLADE_DB_NAME - if candidate.exists(): - splade_db_path = candidate - break - parent = current_dir.parent - if parent == current_dir: # Reached root - break - current_dir = parent - - if not splade_db_path: - self.logger.debug("SPLADE index not found in ancestor directories of %s", index_path) - return [] - - splade_index = SpladeIndex(splade_db_path) - if not splade_index.has_index(): - self.logger.debug("SPLADE index not initialized") - return [] - - # Encode query to sparse vector - encoder = get_splade_encoder(use_gpu=self._use_gpu) - query_sparse = encoder.encode_text(query) - - # Search inverted index for top matches - raw_results = splade_index.search(query_sparse, limit=limit, min_score=0.0) - - if not raw_results: - return [] - - # Fetch chunk details from splade_chunks table (self-contained) - chunk_ids = [chunk_id for chunk_id, _ in raw_results] - score_map = {chunk_id: score for chunk_id, score in raw_results} - - # Get chunk metadata from SPLADE database - rows = splade_index.get_chunks_by_ids(chunk_ids) - - # Build SearchResult objects - results = [] - for row in rows: - chunk_id = row["id"] - file_path = row["file_path"] - content = row["content"] - metadata_json = row["metadata"] - metadata = json.loads(metadata_json) if metadata_json else {} - - score = score_map.get(chunk_id, 0.0) - - # Build excerpt (short preview) - excerpt = content[:200] + "..." if len(content) > 200 else content - - # Extract symbol information from metadata - symbol_name = metadata.get("symbol_name") - symbol_kind = metadata.get("symbol_kind") - start_line = metadata.get("start_line") - end_line = metadata.get("end_line") - - # Build Symbol object if we have symbol info - symbol = None - if symbol_name and symbol_kind and start_line and end_line: - try: - from codexlens.entities import Symbol - symbol = Symbol( - name=symbol_name, - kind=symbol_kind, - range=(start_line, end_line) - ) - except Exception: - pass - - results.append(SearchResult( - path=file_path, - score=score, - excerpt=excerpt, - content=content, - symbol=symbol, - metadata=metadata, - start_line=start_line, - end_line=end_line, - symbol_name=symbol_name, - symbol_kind=symbol_kind, - )) - - return results - - except Exception as exc: - self.logger.debug("SPLADE search error: %s", exc) - return [] - def _search_lsp_graph( self, index_path: Path, @@ -1295,21 +1110,14 @@ class HybridSearchEngine: if seeds: seed_source = "vector" - # 2. Fallback to SPLADE if vector returns nothing + # 2. Fallback to exact FTS if vector returns nothing if not seeds: - self.logger.debug("Vector search returned no seeds, trying SPLADE") - seeds = self._search_splade(index_path, query, limit=3) - if seeds: - seed_source = "splade" - - # 3. Fallback to exact FTS if SPLADE also fails - if not seeds: - self.logger.debug("SPLADE returned no seeds, trying exact FTS") + self.logger.debug("Vector search returned no seeds, trying exact FTS") seeds = self._search_exact(index_path, query, limit=3) if seeds: seed_source = "exact_fts" - # 4. No seeds available from any source + # 3. No seeds available from any source if not seeds: self.logger.debug("No seed results available for LSP graph expansion") return [] diff --git a/codex-lens/src/codexlens/search/ranking.py b/codex-lens/src/codexlens/search/ranking.py index 256c78bd..a578466b 100644 --- a/codex-lens/src/codexlens/search/ranking.py +++ b/codex-lens/src/codexlens/search/ranking.py @@ -1,7 +1,7 @@ """Ranking algorithms for hybrid search result fusion. Implements Reciprocal Rank Fusion (RRF) and score normalization utilities -for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search). +for combining results from heterogeneous search backends (exact FTS, fuzzy FTS, vector search). """ from __future__ import annotations @@ -15,19 +15,12 @@ from typing import Any, Dict, List, Optional from codexlens.entities import SearchResult, AdditionalLocation -# Default RRF weights for SPLADE-based hybrid search +# Default RRF weights for hybrid search DEFAULT_WEIGHTS = { - "splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1) - "vector": 0.5, - "lsp_graph": 0.15, # Real-time LSP-based graph expansion -} - -# Legacy weights for FTS fallback mode (when SPLADE unavailable) -FTS_FALLBACK_WEIGHTS = { "exact": 0.25, "fuzzy": 0.1, "vector": 0.5, - "lsp_graph": 0.15, # Real-time LSP-based graph expansion + "lsp_graph": 0.15, } @@ -105,22 +98,13 @@ def adjust_weights_by_intent( base_weights: Dict[str, float], ) -> Dict[str, float]: """Adjust RRF weights based on query intent.""" - # Check if using SPLADE or FTS mode - use_splade = "splade" in base_weights - if intent == QueryIntent.KEYWORD: - if use_splade: - target = {"splade": 0.6, "vector": 0.4} - else: - target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4} + target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4} elif intent == QueryIntent.SEMANTIC: - if use_splade: - target = {"splade": 0.3, "vector": 0.7} - else: - target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7} + target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7} else: target = dict(base_weights) - + # Filter to active backends keys = list(base_weights.keys()) filtered = {k: float(target.get(k, 0.0)) for k in keys} @@ -225,7 +209,7 @@ def simple_weighted_fusion( Args: results_map: Dictionary mapping source name to list of SearchResult objects - Sources: 'exact', 'fuzzy', 'vector', 'splade' + Sources: 'exact', 'fuzzy', 'vector' weights: Dictionary mapping source name to weight (default: equal weights) Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6} @@ -331,14 +315,11 @@ def reciprocal_rank_fusion( RRF formula: score(d) = Σ weight_source / (k + rank_source(d)) - Supports three-way fusion with FTS, Vector, and SPLADE sources. - Args: results_map: Dictionary mapping source name to list of SearchResult objects - Sources: 'exact', 'fuzzy', 'vector', 'splade' + Sources: 'exact', 'fuzzy', 'vector' weights: Dictionary mapping source name to weight (default: equal weights) Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6} - Or: {'splade': 0.4, 'vector': 0.6} k: Constant to avoid division by zero and control rank influence (default 60) Returns: @@ -349,14 +330,6 @@ def reciprocal_rank_fusion( >>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")] >>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results} >>> fused = reciprocal_rank_fusion(results_map) - - # Three-way fusion with SPLADE - >>> results_map = { - ... 'exact': exact_results, - ... 'vector': vector_results, - ... 'splade': splade_results - ... } - >>> fused = reciprocal_rank_fusion(results_map, k=60) """ if not results_map: return [] diff --git a/codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md b/codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md deleted file mode 100644 index cb3062e1..00000000 --- a/codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md +++ /dev/null @@ -1,225 +0,0 @@ -# SPLADE Encoder Implementation - -## Overview - -Created `splade_encoder.py` - A complete ONNX-optimized SPLADE sparse encoder for code search. - -## File Location - -`src/codexlens/semantic/splade_encoder.py` (474 lines) - -## Key Components - -### 1. Dependency Checking - -**Function**: `check_splade_available() -> Tuple[bool, Optional[str]]` -- Validates numpy, onnxruntime, optimum, transformers availability -- Returns (True, None) if all dependencies present -- Returns (False, error_message) with install instructions if missing - -### 2. Caching System - -**Global Cache**: Thread-safe singleton pattern -- `_splade_cache: Dict[str, SpladeEncoder]` - Global encoder cache -- `_cache_lock: threading.RLock()` - Thread safety lock - -**Factory Function**: `get_splade_encoder(...) -> SpladeEncoder` -- Cache key includes: model_name, gpu/cpu, max_length, sparsity_threshold -- Pre-loads model on first access -- Returns cached instance on subsequent calls - -**Cleanup Function**: `clear_splade_cache() -> None` -- Releases ONNX resources -- Clears model and tokenizer references -- Prevents memory leaks - -### 3. SpladeEncoder Class - -#### Initialization Parameters -- `model_name: str` - Default: "naver/splade-cocondenser-ensembledistil" -- `use_gpu: bool` - Enable GPU acceleration (default: True) -- `max_length: int` - Max sequence length (default: 512) -- `sparsity_threshold: float` - Min weight threshold (default: 0.01) -- `providers: Optional[List]` - Explicit ONNX providers (overrides use_gpu) - -#### Core Methods - -**`_load_model()`**: Lazy loading with GPU support -- Uses `optimum.onnxruntime.ORTModelForMaskedLM` -- Falls back to CPU if GPU unavailable -- Integrates with `gpu_support.get_optimal_providers()` -- Handles device_id options for DirectML/CUDA - -**`_splade_activation(logits, attention_mask)`**: Static method -- Formula: `log(1 + ReLU(logits)) * attention_mask` -- Input: (batch, seq_len, vocab_size) -- Output: (batch, seq_len, vocab_size) - -**`_max_pooling(splade_repr)`**: Static method -- Max pooling over sequence dimension -- Input: (batch, seq_len, vocab_size) -- Output: (batch, vocab_size) - -**`_to_sparse_dict(dense_vec)`**: Conversion helper -- Filters by sparsity_threshold -- Returns: `Dict[int, float]` mapping token_id to weight - -**`encode_text(text: str) -> Dict[int, float]`**: Single text encoding -- Tokenizes input with truncation/padding -- Forward pass through ONNX model -- Applies SPLADE activation + max pooling -- Returns sparse vector - -**`encode_batch(texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]`**: Batch encoding -- Processes in batches for memory efficiency -- Returns list of sparse vectors - -#### Properties - -**`vocab_size: int`**: Vocabulary size (~30k for BERT) -- Cached after first model load -- Returns tokenizer length - -#### Debugging Methods - -**`get_token(token_id: int) -> str`** -- Converts token_id to human-readable string -- Uses tokenizer.decode() - -**`get_top_tokens(sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]`** -- Extracts top-k highest-weight tokens -- Returns (token_string, weight) pairs -- Useful for understanding model focus - -## Design Patterns Followed - -### 1. From `onnx_reranker.py` -✓ ONNX loading with provider detection -✓ Lazy model initialization -✓ Thread-safe loading with RLock -✓ Signature inspection for backward compatibility -✓ Fallback for older Optimum versions -✓ Static helper methods for numerical operations - -### 2. From `embedder.py` -✓ Global cache with thread safety -✓ Factory function pattern (get_splade_encoder) -✓ Cache cleanup function (clear_splade_cache) -✓ GPU provider configuration -✓ Batch processing support - -### 3. From `gpu_support.py` -✓ `get_optimal_providers(use_gpu, with_device_options=True)` -✓ Device ID options for DirectML/CUDA -✓ Provider tuple format: (provider_name, options_dict) - -## SPLADE Algorithm - -### Activation Formula -```python -# Step 1: ReLU activation -relu_logits = max(0, logits) - -# Step 2: Log(1 + x) transformation -log_relu = log(1 + relu_logits) - -# Step 3: Apply attention mask -splade_repr = log_relu * attention_mask - -# Step 4: Max pooling over sequence -splade_vec = max(splade_repr, axis=sequence_length) - -# Step 5: Sparsification by threshold -sparse_dict = {token_id: weight for token_id, weight in enumerate(splade_vec) if weight > threshold} -``` - -### Output Format -- Sparse dictionary: `{token_id: weight}` -- Token IDs: 0 to vocab_size-1 (typically ~30,000) -- Weights: Float values > sparsity_threshold -- Interpretable: Can decode token_ids to strings - -## Integration Points - -### With `splade_index.py` -- `SpladeIndex.add_posting(chunk_id, sparse_vec: Dict[int, float])` -- `SpladeIndex.search(query_sparse: Dict[int, float])` -- Encoder produces the sparse vectors consumed by index - -### With Indexing Pipeline -```python -encoder = get_splade_encoder(use_gpu=True) - -# Single document -sparse_vec = encoder.encode_text("def main():\n print('hello')") -index.add_posting(chunk_id=1, sparse_vec=sparse_vec) - -# Batch indexing -texts = ["code chunk 1", "code chunk 2", ...] -sparse_vecs = encoder.encode_batch(texts, batch_size=64) -postings = [(chunk_id, vec) for chunk_id, vec in enumerate(sparse_vecs)] -index.add_postings_batch(postings) -``` - -### With Search Pipeline -```python -encoder = get_splade_encoder(use_gpu=True) -query_sparse = encoder.encode_text("authentication function") -results = index.search(query_sparse, limit=50, min_score=0.5) -``` - -## Dependencies - -Required packages: -- `numpy` - Numerical operations -- `onnxruntime` - ONNX model execution (CPU) -- `onnxruntime-gpu` - ONNX with GPU support (optional) -- `optimum[onnxruntime]` - Hugging Face ONNX optimization -- `transformers` - Tokenizer and model loading - -Install command: -```bash -# CPU only -pip install numpy onnxruntime optimum[onnxruntime] transformers - -# With GPU support -pip install numpy onnxruntime-gpu optimum[onnxruntime-gpu] transformers -``` - -## Testing Status - -✓ Python syntax validation passed -✓ Module import successful -✓ Dependency checking works correctly -✗ Full functional test pending (requires optimum installation) - -## Next Steps - -1. Install dependencies for functional testing -2. Create unit tests in `tests/semantic/test_splade_encoder.py` -3. Benchmark encoding performance (CPU vs GPU) -4. Integrate with codex-lens indexing pipeline -5. Add SPLADE option to semantic search configuration - -## Performance Considerations - -### Memory Usage -- Model size: ~100MB (ONNX optimized) -- Sparse vectors: ~100-500 non-zero entries per document -- Batch size: 32 recommended (adjust based on GPU memory) - -### Speed Benchmarks (Expected) -- CPU encoding: ~10-20 docs/sec -- GPU encoding (CUDA): ~100-200 docs/sec -- GPU encoding (DirectML): ~50-100 docs/sec - -### Sparsity Analysis -- Threshold 0.01: ~200-400 tokens per document -- Threshold 0.05: ~100-200 tokens per document -- Threshold 0.10: ~50-100 tokens per document - -## References - -- SPLADE paper: https://arxiv.org/abs/2107.05720 -- SPLADE v2: https://arxiv.org/abs/2109.10086 -- Naver model: https://huggingface.co/naver/splade-cocondenser-ensembledistil diff --git a/codex-lens/src/codexlens/semantic/splade_encoder.py b/codex-lens/src/codexlens/semantic/splade_encoder.py deleted file mode 100644 index de92c69d..00000000 --- a/codex-lens/src/codexlens/semantic/splade_encoder.py +++ /dev/null @@ -1,567 +0,0 @@ -"""ONNX-optimized SPLADE sparse encoder for code search. - -This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime -for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors -that combine the interpretability of BM25 with neural relevance modeling. - -Install (CPU): - pip install onnxruntime optimum[onnxruntime] transformers - -Install (GPU): - pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers -""" - -from __future__ import annotations - -import logging -import threading -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -def check_splade_available() -> Tuple[bool, Optional[str]]: - """Check whether SPLADE dependencies are available. - - Returns: - Tuple of (available: bool, error_message: Optional[str]) - """ - try: - import numpy # noqa: F401 - except ImportError as exc: - return False, f"numpy not available: {exc}. Install with: pip install numpy" - - try: - import onnxruntime # noqa: F401 - except ImportError as exc: - return ( - False, - f"onnxruntime not available: {exc}. Install with: pip install onnxruntime", - ) - - try: - from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401 - except ImportError as exc: - return ( - False, - f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]", - ) - - try: - from transformers import AutoTokenizer # noqa: F401 - except ImportError as exc: - return ( - False, - f"transformers not available: {exc}. Install with: pip install transformers", - ) - - return True, None - - -# Global cache for SPLADE encoders (singleton pattern) -_splade_cache: Dict[str, "SpladeEncoder"] = {} -_cache_lock = threading.RLock() - - -def get_splade_encoder( - model_name: str = "naver/splade-cocondenser-ensembledistil", - use_gpu: bool = True, - max_length: int = 512, - sparsity_threshold: float = 0.01, - cache_dir: Optional[str] = None, -) -> "SpladeEncoder": - """Get or create cached SPLADE encoder (thread-safe singleton). - - This function provides significant performance improvement by reusing - SpladeEncoder instances across multiple searches, avoiding repeated model - loading overhead. - - Args: - model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil) - use_gpu: If True, use GPU acceleration when available - max_length: Maximum sequence length for tokenization - sparsity_threshold: Minimum weight to include in sparse vector - cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade) - - Returns: - Cached SpladeEncoder instance for the given configuration - """ - global _splade_cache - - # Cache key includes all configuration parameters - cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}" - - with _cache_lock: - encoder = _splade_cache.get(cache_key) - if encoder is not None: - return encoder - - # Create new encoder and cache it - encoder = SpladeEncoder( - model_name=model_name, - use_gpu=use_gpu, - max_length=max_length, - sparsity_threshold=sparsity_threshold, - cache_dir=cache_dir, - ) - # Pre-load model to ensure it's ready - encoder._load_model() - _splade_cache[cache_key] = encoder - - return encoder - - -def clear_splade_cache() -> None: - """Clear the SPLADE encoder cache and release ONNX resources. - - This method ensures proper cleanup of ONNX model resources to prevent - memory leaks when encoders are no longer needed. - """ - global _splade_cache - with _cache_lock: - # Release ONNX resources before clearing cache - for encoder in _splade_cache.values(): - if encoder._model is not None: - del encoder._model - encoder._model = None - if encoder._tokenizer is not None: - del encoder._tokenizer - encoder._tokenizer = None - _splade_cache.clear() - - -class SpladeEncoder: - """ONNX-optimized SPLADE sparse encoder. - - Produces sparse vectors with vocabulary-aligned dimensions. - Output: Dict[int, float] mapping token_id to weight. - - SPLADE activation formula: - splade_repr = log(1 + ReLU(logits)) * attention_mask - splade_vec = max_pooling(splade_repr, axis=sequence_length) - - References: - - SPLADE: https://arxiv.org/abs/2107.05720 - - SPLADE v2: https://arxiv.org/abs/2109.10086 - """ - - DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil" - - def __init__( - self, - model_name: str = DEFAULT_MODEL, - use_gpu: bool = True, - max_length: int = 512, - sparsity_threshold: float = 0.01, - providers: Optional[List[Any]] = None, - cache_dir: Optional[str] = None, - ) -> None: - """Initialize SPLADE encoder. - - Args: - model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil) - use_gpu: If True, use GPU acceleration when available - max_length: Maximum sequence length for tokenization - sparsity_threshold: Minimum weight to include in sparse vector - providers: Explicit ONNX providers list (overrides use_gpu) - cache_dir: Directory to cache ONNX models (default: ~/.cache/codexlens/splade) - """ - self.model_name = (model_name or self.DEFAULT_MODEL).strip() - if not self.model_name: - raise ValueError("model_name cannot be blank") - - self.use_gpu = bool(use_gpu) - self.max_length = int(max_length) if max_length > 0 else 512 - self.sparsity_threshold = float(sparsity_threshold) - self.providers = providers - - # Setup ONNX cache directory - if cache_dir: - self._cache_dir = Path(cache_dir) - else: - self._cache_dir = Path.home() / ".cache" / "codexlens" / "splade" - - self._tokenizer: Any | None = None - self._model: Any | None = None - self._vocab_size: int | None = None - self._lock = threading.RLock() - - def _get_local_cache_path(self) -> Path: - """Get local cache path for this model's ONNX files. - - Returns: - Path to the local ONNX cache directory for this model - """ - # Replace / with -- for filesystem-safe naming - safe_name = self.model_name.replace("/", "--") - return self._cache_dir / safe_name - - def _load_model(self) -> None: - """Lazy load ONNX model and tokenizer. - - First checks local cache for ONNX model, falling back to - HuggingFace download and conversion if not cached. - """ - if self._model is not None and self._tokenizer is not None: - return - - ok, err = check_splade_available() - if not ok: - raise ImportError(err) - - with self._lock: - if self._model is not None and self._tokenizer is not None: - return - - from inspect import signature - - from optimum.onnxruntime import ORTModelForMaskedLM - from transformers import AutoTokenizer - - if self.providers is None: - from .gpu_support import get_optimal_providers, get_selected_device_id - - # Get providers as pure string list (cache-friendly) - # NOTE: with_device_options=False to avoid tuple-based providers - # which break optimum's caching mechanism - self.providers = get_optimal_providers( - use_gpu=self.use_gpu, with_device_options=False - ) - # Get device_id separately for provider_options - self._device_id = get_selected_device_id() if self.use_gpu else None - - # Some Optimum versions accept `providers`, others accept a single `provider` - # Prefer passing the full providers list, with a conservative fallback - model_kwargs: dict[str, Any] = {} - try: - params = signature(ORTModelForMaskedLM.from_pretrained).parameters - if "providers" in params: - model_kwargs["providers"] = self.providers - # Pass device_id via provider_options for GPU selection - if "provider_options" in params and hasattr(self, '_device_id') and self._device_id is not None: - # Build provider_options dict for each GPU provider - provider_options = {} - for p in self.providers: - if p in ("DmlExecutionProvider", "CUDAExecutionProvider", "ROCMExecutionProvider"): - provider_options[p] = {"device_id": self._device_id} - if provider_options: - model_kwargs["provider_options"] = provider_options - elif "provider" in params: - provider_name = "CPUExecutionProvider" - if self.providers: - first = self.providers[0] - provider_name = first[0] if isinstance(first, tuple) else str(first) - model_kwargs["provider"] = provider_name - except Exception as e: - logger.debug(f"Failed to inspect ORTModel signature: {e}") - model_kwargs = {} - - # Check for local ONNX cache first - local_cache = self._get_local_cache_path() - onnx_model_path = local_cache / "model.onnx" - - if onnx_model_path.exists(): - # Load from local cache - logger.info(f"Loading SPLADE from local cache: {local_cache}") - try: - self._model = ORTModelForMaskedLM.from_pretrained( - str(local_cache), - **model_kwargs, - ) - self._tokenizer = AutoTokenizer.from_pretrained( - str(local_cache), use_fast=True - ) - self._vocab_size = len(self._tokenizer) - logger.info( - f"SPLADE loaded from cache: {self.model_name}, vocab={self._vocab_size}" - ) - return - except Exception as e: - logger.warning(f"Failed to load from cache, redownloading: {e}") - - # Download and convert from HuggingFace - logger.info(f"Downloading SPLADE model: {self.model_name}") - try: - self._model = ORTModelForMaskedLM.from_pretrained( - self.model_name, - export=True, # Export to ONNX - **model_kwargs, - ) - logger.debug(f"SPLADE model loaded: {self.model_name}") - except TypeError: - # Fallback for older Optimum versions: retry without provider arguments - self._model = ORTModelForMaskedLM.from_pretrained( - self.model_name, - export=True, - ) - logger.warning( - "Optimum version doesn't support provider parameters. " - "Upgrade optimum for GPU acceleration: pip install --upgrade optimum" - ) - - self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) - - # Cache vocabulary size - self._vocab_size = len(self._tokenizer) - logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}") - - # Save to local cache for future use - try: - local_cache.mkdir(parents=True, exist_ok=True) - self._model.save_pretrained(str(local_cache)) - self._tokenizer.save_pretrained(str(local_cache)) - logger.info(f"SPLADE model cached to: {local_cache}") - except Exception as e: - logger.warning(f"Failed to cache SPLADE model: {e}") - - @staticmethod - def _splade_activation(logits: Any, attention_mask: Any) -> Any: - """Apply SPLADE activation function to model outputs. - - Formula: log(1 + ReLU(logits)) * attention_mask - - Args: - logits: Model output logits (batch, seq_len, vocab_size) - attention_mask: Attention mask (batch, seq_len) - - Returns: - SPLADE representations (batch, seq_len, vocab_size) - """ - import numpy as np - - # ReLU activation - relu_logits = np.maximum(0, logits) - - # Log(1 + x) transformation - log_relu = np.log1p(relu_logits) - - # Apply attention mask (expand to match vocab dimension) - # attention_mask: (batch, seq_len) -> (batch, seq_len, 1) - mask_expanded = np.expand_dims(attention_mask, axis=-1) - - # Element-wise multiplication - splade_repr = log_relu * mask_expanded - - return splade_repr - - @staticmethod - def _max_pooling(splade_repr: Any) -> Any: - """Max pooling over sequence length dimension. - - Args: - splade_repr: SPLADE representations (batch, seq_len, vocab_size) - - Returns: - Pooled sparse vectors (batch, vocab_size) - """ - import numpy as np - - # Max pooling over sequence dimension (axis=1) - return np.max(splade_repr, axis=1) - - def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]: - """Convert dense vector to sparse dictionary. - - Args: - dense_vec: Dense vector (vocab_size,) - - Returns: - Sparse dictionary {token_id: weight} with weights above threshold - """ - import numpy as np - - # Find non-zero indices above threshold - nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0] - - # Create sparse dictionary - sparse_dict = { - int(idx): float(dense_vec[idx]) - for idx in nonzero_indices - } - - return sparse_dict - - def warmup(self, text: str = "warmup query") -> None: - """Warmup the encoder by running a dummy inference. - - First-time model inference includes initialization overhead. - Call this method once before the first real search to avoid - latency spikes. - - Args: - text: Dummy text for warmup (default: "warmup query") - """ - logger.info("Warming up SPLADE encoder...") - # Trigger model loading and first inference - _ = self.encode_text(text) - logger.info("SPLADE encoder warmup complete") - - def encode_text(self, text: str) -> Dict[int, float]: - """Encode text to sparse vector {token_id: weight}. - - Args: - text: Input text to encode - - Returns: - Sparse vector as dictionary mapping token_id to weight - """ - self._load_model() - - if self._model is None or self._tokenizer is None: - raise RuntimeError("Model not loaded") - - import numpy as np - - # Tokenize input - encoded = self._tokenizer( - text, - padding=True, - truncation=True, - max_length=self.max_length, - return_tensors="np", - ) - - # Forward pass through model - outputs = self._model(**encoded) - - # Extract logits - if hasattr(outputs, "logits"): - logits = outputs.logits - elif isinstance(outputs, dict) and "logits" in outputs: - logits = outputs["logits"] - elif isinstance(outputs, (list, tuple)) and outputs: - logits = outputs[0] - else: - raise RuntimeError("Unexpected model output format") - - # Apply SPLADE activation - attention_mask = encoded["attention_mask"] - splade_repr = self._splade_activation(logits, attention_mask) - - # Max pooling over sequence length - splade_vec = self._max_pooling(splade_repr) - - # Convert to sparse dictionary (single item batch) - sparse_dict = self._to_sparse_dict(splade_vec[0]) - - return sparse_dict - - def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]: - """Batch encode texts to sparse vectors. - - Args: - texts: List of input texts to encode - batch_size: Batch size for encoding (default: 32) - - Returns: - List of sparse vectors as dictionaries - """ - if not texts: - return [] - - self._load_model() - - if self._model is None or self._tokenizer is None: - raise RuntimeError("Model not loaded") - - import numpy as np - - results: List[Dict[int, float]] = [] - - # Process in batches - for i in range(0, len(texts), batch_size): - batch_texts = texts[i:i + batch_size] - - # Tokenize batch - encoded = self._tokenizer( - batch_texts, - padding=True, - truncation=True, - max_length=self.max_length, - return_tensors="np", - ) - - # Forward pass through model - outputs = self._model(**encoded) - - # Extract logits - if hasattr(outputs, "logits"): - logits = outputs.logits - elif isinstance(outputs, dict) and "logits" in outputs: - logits = outputs["logits"] - elif isinstance(outputs, (list, tuple)) and outputs: - logits = outputs[0] - else: - raise RuntimeError("Unexpected model output format") - - # Apply SPLADE activation - attention_mask = encoded["attention_mask"] - splade_repr = self._splade_activation(logits, attention_mask) - - # Max pooling over sequence length - splade_vecs = self._max_pooling(splade_repr) - - # Convert each vector to sparse dictionary - for vec in splade_vecs: - sparse_dict = self._to_sparse_dict(vec) - results.append(sparse_dict) - - return results - - @property - def vocab_size(self) -> int: - """Return vocabulary size (~30k for BERT-based models). - - Returns: - Vocabulary size (number of tokens in tokenizer) - """ - if self._vocab_size is not None: - return self._vocab_size - - self._load_model() - return self._vocab_size or 0 - - def get_token(self, token_id: int) -> str: - """Convert token_id to string (for debugging). - - Args: - token_id: Token ID to convert - - Returns: - Token string - """ - self._load_model() - - if self._tokenizer is None: - raise RuntimeError("Tokenizer not loaded") - - return self._tokenizer.decode([token_id]) - - def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]: - """Get top-k tokens with highest weights from sparse vector. - - Useful for debugging and understanding what the model is focusing on. - - Args: - sparse_vec: Sparse vector as {token_id: weight} - top_k: Number of top tokens to return - - Returns: - List of (token_string, weight) tuples, sorted by weight descending - """ - self._load_model() - - if not sparse_vec: - return [] - - # Sort by weight descending - sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True) - - # Take top-k and convert token_ids to strings - top_items = sorted_items[:top_k] - - return [ - (self.get_token(token_id), weight) - for token_id, weight in top_items - ] diff --git a/codex-lens/src/codexlens/storage/migrations/migration_009_add_splade.py b/codex-lens/src/codexlens/storage/migrations/migration_009_add_splade.py deleted file mode 100644 index c675233e..00000000 --- a/codex-lens/src/codexlens/storage/migrations/migration_009_add_splade.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Migration 009: Add SPLADE sparse retrieval tables. - -This migration introduces SPLADE (Sparse Lexical AnD Expansion) support: -- splade_metadata: Model configuration (model name, vocab size, ONNX path) -- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight) - -The SPLADE tables are designed for efficient sparse vector retrieval: -- Token-based lookup for query expansion -- Chunk-based deletion for index maintenance -- Maintains backward compatibility with existing FTS tables -""" - -import logging -from sqlite3 import Connection - -log = logging.getLogger(__name__) - - -def upgrade(db_conn: Connection) -> None: - """ - Adds SPLADE tables for sparse retrieval. - - Creates: - - splade_metadata: Stores model configuration and ONNX path - - splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings - - Indexes for efficient token-based and chunk-based lookups - - Args: - db_conn: The SQLite database connection. - """ - cursor = db_conn.cursor() - - log.info("Creating splade_metadata table...") - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS splade_metadata ( - id INTEGER PRIMARY KEY DEFAULT 1, - model_name TEXT NOT NULL, - vocab_size INTEGER NOT NULL, - onnx_path TEXT, - created_at REAL - ) - """ - ) - - log.info("Creating splade_posting_list table...") - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS splade_posting_list ( - token_id INTEGER NOT NULL, - chunk_id INTEGER NOT NULL, - weight REAL NOT NULL, - PRIMARY KEY (token_id, chunk_id), - FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE - ) - """ - ) - - log.info("Creating indexes for splade_posting_list...") - # Index for efficient chunk-based lookups (deletion, updates) - cursor.execute( - """ - CREATE INDEX IF NOT EXISTS idx_splade_by_chunk - ON splade_posting_list(chunk_id) - """ - ) - - # Index for efficient term-based retrieval - cursor.execute( - """ - CREATE INDEX IF NOT EXISTS idx_splade_by_token - ON splade_posting_list(token_id) - """ - ) - - log.info("Migration 009 completed successfully") - - -def downgrade(db_conn: Connection) -> None: - """ - Removes SPLADE tables. - - Drops: - - splade_posting_list (and associated indexes) - - splade_metadata - - Args: - db_conn: The SQLite database connection. - """ - cursor = db_conn.cursor() - - log.info("Dropping SPLADE indexes...") - cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk") - cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token") - - log.info("Dropping splade_posting_list table...") - cursor.execute("DROP TABLE IF EXISTS splade_posting_list") - - log.info("Dropping splade_metadata table...") - cursor.execute("DROP TABLE IF EXISTS splade_metadata") - - log.info("Migration 009 downgrade completed successfully") diff --git a/codex-lens/src/codexlens/storage/splade_index.py b/codex-lens/src/codexlens/storage/splade_index.py deleted file mode 100644 index d090a12d..00000000 --- a/codex-lens/src/codexlens/storage/splade_index.py +++ /dev/null @@ -1,578 +0,0 @@ -"""SPLADE inverted index storage for sparse vector retrieval. - -This module implements SQLite-based inverted index for SPLADE sparse vectors, -enabling efficient sparse retrieval using dot-product scoring. -""" - -from __future__ import annotations - -import logging -import sqlite3 -import threading -import time -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -from codexlens.entities import SearchResult -from codexlens.errors import StorageError - -logger = logging.getLogger(__name__) - - -class SpladeIndex: - """SQLite-based inverted index for SPLADE sparse vectors. - - Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight). - Supports efficient dot-product retrieval using SQL joins. - """ - - def __init__(self, db_path: Path | str) -> None: - """Initialize SPLADE index. - - Args: - db_path: Path to SQLite database file. - """ - self.db_path = Path(db_path) - self.db_path.parent.mkdir(parents=True, exist_ok=True) - - # Thread-safe connection management - self._lock = threading.RLock() - self._local = threading.local() - - def _get_connection(self) -> sqlite3.Connection: - """Get or create a thread-local database connection. - - Each thread gets its own connection to ensure thread safety. - Connections are stored in thread-local storage. - """ - conn = getattr(self._local, "conn", None) - if conn is None: - # Thread-local connection - each thread has its own - conn = sqlite3.connect( - self.db_path, - timeout=30.0, # Wait up to 30s for locks - check_same_thread=True, # Enforce thread safety - ) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - conn.execute("PRAGMA foreign_keys=ON") - # Limit mmap to 1GB to avoid OOM on smaller systems - conn.execute("PRAGMA mmap_size=1073741824") - # Increase cache size for better query performance (20MB = -20000 pages) - conn.execute("PRAGMA cache_size=-20000") - self._local.conn = conn - return conn - - def close(self) -> None: - """Close thread-local database connection.""" - with self._lock: - conn = getattr(self._local, "conn", None) - if conn is not None: - conn.close() - self._local.conn = None - - def __enter__(self) -> SpladeIndex: - """Context manager entry.""" - self.create_tables() - return self - - def __exit__(self, exc_type, exc, tb) -> None: - """Context manager exit.""" - self.close() - - def has_index(self) -> bool: - """Check if SPLADE tables exist in database. - - Returns: - True if tables exist, False otherwise. - """ - with self._lock: - conn = self._get_connection() - try: - cursor = conn.execute( - """ - SELECT name FROM sqlite_master - WHERE type='table' AND name='splade_posting_list' - """ - ) - return cursor.fetchone() is not None - except sqlite3.Error as e: - logger.error("Failed to check index existence: %s", e) - return False - - def create_tables(self) -> None: - """Create SPLADE schema if not exists. - - Note: When used with distributed indexes (multiple _index.db files), - the SPLADE database stores chunk IDs from multiple sources. In this case, - foreign key constraints are not enforced to allow cross-database references. - """ - with self._lock: - conn = self._get_connection() - try: - # Inverted index for sparse vectors - # Note: No FOREIGN KEY constraint to support distributed index architecture - # where chunks may come from multiple _index.db files - conn.execute(""" - CREATE TABLE IF NOT EXISTS splade_posting_list ( - token_id INTEGER NOT NULL, - chunk_id INTEGER NOT NULL, - weight REAL NOT NULL, - PRIMARY KEY (token_id, chunk_id) - ) - """) - - # Indexes for efficient lookups - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_splade_by_chunk - ON splade_posting_list(chunk_id) - """) - conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_splade_by_token - ON splade_posting_list(token_id) - """) - - # Model metadata - conn.execute(""" - CREATE TABLE IF NOT EXISTS splade_metadata ( - id INTEGER PRIMARY KEY DEFAULT 1, - model_name TEXT NOT NULL, - vocab_size INTEGER NOT NULL, - onnx_path TEXT, - created_at REAL - ) - """) - - # Chunk metadata for self-contained search results - # Stores all chunk info needed to build SearchResult without querying _index.db - conn.execute(""" - CREATE TABLE IF NOT EXISTS splade_chunks ( - id INTEGER PRIMARY KEY, - file_path TEXT NOT NULL, - content TEXT NOT NULL, - metadata TEXT, - source_db TEXT - ) - """) - - conn.commit() - logger.debug("SPLADE schema created successfully") - except sqlite3.Error as e: - raise StorageError( - f"Failed to create SPLADE schema: {e}", - db_path=str(self.db_path), - operation="create_tables" - ) from e - - def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None: - """Add a single document to inverted index. - - Args: - chunk_id: Chunk ID (foreign key to semantic_chunks.id). - sparse_vec: Sparse vector as {token_id: weight} mapping. - """ - if not sparse_vec: - logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id) - return - - with self._lock: - conn = self._get_connection() - try: - # Insert all non-zero weights for this chunk - postings = [ - (token_id, chunk_id, weight) - for token_id, weight in sparse_vec.items() - if weight > 0 # Only store non-zero weights - ] - - if postings: - conn.executemany( - """ - INSERT OR REPLACE INTO splade_posting_list - (token_id, chunk_id, weight) - VALUES (?, ?, ?) - """, - postings - ) - conn.commit() - logger.debug( - "Added %d postings for chunk_id=%d", len(postings), chunk_id - ) - except sqlite3.Error as e: - raise StorageError( - f"Failed to add posting for chunk_id={chunk_id}: {e}", - db_path=str(self.db_path), - operation="add_posting" - ) from e - - def add_postings_batch( - self, postings: List[Tuple[int, Dict[int, float]]] - ) -> None: - """Batch insert postings for multiple chunks. - - Args: - postings: List of (chunk_id, sparse_vec) tuples. - """ - if not postings: - return - - with self._lock: - conn = self._get_connection() - try: - # Flatten all postings into single batch - batch_data = [] - for chunk_id, sparse_vec in postings: - for token_id, weight in sparse_vec.items(): - if weight > 0: # Only store non-zero weights - batch_data.append((token_id, chunk_id, weight)) - - if batch_data: - conn.executemany( - """ - INSERT OR REPLACE INTO splade_posting_list - (token_id, chunk_id, weight) - VALUES (?, ?, ?) - """, - batch_data - ) - conn.commit() - logger.debug( - "Batch inserted %d postings for %d chunks", - len(batch_data), - len(postings) - ) - except sqlite3.Error as e: - raise StorageError( - f"Failed to batch insert postings: {e}", - db_path=str(self.db_path), - operation="add_postings_batch" - ) from e - - def add_chunk_metadata( - self, - chunk_id: int, - file_path: str, - content: str, - metadata: Optional[str] = None, - source_db: Optional[str] = None - ) -> None: - """Store chunk metadata for self-contained search results. - - Args: - chunk_id: Global chunk ID. - file_path: Path to source file. - content: Chunk text content. - metadata: JSON metadata string. - source_db: Path to source _index.db. - """ - with self._lock: - conn = self._get_connection() - try: - conn.execute( - """ - INSERT OR REPLACE INTO splade_chunks - (id, file_path, content, metadata, source_db) - VALUES (?, ?, ?, ?, ?) - """, - (chunk_id, file_path, content, metadata, source_db) - ) - conn.commit() - except sqlite3.Error as e: - raise StorageError( - f"Failed to add chunk metadata for chunk_id={chunk_id}: {e}", - db_path=str(self.db_path), - operation="add_chunk_metadata" - ) from e - - def add_chunks_metadata_batch( - self, - chunks: List[Tuple[int, str, str, Optional[str], Optional[str]]] - ) -> None: - """Batch insert chunk metadata. - - Args: - chunks: List of (chunk_id, file_path, content, metadata, source_db) tuples. - """ - if not chunks: - return - - with self._lock: - conn = self._get_connection() - try: - conn.executemany( - """ - INSERT OR REPLACE INTO splade_chunks - (id, file_path, content, metadata, source_db) - VALUES (?, ?, ?, ?, ?) - """, - chunks - ) - conn.commit() - logger.debug("Batch inserted %d chunk metadata records", len(chunks)) - except sqlite3.Error as e: - raise StorageError( - f"Failed to batch insert chunk metadata: {e}", - db_path=str(self.db_path), - operation="add_chunks_metadata_batch" - ) from e - - def get_chunks_by_ids(self, chunk_ids: List[int]) -> List[Dict]: - """Get chunk metadata by IDs. - - Args: - chunk_ids: List of chunk IDs to retrieve. - - Returns: - List of dicts with id, file_path, content, metadata, source_db. - """ - if not chunk_ids: - return [] - - with self._lock: - conn = self._get_connection() - try: - placeholders = ",".join("?" * len(chunk_ids)) - rows = conn.execute( - f""" - SELECT id, file_path, content, metadata, source_db - FROM splade_chunks - WHERE id IN ({placeholders}) - """, - chunk_ids - ).fetchall() - - return [ - { - "id": row["id"], - "file_path": row["file_path"], - "content": row["content"], - "metadata": row["metadata"], - "source_db": row["source_db"] - } - for row in rows - ] - except sqlite3.Error as e: - logger.error("Failed to get chunks by IDs: %s", e) - return [] - - def remove_chunk(self, chunk_id: int) -> int: - """Remove all postings for a chunk. - - Args: - chunk_id: Chunk ID to remove. - - Returns: - Number of deleted postings. - """ - with self._lock: - conn = self._get_connection() - try: - cursor = conn.execute( - "DELETE FROM splade_posting_list WHERE chunk_id = ?", - (chunk_id,) - ) - conn.commit() - deleted = cursor.rowcount - logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id) - return deleted - except sqlite3.Error as e: - raise StorageError( - f"Failed to remove chunk_id={chunk_id}: {e}", - db_path=str(self.db_path), - operation="remove_chunk" - ) from e - - def search( - self, - query_sparse: Dict[int, float], - limit: int = 50, - min_score: float = 0.0, - max_query_terms: int = 64 - ) -> List[Tuple[int, float]]: - """Search for similar chunks using dot-product scoring. - - Implements efficient sparse dot-product via SQL JOIN: - score(q, d) = sum(q[t] * d[t]) for all tokens t - - Args: - query_sparse: Query sparse vector as {token_id: weight}. - limit: Maximum number of results. - min_score: Minimum score threshold. - max_query_terms: Maximum query terms to use (default: 64). - Pruning to top-K terms reduces search time with minimal impact on quality. - Set to 0 or negative to disable pruning (use all terms). - - Returns: - List of (chunk_id, score) tuples, ordered by score descending. - """ - if not query_sparse: - logger.warning("Empty query sparse vector") - return [] - - with self._lock: - conn = self._get_connection() - try: - # Build VALUES clause for query terms - # Each term: (token_id, weight) - query_terms = [ - (token_id, weight) - for token_id, weight in query_sparse.items() - if weight > 0 - ] - - if not query_terms: - logger.warning("No non-zero query terms") - return [] - - # Query pruning: keep only top-K terms by weight - # max_query_terms <= 0 means no limit (use all terms) - if max_query_terms > 0 and len(query_terms) > max_query_terms: - query_terms = sorted(query_terms, key=lambda x: x[1], reverse=True)[:max_query_terms] - logger.debug( - "Query pruned from %d to %d terms", - len(query_sparse), - len(query_terms) - ) - - # Create CTE for query terms using parameterized VALUES - # Build placeholders and params to prevent SQL injection - params = [] - placeholders = [] - for token_id, weight in query_terms: - placeholders.append("(?, ?)") - params.extend([token_id, weight]) - - values_placeholders = ", ".join(placeholders) - - sql = f""" - WITH query_terms(token_id, weight) AS ( - VALUES {values_placeholders} - ) - SELECT - p.chunk_id, - SUM(p.weight * q.weight) as score - FROM splade_posting_list p - INNER JOIN query_terms q ON p.token_id = q.token_id - GROUP BY p.chunk_id - HAVING score >= ? - ORDER BY score DESC - LIMIT ? - """ - - # Append min_score and limit to params - params.extend([min_score, limit]) - rows = conn.execute(sql, params).fetchall() - - results = [(row["chunk_id"], float(row["score"])) for row in rows] - logger.debug( - "SPLADE search: %d query terms, %d results", - len(query_terms), - len(results) - ) - return results - - except sqlite3.Error as e: - raise StorageError( - f"SPLADE search failed: {e}", - db_path=str(self.db_path), - operation="search" - ) from e - - def get_metadata(self) -> Optional[Dict]: - """Get SPLADE model metadata. - - Returns: - Dictionary with model_name, vocab_size, onnx_path, created_at, - or None if not set. - """ - with self._lock: - conn = self._get_connection() - try: - row = conn.execute( - """ - SELECT model_name, vocab_size, onnx_path, created_at - FROM splade_metadata - WHERE id = 1 - """ - ).fetchone() - - if row: - return { - "model_name": row["model_name"], - "vocab_size": row["vocab_size"], - "onnx_path": row["onnx_path"], - "created_at": row["created_at"] - } - return None - except sqlite3.Error as e: - logger.error("Failed to get metadata: %s", e) - return None - - def set_metadata( - self, - model_name: str, - vocab_size: int, - onnx_path: Optional[str] = None - ) -> None: - """Set SPLADE model metadata. - - Args: - model_name: SPLADE model name. - vocab_size: Vocabulary size (typically ~30k for BERT vocab). - onnx_path: Optional path to ONNX model file. - """ - with self._lock: - conn = self._get_connection() - try: - current_time = time.time() - conn.execute( - """ - INSERT OR REPLACE INTO splade_metadata - (id, model_name, vocab_size, onnx_path, created_at) - VALUES (1, ?, ?, ?, ?) - """, - (model_name, vocab_size, onnx_path, current_time) - ) - conn.commit() - logger.info( - "Set SPLADE metadata: model=%s, vocab_size=%d", - model_name, - vocab_size - ) - except sqlite3.Error as e: - raise StorageError( - f"Failed to set metadata: {e}", - db_path=str(self.db_path), - operation="set_metadata" - ) from e - - def get_stats(self) -> Dict: - """Get index statistics. - - Returns: - Dictionary with total_postings, unique_tokens, unique_chunks. - """ - with self._lock: - conn = self._get_connection() - try: - row = conn.execute(""" - SELECT - COUNT(*) as total_postings, - COUNT(DISTINCT token_id) as unique_tokens, - COUNT(DISTINCT chunk_id) as unique_chunks - FROM splade_posting_list - """).fetchone() - - return { - "total_postings": row["total_postings"], - "unique_tokens": row["unique_tokens"], - "unique_chunks": row["unique_chunks"] - } - except sqlite3.Error as e: - logger.error("Failed to get stats: %s", e) - return { - "total_postings": 0, - "unique_tokens": 0, - "unique_chunks": 0 - } diff --git a/codex-lens/tests/api/test_semantic_search.py b/codex-lens/tests/api/test_semantic_search.py index ee6757bc..f44b2f61 100644 --- a/codex-lens/tests/api/test_semantic_search.py +++ b/codex-lens/tests/api/test_semantic_search.py @@ -427,12 +427,12 @@ class TestFusionStrategyMapping: mock_engine.binary_cascade_search.assert_called_once() - def test_hybrid_strategy_calls_hybrid_cascade_search(self): - """Test that hybrid strategy maps to hybrid_cascade_search.""" + def test_hybrid_strategy_maps_to_binary_rerank(self): + """Test that hybrid strategy maps to binary_rerank_cascade_search (backward compat).""" from codexlens.api.semantic import _execute_search mock_engine = MagicMock() - mock_engine.hybrid_cascade_search.return_value = MagicMock(results=[]) + mock_engine.binary_rerank_cascade_search.return_value = MagicMock(results=[]) mock_options = MagicMock() _execute_search( @@ -444,7 +444,7 @@ class TestFusionStrategyMapping: limit=20, ) - mock_engine.hybrid_cascade_search.assert_called_once() + mock_engine.binary_rerank_cascade_search.assert_called_once() def test_unknown_strategy_defaults_to_rrf(self): """Test that unknown strategy defaults to standard search (rrf).""" diff --git a/codex-lens/tests/integration/test_lsp_search_integration.py b/codex-lens/tests/integration/test_lsp_search_integration.py index 08768137..f6b68bc0 100644 --- a/codex-lens/tests/integration/test_lsp_search_integration.py +++ b/codex-lens/tests/integration/test_lsp_search_integration.py @@ -1,7 +1,7 @@ """Integration tests for HybridSearchEngine LSP graph search. Tests the _search_lsp_graph method which orchestrates: -1. Seed retrieval via vector/splade/exact fallback chain +1. Seed retrieval via vector/exact fallback chain 2. LSP graph expansion via LspBridge and LspGraphBuilder 3. Result deduplication and merging @@ -184,8 +184,6 @@ class TestP0CriticalLspSearch: with patch.object( engine, "_search_vector", return_value=[sample_search_result] ) as mock_vector, patch.object( - engine, "_search_splade", return_value=[] - ), patch.object( engine, "_search_exact", return_value=[] ): # Patch LSP module at the import location @@ -251,11 +249,10 @@ class TestP0CriticalLspSearch: sample_search_result: SearchResult, sample_code_symbol_node: CodeSymbolNode, ) -> None: - """Test seed fallback chain: vector -> splade -> exact. + """Test seed fallback chain: vector -> exact. Input: query="init_db" Mock: _search_vector returns [] - Mock: _search_splade returns [] Mock: _search_exact returns 1 seed Assert: Fallback chain called in order, uses exact's seed """ @@ -267,10 +264,6 @@ class TestP0CriticalLspSearch: call_order.append("vector") return [] - def track_splade(*args, **kwargs): - call_order.append("splade") - return [] - def track_exact(*args, **kwargs): call_order.append("exact") return [sample_search_result] @@ -284,8 +277,6 @@ class TestP0CriticalLspSearch: with patch.object( engine, "_search_vector", side_effect=track_vector ) as mock_vector, patch.object( - engine, "_search_splade", side_effect=track_splade - ) as mock_splade, patch.object( engine, "_search_exact", side_effect=track_exact ) as mock_exact: with patch("codexlens.search.hybrid_search.HAS_LSP", True): @@ -322,12 +313,11 @@ class TestP0CriticalLspSearch: max_nodes=20, ) - # Verify fallback chain order: vector -> splade -> exact - assert call_order == ["vector", "splade", "exact"] + # Verify fallback chain order: vector -> exact + assert call_order == ["vector", "exact"] - # All three methods should be called + # Both methods should be called mock_vector.assert_called_once() - mock_splade.assert_called_once() mock_exact.assert_called_once() # Should return results from graph expansion (1 related node) @@ -357,8 +347,6 @@ class TestP1ImportantLspSearch: with patch.object( engine, "_search_vector", return_value=[] ) as mock_vector, patch.object( - engine, "_search_splade", return_value=[] - ) as mock_splade, patch.object( engine, "_search_exact", return_value=[] ) as mock_exact: with patch("codexlens.search.hybrid_search.HAS_LSP", True): @@ -379,7 +367,6 @@ class TestP1ImportantLspSearch: # All search methods should be tried mock_vector.assert_called_once() - mock_splade.assert_called_once() mock_exact.assert_called_once() # Should return empty list diff --git a/codex-lens/tests/real/test_lsp_real_interface.py b/codex-lens/tests/real/test_lsp_real_interface.py index f6a69bff..587d3f90 100644 --- a/codex-lens/tests/real/test_lsp_real_interface.py +++ b/codex-lens/tests/real/test_lsp_real_interface.py @@ -303,7 +303,7 @@ class TestRealHybridSearchIntegrationStandalone: """Test the full LSP search pipeline with real LSP.""" print(f"\n>>> Testing full LSP search pipeline") - # Create mock seeds (normally from vector/splade search) + # Create mock seeds (normally from vector/FTS search) seeds = [ CodeSymbolNode( id=f"{TEST_PYTHON_FILE}:LspBridge:96", diff --git a/codex-lens/tests/test_chain_search.py b/codex-lens/tests/test_chain_search.py index c4acd8af..09e4b166 100644 --- a/codex-lens/tests/test_chain_search.py +++ b/codex-lens/tests/test_chain_search.py @@ -109,14 +109,6 @@ def test_cascade_search_strategy_routing(temp_paths: Path) -> None: engine.cascade_search("query", source_path, strategy="binary") mock_binary.assert_called_once() - # Test strategy='hybrid' routing - with patch.object(engine, "hybrid_cascade_search") as mock_hybrid: - mock_hybrid.return_value = ChainSearchResult( - query="query", results=[], symbols=[], stats=SearchStats() - ) - engine.cascade_search("query", source_path, strategy="hybrid") - mock_hybrid.assert_called_once() - # Test strategy='binary_rerank' routing with patch.object(engine, "binary_rerank_cascade_search") as mock_br: mock_br.return_value = ChainSearchResult( diff --git a/codex-lens/tests/test_staged_cascade.py b/codex-lens/tests/test_staged_cascade.py index b12b2271..11c8b829 100644 --- a/codex-lens/tests/test_staged_cascade.py +++ b/codex-lens/tests/test_staged_cascade.py @@ -576,20 +576,20 @@ class TestStagedCascadeIntegration: # Verify stage 4 was called mock_stage4.assert_called_once() - def test_staged_cascade_fallback_to_hybrid( + def test_staged_cascade_fallback_to_search( self, mock_registry, mock_mapper, mock_config, temp_paths ): - """Test staged_cascade_search falls back to hybrid when numpy unavailable.""" + """Test staged_cascade_search falls back to standard search when numpy unavailable.""" engine = ChainSearchEngine(mock_registry, mock_mapper, config=mock_config) with patch("codexlens.search.chain_search.NUMPY_AVAILABLE", False): - with patch.object(engine, "hybrid_cascade_search") as mock_hybrid: - mock_hybrid.return_value = MagicMock() + with patch.object(engine, "search") as mock_search: + mock_search.return_value = MagicMock() engine.staged_cascade_search("query", temp_paths / "src") - # Should fall back to hybrid cascade - mock_hybrid.assert_called_once() + # Should fall back to standard search + mock_search.assert_called_once() def test_staged_cascade_deduplicates_final_results( self, mock_registry, mock_mapper, mock_config, temp_paths @@ -689,10 +689,10 @@ class TestStagedCascadeGracefulDegradation: # Stage 1 returns no results mock_stage1.return_value = ([], None) - with patch.object(engine, "hybrid_cascade_search") as mock_hybrid: - mock_hybrid.return_value = MagicMock() + with patch.object(engine, "search") as mock_search: + mock_search.return_value = MagicMock() engine.staged_cascade_search("query", temp_paths / "src") - # Should fall back to hybrid when stage 1 fails - mock_hybrid.assert_called_once() + # Should fall back to standard search when stage 1 fails + mock_search.assert_called_once()