mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-04 01:40:45 +08:00
Implement SPLADE sparse encoder and associated database migrations
- Added `splade_encoder.py` for ONNX-optimized SPLADE encoding, including methods for encoding text and batch processing. - Created `SPLADE_IMPLEMENTATION.md` to document the SPLADE encoder's functionality, design patterns, and integration points. - Introduced migration script `migration_009_add_splade.py` to add SPLADE metadata and posting list tables to the database. - Developed `splade_index.py` for managing the SPLADE inverted index, supporting efficient sparse vector retrieval. - Added verification script `verify_watcher.py` to test FileWatcher event filtering and debouncing functionality.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -17,6 +17,16 @@ import {
|
||||
isIndexingInProgress
|
||||
} from '../../tools/codex-lens.js';
|
||||
import type { ProgressInfo, GpuMode } from '../../tools/codex-lens.js';
|
||||
import { loadLiteLLMApiConfig } from '../../config/litellm-api-config-manager.js';
|
||||
|
||||
// File watcher state (persisted across requests)
|
||||
let watcherProcess: any = null;
|
||||
let watcherStats = {
|
||||
running: false,
|
||||
root_path: '',
|
||||
events_processed: 0,
|
||||
start_time: null as Date | null
|
||||
};
|
||||
|
||||
export interface RouteContext {
|
||||
pathname: string;
|
||||
@@ -1052,5 +1062,478 @@ export async function handleCodexLensRoutes(ctx: RouteContext): Promise<boolean>
|
||||
return true;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// RERANKER CONFIGURATION ENDPOINTS
|
||||
// ============================================================
|
||||
|
||||
// API: Get Reranker Configuration
|
||||
if (pathname === '/api/codexlens/reranker/config' && req.method === 'GET') {
|
||||
try {
|
||||
const venvStatus = await checkVenvStatus();
|
||||
|
||||
// Default reranker config
|
||||
const rerankerConfig = {
|
||||
backend: 'onnx',
|
||||
model_name: 'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
||||
api_provider: 'siliconflow',
|
||||
api_key_set: false,
|
||||
available_backends: ['onnx', 'api', 'litellm', 'legacy'],
|
||||
api_providers: ['siliconflow', 'cohere', 'jina'],
|
||||
litellm_endpoints: [] as string[],
|
||||
config_source: 'default'
|
||||
};
|
||||
|
||||
// Load LiteLLM endpoints for dropdown
|
||||
try {
|
||||
const litellmConfig = loadLiteLLMApiConfig(initialPath);
|
||||
if (litellmConfig.endpoints && Array.isArray(litellmConfig.endpoints)) {
|
||||
rerankerConfig.litellm_endpoints = litellmConfig.endpoints.map(
|
||||
(ep: any) => ep.alias || ep.name || ep.baseUrl
|
||||
).filter(Boolean);
|
||||
}
|
||||
} catch (e) {
|
||||
// LiteLLM config not available, continue with empty endpoints
|
||||
}
|
||||
|
||||
// If CodexLens is installed, try to get actual config
|
||||
if (venvStatus.ready) {
|
||||
try {
|
||||
const result = await executeCodexLens(['config', '--json']);
|
||||
if (result.success) {
|
||||
const config = extractJSON(result.output);
|
||||
if (config.success && config.result) {
|
||||
// Map config values
|
||||
if (config.result.reranker_backend) {
|
||||
rerankerConfig.backend = config.result.reranker_backend;
|
||||
rerankerConfig.config_source = 'codexlens';
|
||||
}
|
||||
if (config.result.reranker_model) {
|
||||
rerankerConfig.model_name = config.result.reranker_model;
|
||||
}
|
||||
if (config.result.reranker_api_provider) {
|
||||
rerankerConfig.api_provider = config.result.reranker_api_provider;
|
||||
}
|
||||
// Check if API key is set (from env)
|
||||
if (process.env.RERANKER_API_KEY) {
|
||||
rerankerConfig.api_key_set = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('[CodexLens] Failed to get reranker config:', e);
|
||||
}
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ success: true, ...rerankerConfig }));
|
||||
} catch (err) {
|
||||
res.writeHead(500, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ success: false, error: err.message }));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// API: Set Reranker Configuration
|
||||
if (pathname === '/api/codexlens/reranker/config' && req.method === 'POST') {
|
||||
handlePostRequest(req, res, async (body) => {
|
||||
const { backend, model_name, api_provider, api_key, litellm_endpoint } = body;
|
||||
|
||||
// Validate backend
|
||||
const validBackends = ['onnx', 'api', 'litellm', 'legacy'];
|
||||
if (backend && !validBackends.includes(backend)) {
|
||||
return { success: false, error: `Invalid backend: ${backend}. Valid options: ${validBackends.join(', ')}`, status: 400 };
|
||||
}
|
||||
|
||||
// Validate api_provider
|
||||
const validProviders = ['siliconflow', 'cohere', 'jina'];
|
||||
if (api_provider && !validProviders.includes(api_provider)) {
|
||||
return { success: false, error: `Invalid api_provider: ${api_provider}. Valid options: ${validProviders.join(', ')}`, status: 400 };
|
||||
}
|
||||
|
||||
try {
|
||||
const updates: string[] = [];
|
||||
|
||||
// Set backend
|
||||
if (backend) {
|
||||
const result = await executeCodexLens(['config', 'set', 'reranker_backend', backend, '--json']);
|
||||
if (result.success) updates.push('backend');
|
||||
}
|
||||
|
||||
// Set model
|
||||
if (model_name) {
|
||||
const result = await executeCodexLens(['config', 'set', 'reranker_model', model_name, '--json']);
|
||||
if (result.success) updates.push('model_name');
|
||||
}
|
||||
|
||||
// Set API provider
|
||||
if (api_provider) {
|
||||
const result = await executeCodexLens(['config', 'set', 'reranker_api_provider', api_provider, '--json']);
|
||||
if (result.success) updates.push('api_provider');
|
||||
}
|
||||
|
||||
// Set LiteLLM endpoint
|
||||
if (litellm_endpoint) {
|
||||
const result = await executeCodexLens(['config', 'set', 'reranker_litellm_endpoint', litellm_endpoint, '--json']);
|
||||
if (result.success) updates.push('litellm_endpoint');
|
||||
}
|
||||
|
||||
// Handle API key - write to .env file or environment
|
||||
if (api_key) {
|
||||
// For security, we store in process.env for the current session
|
||||
// In production, this should be written to a secure .env file
|
||||
process.env.RERANKER_API_KEY = api_key;
|
||||
updates.push('api_key');
|
||||
}
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: `Updated: ${updates.join(', ')}`,
|
||||
updated_fields: updates
|
||||
};
|
||||
} catch (err) {
|
||||
return { success: false, error: err.message, status: 500 };
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// FILE WATCHER CONTROL ENDPOINTS
|
||||
// ============================================================
|
||||
|
||||
// API: Get File Watcher Status
|
||||
if (pathname === '/api/codexlens/watch/status') {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
success: true,
|
||||
running: watcherStats.running,
|
||||
root_path: watcherStats.root_path,
|
||||
events_processed: watcherStats.events_processed,
|
||||
start_time: watcherStats.start_time?.toISOString() || null,
|
||||
uptime_seconds: watcherStats.start_time
|
||||
? Math.floor((Date.now() - watcherStats.start_time.getTime()) / 1000)
|
||||
: 0
|
||||
}));
|
||||
return true;
|
||||
}
|
||||
|
||||
// API: Start File Watcher
|
||||
if (pathname === '/api/codexlens/watch/start' && req.method === 'POST') {
|
||||
handlePostRequest(req, res, async (body) => {
|
||||
const { path: watchPath, debounce_ms = 1000 } = body;
|
||||
const targetPath = watchPath || initialPath;
|
||||
|
||||
if (watcherStats.running) {
|
||||
return { success: false, error: 'Watcher already running', status: 400 };
|
||||
}
|
||||
|
||||
try {
|
||||
const { spawn } = await import('child_process');
|
||||
const { join } = await import('path');
|
||||
const { existsSync, statSync } = await import('fs');
|
||||
|
||||
// Validate path exists and is a directory
|
||||
if (!existsSync(targetPath)) {
|
||||
return { success: false, error: `Path does not exist: ${targetPath}`, status: 400 };
|
||||
}
|
||||
const pathStat = statSync(targetPath);
|
||||
if (!pathStat.isDirectory()) {
|
||||
return { success: false, error: `Path is not a directory: ${targetPath}`, status: 400 };
|
||||
}
|
||||
|
||||
// Get the codexlens CLI path
|
||||
const venvStatus = await checkVenvStatus();
|
||||
if (!venvStatus.ready) {
|
||||
return { success: false, error: 'CodexLens not installed', status: 400 };
|
||||
}
|
||||
|
||||
// Spawn watch process (no shell: true for security)
|
||||
// Use process.platform to determine if we need .cmd extension on Windows
|
||||
const isWindows = process.platform === 'win32';
|
||||
const codexlensCmd = isWindows ? 'codexlens.exe' : 'codexlens';
|
||||
const args = ['watch', targetPath, '--debounce', String(debounce_ms)];
|
||||
watcherProcess = spawn(codexlensCmd, args, {
|
||||
cwd: targetPath,
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
env: { ...process.env }
|
||||
});
|
||||
|
||||
watcherStats = {
|
||||
running: true,
|
||||
root_path: targetPath,
|
||||
events_processed: 0,
|
||||
start_time: new Date()
|
||||
};
|
||||
|
||||
// Handle process output for event counting
|
||||
if (watcherProcess.stdout) {
|
||||
watcherProcess.stdout.on('data', (data: Buffer) => {
|
||||
const output = data.toString();
|
||||
// Count processed events from output
|
||||
const matches = output.match(/Processed \d+ events?/g);
|
||||
if (matches) {
|
||||
watcherStats.events_processed += matches.length;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Handle process exit
|
||||
watcherProcess.on('exit', (code: number) => {
|
||||
watcherStats.running = false;
|
||||
watcherProcess = null;
|
||||
console.log(`[CodexLens] Watcher exited with code ${code}`);
|
||||
});
|
||||
|
||||
// Broadcast watcher started
|
||||
broadcastToClients({
|
||||
type: 'CODEXLENS_WATCHER_STATUS',
|
||||
payload: { running: true, path: targetPath }
|
||||
});
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: 'Watcher started',
|
||||
path: targetPath,
|
||||
pid: watcherProcess.pid
|
||||
};
|
||||
} catch (err) {
|
||||
return { success: false, error: err.message, status: 500 };
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
// API: Stop File Watcher
|
||||
if (pathname === '/api/codexlens/watch/stop' && req.method === 'POST') {
|
||||
handlePostRequest(req, res, async () => {
|
||||
if (!watcherStats.running || !watcherProcess) {
|
||||
return { success: false, error: 'Watcher not running', status: 400 };
|
||||
}
|
||||
|
||||
try {
|
||||
// Send SIGTERM to gracefully stop the watcher
|
||||
watcherProcess.kill('SIGTERM');
|
||||
|
||||
// Wait a moment for graceful shutdown
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
// Force kill if still running
|
||||
if (watcherProcess && !watcherProcess.killed) {
|
||||
watcherProcess.kill('SIGKILL');
|
||||
}
|
||||
|
||||
const finalStats = {
|
||||
events_processed: watcherStats.events_processed,
|
||||
uptime_seconds: watcherStats.start_time
|
||||
? Math.floor((Date.now() - watcherStats.start_time.getTime()) / 1000)
|
||||
: 0
|
||||
};
|
||||
|
||||
watcherStats = {
|
||||
running: false,
|
||||
root_path: '',
|
||||
events_processed: 0,
|
||||
start_time: null
|
||||
};
|
||||
watcherProcess = null;
|
||||
|
||||
// Broadcast watcher stopped
|
||||
broadcastToClients({
|
||||
type: 'CODEXLENS_WATCHER_STATUS',
|
||||
payload: { running: false }
|
||||
});
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: 'Watcher stopped',
|
||||
...finalStats
|
||||
};
|
||||
} catch (err) {
|
||||
return { success: false, error: err.message, status: 500 };
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// ============================================================
|
||||
// SPLADE ENDPOINTS
|
||||
// ============================================================
|
||||
|
||||
// API: SPLADE Status - Check if SPLADE is available and installed
|
||||
if (pathname === '/api/codexlens/splade/status') {
|
||||
try {
|
||||
// Check if CodexLens is installed first
|
||||
const venvStatus = await checkVenvStatus();
|
||||
if (!venvStatus.ready) {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
available: false,
|
||||
installed: false,
|
||||
model: 'naver/splade-cocondenser-ensembledistil',
|
||||
error: 'CodexLens not installed'
|
||||
}));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check SPLADE availability using Python check
|
||||
const result = await executeCodexLens(['python', '-c',
|
||||
'from codexlens.semantic.splade_encoder import check_splade_available; ok, err = check_splade_available(); print("OK" if ok else err)'
|
||||
]);
|
||||
|
||||
const available = result.output.includes('OK');
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
available,
|
||||
installed: available,
|
||||
model: 'naver/splade-cocondenser-ensembledistil',
|
||||
error: available ? null : result.output.trim()
|
||||
}));
|
||||
} catch (err) {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
available: false,
|
||||
installed: false,
|
||||
model: 'naver/splade-cocondenser-ensembledistil',
|
||||
error: err.message
|
||||
}));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// API: SPLADE Install - Install SPLADE dependencies
|
||||
if (pathname === '/api/codexlens/splade/install' && req.method === 'POST') {
|
||||
handlePostRequest(req, res, async (body) => {
|
||||
try {
|
||||
const gpu = body?.gpu || false;
|
||||
const packageName = gpu ? 'codex-lens[splade-gpu]' : 'codex-lens[splade]';
|
||||
|
||||
// Use pip to install the SPLADE extras
|
||||
const { spawn } = await import('child_process');
|
||||
const { promisify } = await import('util');
|
||||
const execFilePromise = promisify(require('child_process').execFile);
|
||||
|
||||
const result = await execFilePromise('pip', ['install', packageName], {
|
||||
timeout: 600000 // 10 minutes
|
||||
});
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: `SPLADE installed successfully (${gpu ? 'GPU' : 'CPU'} mode)`,
|
||||
output: result.stdout
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
success: false,
|
||||
error: err.message,
|
||||
stderr: err.stderr,
|
||||
status: 500
|
||||
};
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
// API: SPLADE Index Status - Check if SPLADE index exists for a project
|
||||
if (pathname === '/api/codexlens/splade/index-status') {
|
||||
try {
|
||||
const projectPath = url.searchParams.get('path');
|
||||
if (!projectPath) {
|
||||
res.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ success: false, error: 'Missing path parameter' }));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if CodexLens is installed first
|
||||
const venvStatus = await checkVenvStatus();
|
||||
if (!venvStatus.ready) {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ exists: false, error: 'CodexLens not installed' }));
|
||||
return true;
|
||||
}
|
||||
|
||||
const { join } = await import('path');
|
||||
const indexDb = join(projectPath, '.codexlens', '_index.db');
|
||||
|
||||
// Use Python to check SPLADE index status
|
||||
const pythonCode = `
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from pathlib import Path
|
||||
try:
|
||||
idx = SpladeIndex(Path("${indexDb.replace(/\\/g, '\\\\')}"))
|
||||
if idx.has_index():
|
||||
stats = idx.get_stats()
|
||||
meta = idx.get_metadata()
|
||||
model = meta.get('model_name', '') if meta else ''
|
||||
print(f"OK|{stats['unique_chunks']}|{stats['total_postings']}|{model}")
|
||||
else:
|
||||
print("NO_INDEX")
|
||||
except Exception as e:
|
||||
print(f"ERROR|{str(e)}")
|
||||
`;
|
||||
|
||||
const result = await executeCodexLens(['python', '-c', pythonCode]);
|
||||
|
||||
if (result.output.startsWith('OK|')) {
|
||||
const parts = result.output.trim().split('|');
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
exists: true,
|
||||
chunks: parseInt(parts[1]),
|
||||
postings: parseInt(parts[2]),
|
||||
model: parts[3]
|
||||
}));
|
||||
} else if (result.output.startsWith('ERROR|')) {
|
||||
const errorMsg = result.output.substring(6).trim();
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ exists: false, error: errorMsg }));
|
||||
} else {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ exists: false }));
|
||||
}
|
||||
} catch (err) {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({ exists: false, error: err.message }));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// API: SPLADE Index Rebuild - Rebuild SPLADE index for a project
|
||||
if (pathname === '/api/codexlens/splade/rebuild' && req.method === 'POST') {
|
||||
handlePostRequest(req, res, async (body) => {
|
||||
const { path: projectPath } = body;
|
||||
|
||||
if (!projectPath) {
|
||||
return { success: false, error: 'Missing path parameter', status: 400 };
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await executeCodexLens(['splade-index', projectPath, '--rebuild'], {
|
||||
cwd: projectPath,
|
||||
timeout: 1800000 // 30 minutes for large codebases
|
||||
});
|
||||
|
||||
if (result.success) {
|
||||
return {
|
||||
success: true,
|
||||
message: 'SPLADE index rebuilt successfully',
|
||||
output: result.output
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
success: false,
|
||||
error: result.error || 'Failed to rebuild SPLADE index',
|
||||
output: result.output,
|
||||
status: 500
|
||||
};
|
||||
}
|
||||
} catch (err) {
|
||||
return { success: false, error: err.message, status: 500 };
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -366,6 +366,25 @@ const i18n = {
|
||||
'codexlens.depsInstalled': 'Dependencies installed successfully',
|
||||
'codexlens.depsInstallFailed': 'Failed to install dependencies',
|
||||
|
||||
// SPLADE Dependencies
|
||||
'codexlens.spladeDeps': 'SPLADE Sparse Retrieval',
|
||||
'codexlens.spladeInstalled': 'SPLADE Available',
|
||||
'codexlens.spladeNotInstalled': 'SPLADE Not Installed',
|
||||
'codexlens.spladeInstallHint': 'Install for improved synonym matching in code search',
|
||||
'codexlens.installingSpladePackage': 'Installing SPLADE package',
|
||||
'codexlens.spladeInstallSuccess': 'SPLADE installed successfully',
|
||||
'codexlens.spladeInstallFailed': 'SPLADE installation failed',
|
||||
'codexlens.spladeModel': 'Model',
|
||||
'codexlens.spladeIndexStatus': 'SPLADE Index',
|
||||
'codexlens.spladeIndexExists': 'Index available',
|
||||
'codexlens.spladeIndexMissing': 'No SPLADE index',
|
||||
'codexlens.spladeRebuild': 'Rebuild SPLADE Index',
|
||||
'codexlens.spladeRebuilding': 'Rebuilding SPLADE index...',
|
||||
'codexlens.spladeRebuildSuccess': 'SPLADE index rebuilt',
|
||||
'codexlens.spladeRebuildFailed': 'SPLADE index rebuild failed',
|
||||
'codexlens.spladeChunks': 'Chunks',
|
||||
'codexlens.spladePostings': 'Postings',
|
||||
|
||||
// GPU Mode Selection
|
||||
'codexlens.selectGpuMode': 'Select acceleration mode',
|
||||
'codexlens.cpuModeDesc': 'Standard CPU processing',
|
||||
@@ -2288,6 +2307,25 @@ const i18n = {
|
||||
'codexlens.depsInstalled': '依赖安装成功',
|
||||
'codexlens.depsInstallFailed': '依赖安装失败',
|
||||
|
||||
// SPLADE 依赖
|
||||
'codexlens.spladeDeps': 'SPLADE 稀疏检索',
|
||||
'codexlens.spladeInstalled': 'SPLADE 已安装',
|
||||
'codexlens.spladeNotInstalled': 'SPLADE 未安装',
|
||||
'codexlens.spladeInstallHint': '安装后可改进代码搜索的同义词匹配',
|
||||
'codexlens.installingSpladePackage': '正在安装 SPLADE 包',
|
||||
'codexlens.spladeInstallSuccess': 'SPLADE 安装成功',
|
||||
'codexlens.spladeInstallFailed': 'SPLADE 安装失败',
|
||||
'codexlens.spladeModel': '模型',
|
||||
'codexlens.spladeIndexStatus': 'SPLADE 索引',
|
||||
'codexlens.spladeIndexExists': '索引可用',
|
||||
'codexlens.spladeIndexMissing': '无 SPLADE 索引',
|
||||
'codexlens.spladeRebuild': '重建 SPLADE 索引',
|
||||
'codexlens.spladeRebuilding': '正在重建 SPLADE 索引...',
|
||||
'codexlens.spladeRebuildSuccess': 'SPLADE 索引重建完成',
|
||||
'codexlens.spladeRebuildFailed': 'SPLADE 索引重建失败',
|
||||
'codexlens.spladeChunks': '分块数',
|
||||
'codexlens.spladePostings': '词条数',
|
||||
|
||||
// GPU 模式选择
|
||||
'codexlens.selectGpuMode': '选择加速模式',
|
||||
'codexlens.cpuModeDesc': '标准 CPU 处理',
|
||||
|
||||
@@ -120,6 +120,12 @@ function buildCodexLensConfigContent(config) {
|
||||
? '<button class="inline-flex items-center gap-1.5 px-3 py-1.5 text-xs font-medium rounded-md border border-primary/30 bg-primary/5 text-primary hover:bg-primary/10 transition-colors" onclick="initCodexLensIndex()">' +
|
||||
'<i data-lucide="database" class="w-3.5 h-3.5"></i> ' + t('codexlens.initializeIndex') +
|
||||
'</button>' +
|
||||
'<button class="inline-flex items-center gap-1.5 px-3 py-1.5 text-xs font-medium rounded-md border border-primary/30 bg-primary/5 text-primary hover:bg-primary/10 transition-colors" onclick="showRerankerConfigModal()">' +
|
||||
'<i data-lucide="layers" class="w-3.5 h-3.5"></i> ' + (t('codexlens.rerankerConfig') || 'Reranker Config') +
|
||||
'</button>' +
|
||||
'<button class="inline-flex items-center gap-1.5 px-3 py-1.5 text-xs font-medium rounded-md border border-primary/30 bg-primary/5 text-primary hover:bg-primary/10 transition-colors" onclick="showWatcherControlModal()">' +
|
||||
'<i data-lucide="eye" class="w-3.5 h-3.5"></i> ' + (t('codexlens.watcherControl') || 'File Watcher') +
|
||||
'</button>' +
|
||||
'<button class="inline-flex items-center gap-1.5 px-3 py-1.5 text-xs font-medium rounded-md border border-border bg-background hover:bg-muted/50 transition-colors" onclick="cleanCurrentWorkspaceIndex()">' +
|
||||
'<i data-lucide="folder-x" class="w-3.5 h-3.5"></i> ' + t('codexlens.cleanCurrentWorkspace') +
|
||||
'</button>' +
|
||||
@@ -145,6 +151,17 @@ function buildCodexLensConfigContent(config) {
|
||||
'</div>'
|
||||
: '') +
|
||||
|
||||
// SPLADE Section
|
||||
(isInstalled
|
||||
? '<div class="tool-config-section">' +
|
||||
'<h4>' + t('codexlens.spladeDeps') + '</h4>' +
|
||||
'<div id="spladeStatus" class="space-y-2">' +
|
||||
'<div class="text-sm text-muted-foreground">' + t('common.loading') + '</div>' +
|
||||
'</div>' +
|
||||
'</div>'
|
||||
: '') +
|
||||
|
||||
|
||||
// Model Management Section
|
||||
(isInstalled
|
||||
? '<div class="tool-config-section">' +
|
||||
@@ -335,6 +352,9 @@ function initCodexLensConfigEvents(currentConfig) {
|
||||
// Load semantic dependencies status
|
||||
loadSemanticDepsStatus();
|
||||
|
||||
// Load SPLADE status
|
||||
loadSpladeStatus();
|
||||
|
||||
// Load model list
|
||||
loadModelList();
|
||||
}
|
||||
@@ -714,6 +734,95 @@ async function installSemanticDeps() {
|
||||
await installSemanticDepsWithGpu();
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// SPLADE MANAGEMENT
|
||||
// ============================================================
|
||||
|
||||
/**
|
||||
* Load SPLADE status
|
||||
*/
|
||||
async function loadSpladeStatus() {
|
||||
var container = document.getElementById('spladeStatus');
|
||||
if (!container) return;
|
||||
|
||||
try {
|
||||
var response = await fetch('/api/codexlens/splade/status');
|
||||
var status = await response.json();
|
||||
|
||||
if (status.available) {
|
||||
container.innerHTML =
|
||||
'<div class="flex items-center justify-between p-3 border border-success/30 rounded-lg bg-success/5">' +
|
||||
'<div class="flex items-center gap-3">' +
|
||||
'<i data-lucide="check-circle" class="w-5 h-5 text-success"></i>' +
|
||||
'<div>' +
|
||||
'<span class="font-medium">' + t('codexlens.spladeInstalled') + '</span>' +
|
||||
'<div class="text-xs text-muted-foreground">' + status.model + '</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>';
|
||||
} else {
|
||||
container.innerHTML =
|
||||
'<div class="flex items-center justify-between p-3 border border-border rounded-lg">' +
|
||||
'<div class="flex items-center gap-3">' +
|
||||
'<i data-lucide="alert-circle" class="w-5 h-5 text-muted-foreground"></i>' +
|
||||
'<div>' +
|
||||
'<span class="font-medium">' + t('codexlens.spladeNotInstalled') + '</span>' +
|
||||
'<div class="text-xs text-muted-foreground">' + (status.error || t('codexlens.spladeInstallHint')) + '</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'<div class="flex gap-2">' +
|
||||
'<button class="btn-sm btn-outline" onclick="installSplade(false)">' +
|
||||
'<i data-lucide="download" class="w-3.5 h-3.5 mr-1"></i>CPU' +
|
||||
'</button>' +
|
||||
'<button class="btn-sm btn-primary" onclick="installSplade(true)">' +
|
||||
'<i data-lucide="zap" class="w-3.5 h-3.5 mr-1"></i>GPU' +
|
||||
'</button>' +
|
||||
'</div>' +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
if (window.lucide) lucide.createIcons();
|
||||
} catch (err) {
|
||||
container.innerHTML = '<div class="text-sm text-error">' + err.message + '</div>';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Install SPLADE package
|
||||
*/
|
||||
async function installSplade(gpu) {
|
||||
var container = document.getElementById('spladeStatus');
|
||||
if (!container) return;
|
||||
|
||||
container.innerHTML =
|
||||
'<div class="flex items-center gap-3 p-3 border border-primary/30 rounded-lg">' +
|
||||
'<div class="animate-spin"><i data-lucide="loader-2" class="w-5 h-5 text-primary"></i></div>' +
|
||||
'<span>' + t('codexlens.installingSpladePackage') + (gpu ? ' (GPU)' : ' (CPU)') + '...</span>' +
|
||||
'</div>';
|
||||
if (window.lucide) lucide.createIcons();
|
||||
|
||||
try {
|
||||
var response = await fetch('/api/codexlens/splade/install', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ gpu: gpu })
|
||||
});
|
||||
var result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
showRefreshToast(t('codexlens.spladeInstallSuccess'), 'success');
|
||||
loadSpladeStatus();
|
||||
} else {
|
||||
showRefreshToast(t('codexlens.spladeInstallFailed') + ': ' + result.error, 'error');
|
||||
loadSpladeStatus();
|
||||
}
|
||||
} catch (err) {
|
||||
showRefreshToast(t('common.error') + ': ' + err.message, 'error');
|
||||
loadSpladeStatus();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ============================================================
|
||||
// MODEL MANAGEMENT
|
||||
// ============================================================
|
||||
@@ -2975,3 +3084,546 @@ async function saveRotationConfig() {
|
||||
showRefreshToast(t('common.error') + ': ' + err.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// RERANKER CONFIGURATION MODAL
|
||||
// ============================================================
|
||||
|
||||
/**
|
||||
* Show Reranker configuration modal
|
||||
*/
|
||||
async function showRerankerConfigModal() {
|
||||
try {
|
||||
showRefreshToast(t('codexlens.loadingRerankerConfig') || 'Loading reranker configuration...', 'info');
|
||||
|
||||
// Fetch current reranker config
|
||||
const response = await fetch('/api/codexlens/reranker/config');
|
||||
const config = await response.json();
|
||||
|
||||
if (!config.success) {
|
||||
showRefreshToast(t('common.error') + ': ' + (config.error || 'Failed to load config'), 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
const modalHtml = buildRerankerConfigContent(config);
|
||||
|
||||
// Create and show modal
|
||||
const tempContainer = document.createElement('div');
|
||||
tempContainer.innerHTML = modalHtml;
|
||||
const modal = tempContainer.firstElementChild;
|
||||
document.body.appendChild(modal);
|
||||
|
||||
// Initialize icons
|
||||
if (window.lucide) lucide.createIcons();
|
||||
|
||||
// Initialize event handlers
|
||||
initRerankerConfigEvents(config);
|
||||
} catch (err) {
|
||||
showRefreshToast(t('common.error') + ': ' + err.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build Reranker configuration modal content
|
||||
*/
|
||||
function buildRerankerConfigContent(config) {
|
||||
const backend = config.backend || 'onnx';
|
||||
const modelName = config.model_name || '';
|
||||
const apiProvider = config.api_provider || 'siliconflow';
|
||||
const apiKeySet = config.api_key_set || false;
|
||||
const availableBackends = config.available_backends || ['onnx', 'api', 'litellm', 'legacy'];
|
||||
const apiProviders = config.api_providers || ['siliconflow', 'cohere', 'jina'];
|
||||
const litellmEndpoints = config.litellm_endpoints || [];
|
||||
|
||||
// ONNX models
|
||||
const onnxModels = [
|
||||
'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
||||
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
|
||||
'BAAI/bge-reranker-base',
|
||||
'BAAI/bge-reranker-large'
|
||||
];
|
||||
|
||||
// Build backend options
|
||||
const backendOptions = availableBackends.map(function(b) {
|
||||
const labels = {
|
||||
'onnx': 'ONNX (Local, Optimum)',
|
||||
'api': 'API (SiliconFlow/Cohere/Jina)',
|
||||
'litellm': 'LiteLLM (Custom Endpoint)',
|
||||
'legacy': 'Legacy (SentenceTransformers)'
|
||||
};
|
||||
return '<option value="' + b + '" ' + (backend === b ? 'selected' : '') + '>' + (labels[b] || b) + '</option>';
|
||||
}).join('');
|
||||
|
||||
// Build API provider options
|
||||
const providerOptions = apiProviders.map(function(p) {
|
||||
return '<option value="' + p + '" ' + (apiProvider === p ? 'selected' : '') + '>' + p.charAt(0).toUpperCase() + p.slice(1) + '</option>';
|
||||
}).join('');
|
||||
|
||||
// Build ONNX model options
|
||||
const onnxModelOptions = onnxModels.map(function(m) {
|
||||
return '<option value="' + m + '" ' + (modelName === m ? 'selected' : '') + '>' + m + '</option>';
|
||||
}).join('');
|
||||
|
||||
// Build LiteLLM endpoint options
|
||||
const litellmOptions = litellmEndpoints.length > 0
|
||||
? litellmEndpoints.map(function(ep) {
|
||||
return '<option value="' + ep + '">' + ep + '</option>';
|
||||
}).join('')
|
||||
: '<option value="" disabled>No endpoints configured</option>';
|
||||
|
||||
return '<div class="modal-backdrop" id="rerankerConfigModal">' +
|
||||
'<div class="modal-container max-w-xl">' +
|
||||
'<div class="modal-header">' +
|
||||
'<div class="flex items-center gap-3">' +
|
||||
'<div class="modal-icon">' +
|
||||
'<i data-lucide="layers" class="w-5 h-5"></i>' +
|
||||
'</div>' +
|
||||
'<div>' +
|
||||
'<h2 class="text-lg font-bold">' + (t('codexlens.rerankerConfig') || 'Reranker Configuration') + '</h2>' +
|
||||
'<p class="text-xs text-muted-foreground">' + (t('codexlens.rerankerConfigDesc') || 'Configure cross-encoder reranking for semantic search') + '</p>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'<button onclick="closeRerankerModal()" class="text-muted-foreground hover:text-foreground">' +
|
||||
'<i data-lucide="x" class="w-5 h-5"></i>' +
|
||||
'</button>' +
|
||||
'</div>' +
|
||||
|
||||
'<div class="modal-body space-y-4">' +
|
||||
// Backend Selection
|
||||
'<div class="tool-config-section">' +
|
||||
'<h4>' + (t('codexlens.rerankerBackend') || 'Backend') + '</h4>' +
|
||||
'<select id="rerankerBackend" class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm" onchange="toggleRerankerSections()">' +
|
||||
backendOptions +
|
||||
'</select>' +
|
||||
'<p class="text-xs text-muted-foreground mt-1">' + (t('codexlens.rerankerBackendHint') || 'Select reranking backend based on your needs') + '</p>' +
|
||||
'</div>' +
|
||||
|
||||
// ONNX Section (visible when backend=onnx)
|
||||
'<div id="rerankerOnnxSection" class="tool-config-section" style="display:' + (backend === 'onnx' ? 'block' : 'none') + '">' +
|
||||
'<h4>' + (t('codexlens.onnxModel') || 'ONNX Model') + '</h4>' +
|
||||
'<select id="rerankerOnnxModel" class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm">' +
|
||||
onnxModelOptions +
|
||||
'<option value="custom">Custom model...</option>' +
|
||||
'</select>' +
|
||||
'<input type="text" id="rerankerCustomModel" value="' + (onnxModels.includes(modelName) ? '' : modelName) + '" ' +
|
||||
'placeholder="Enter custom model name" ' +
|
||||
'class="w-full mt-2 px-3 py-2 border border-border rounded-lg bg-background text-sm" style="display:' + (onnxModels.includes(modelName) ? 'none' : 'block') + '" />' +
|
||||
'</div>' +
|
||||
|
||||
// API Section (visible when backend=api)
|
||||
'<div id="rerankerApiSection" class="tool-config-section" style="display:' + (backend === 'api' ? 'block' : 'none') + '">' +
|
||||
'<h4>' + (t('codexlens.apiConfig') || 'API Configuration') + '</h4>' +
|
||||
'<div class="space-y-3">' +
|
||||
'<div>' +
|
||||
'<label class="block text-sm font-medium mb-1.5">' + (t('codexlens.apiProvider') || 'Provider') + '</label>' +
|
||||
'<select id="rerankerApiProvider" class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm">' +
|
||||
providerOptions +
|
||||
'</select>' +
|
||||
'</div>' +
|
||||
'<div>' +
|
||||
'<label class="block text-sm font-medium mb-1.5">' + (t('codexlens.apiKey') || 'API Key') + '</label>' +
|
||||
'<div class="flex items-center gap-2">' +
|
||||
'<input type="password" id="rerankerApiKey" placeholder="' + (apiKeySet ? '••••••••' : 'Enter API key') + '" ' +
|
||||
'class="flex-1 px-3 py-2 border border-border rounded-lg bg-background text-sm" />' +
|
||||
(apiKeySet ? '<span class="inline-flex items-center gap-1 px-2 py-1 rounded text-xs bg-success/10 text-success border border-success/20"><i data-lucide="check" class="w-3 h-3"></i>Set</span>' : '') +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'<div>' +
|
||||
'<label class="block text-sm font-medium mb-1.5">' + (t('codexlens.modelName') || 'Model Name') + '</label>' +
|
||||
'<input type="text" id="rerankerApiModel" value="' + modelName + '" ' +
|
||||
'placeholder="e.g., BAAI/bge-reranker-v2-m3" ' +
|
||||
'class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm" />' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
||||
// LiteLLM Section (visible when backend=litellm)
|
||||
'<div id="rerankerLitellmSection" class="tool-config-section" style="display:' + (backend === 'litellm' ? 'block' : 'none') + '">' +
|
||||
'<h4>' + (t('codexlens.litellmEndpoint') || 'LiteLLM Endpoint') + '</h4>' +
|
||||
'<select id="rerankerLitellmEndpoint" class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm">' +
|
||||
litellmOptions +
|
||||
'</select>' +
|
||||
(litellmEndpoints.length === 0
|
||||
? '<p class="text-xs text-warning mt-1">' + (t('codexlens.noEndpointsHint') || 'Configure LiteLLM endpoints in API Settings first') + '</p>'
|
||||
: '') +
|
||||
'</div>' +
|
||||
|
||||
// Legacy Section (visible when backend=legacy)
|
||||
'<div id="rerankerLegacySection" class="tool-config-section" style="display:' + (backend === 'legacy' ? 'block' : 'none') + '">' +
|
||||
'<div class="flex items-start gap-2 bg-warning/10 border border-warning/30 rounded-lg p-3">' +
|
||||
'<i data-lucide="alert-triangle" class="w-4 h-4 text-warning mt-0.5"></i>' +
|
||||
'<div class="text-sm">' +
|
||||
'<p class="font-medium text-warning">' + (t('codexlens.legacyWarning') || 'Legacy Backend') + '</p>' +
|
||||
'<p class="text-muted-foreground mt-1">' + (t('codexlens.legacyWarningDesc') || 'Uses SentenceTransformers CrossEncoder. Consider using ONNX for better performance.') + '</p>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
||||
'<div class="modal-footer">' +
|
||||
'<button onclick="resetRerankerConfig()" class="btn btn-outline">' +
|
||||
'<i data-lucide="rotate-ccw" class="w-4 h-4"></i> ' + (t('common.reset') || 'Reset') +
|
||||
'</button>' +
|
||||
'<button onclick="closeRerankerModal()" class="btn btn-outline">' + t('common.cancel') + '</button>' +
|
||||
'<button onclick="saveRerankerConfig()" class="btn btn-primary">' +
|
||||
'<i data-lucide="save" class="w-4 h-4"></i> ' + t('common.save') +
|
||||
'</button>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle reranker configuration sections based on selected backend
|
||||
*/
|
||||
function toggleRerankerSections() {
|
||||
var backend = document.getElementById('rerankerBackend').value;
|
||||
|
||||
document.getElementById('rerankerOnnxSection').style.display = backend === 'onnx' ? 'block' : 'none';
|
||||
document.getElementById('rerankerApiSection').style.display = backend === 'api' ? 'block' : 'none';
|
||||
document.getElementById('rerankerLitellmSection').style.display = backend === 'litellm' ? 'block' : 'none';
|
||||
document.getElementById('rerankerLegacySection').style.display = backend === 'legacy' ? 'block' : 'none';
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize reranker config modal events
|
||||
*/
|
||||
function initRerankerConfigEvents(config) {
|
||||
// Handle ONNX model custom input toggle
|
||||
var onnxModelSelect = document.getElementById('rerankerOnnxModel');
|
||||
var customModelInput = document.getElementById('rerankerCustomModel');
|
||||
|
||||
if (onnxModelSelect && customModelInput) {
|
||||
onnxModelSelect.addEventListener('change', function() {
|
||||
customModelInput.style.display = this.value === 'custom' ? 'block' : 'none';
|
||||
});
|
||||
}
|
||||
|
||||
// Store original config for reset
|
||||
window._rerankerOriginalConfig = config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the reranker config modal
|
||||
*/
|
||||
function closeRerankerModal() {
|
||||
var modal = document.getElementById('rerankerConfigModal');
|
||||
if (modal) modal.remove();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset reranker config to original values
|
||||
*/
|
||||
function resetRerankerConfig() {
|
||||
var config = window._rerankerOriginalConfig;
|
||||
if (!config) return;
|
||||
|
||||
document.getElementById('rerankerBackend').value = config.backend || 'onnx';
|
||||
toggleRerankerSections();
|
||||
|
||||
// Reset ONNX section
|
||||
var onnxModels = [
|
||||
'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
||||
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
|
||||
'BAAI/bge-reranker-base',
|
||||
'BAAI/bge-reranker-large'
|
||||
];
|
||||
if (onnxModels.includes(config.model_name)) {
|
||||
document.getElementById('rerankerOnnxModel').value = config.model_name;
|
||||
document.getElementById('rerankerCustomModel').style.display = 'none';
|
||||
} else {
|
||||
document.getElementById('rerankerOnnxModel').value = 'custom';
|
||||
document.getElementById('rerankerCustomModel').value = config.model_name || '';
|
||||
document.getElementById('rerankerCustomModel').style.display = 'block';
|
||||
}
|
||||
|
||||
// Reset API section
|
||||
document.getElementById('rerankerApiProvider').value = config.api_provider || 'siliconflow';
|
||||
document.getElementById('rerankerApiKey').value = '';
|
||||
document.getElementById('rerankerApiModel').value = config.model_name || '';
|
||||
|
||||
showRefreshToast(t('common.reset') || 'Reset to original values', 'info');
|
||||
}
|
||||
|
||||
/**
|
||||
* Save reranker configuration
|
||||
*/
|
||||
async function saveRerankerConfig() {
|
||||
try {
|
||||
var backend = document.getElementById('rerankerBackend').value;
|
||||
var payload = { backend: backend };
|
||||
|
||||
// Collect model name based on backend
|
||||
if (backend === 'onnx') {
|
||||
var onnxModel = document.getElementById('rerankerOnnxModel').value;
|
||||
if (onnxModel === 'custom') {
|
||||
payload.model_name = document.getElementById('rerankerCustomModel').value.trim();
|
||||
} else {
|
||||
payload.model_name = onnxModel;
|
||||
}
|
||||
} else if (backend === 'api') {
|
||||
payload.api_provider = document.getElementById('rerankerApiProvider').value;
|
||||
payload.model_name = document.getElementById('rerankerApiModel').value.trim();
|
||||
var apiKey = document.getElementById('rerankerApiKey').value.trim();
|
||||
if (apiKey) {
|
||||
payload.api_key = apiKey;
|
||||
}
|
||||
} else if (backend === 'litellm') {
|
||||
payload.litellm_endpoint = document.getElementById('rerankerLitellmEndpoint').value;
|
||||
}
|
||||
|
||||
var response = await fetch('/api/codexlens/reranker/config', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
|
||||
var result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
showRefreshToast((t('codexlens.rerankerConfigSaved') || 'Reranker configuration saved') + ': ' + result.message, 'success');
|
||||
closeRerankerModal();
|
||||
} else {
|
||||
showRefreshToast(t('common.saveFailed') + ': ' + result.error, 'error');
|
||||
}
|
||||
} catch (err) {
|
||||
showRefreshToast(t('common.error') + ': ' + err.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// FILE WATCHER CONTROL
|
||||
// ============================================================
|
||||
|
||||
/**
|
||||
* Show File Watcher control modal
|
||||
*/
|
||||
async function showWatcherControlModal() {
|
||||
try {
|
||||
showRefreshToast(t('codexlens.loadingWatcherStatus') || 'Loading watcher status...', 'info');
|
||||
|
||||
// Fetch current watcher status
|
||||
const response = await fetch('/api/codexlens/watch/status');
|
||||
const status = await response.json();
|
||||
|
||||
const modalHtml = buildWatcherControlContent(status);
|
||||
|
||||
// Create and show modal
|
||||
const tempContainer = document.createElement('div');
|
||||
tempContainer.innerHTML = modalHtml;
|
||||
const modal = tempContainer.firstElementChild;
|
||||
document.body.appendChild(modal);
|
||||
|
||||
// Initialize icons
|
||||
if (window.lucide) lucide.createIcons();
|
||||
|
||||
// Start polling if watcher is running
|
||||
if (status.running) {
|
||||
startWatcherStatusPolling();
|
||||
}
|
||||
} catch (err) {
|
||||
showRefreshToast(t('common.error') + ': ' + err.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build File Watcher control modal content
|
||||
*/
|
||||
function buildWatcherControlContent(status) {
|
||||
const running = status.running || false;
|
||||
const rootPath = status.root_path || '';
|
||||
const eventsProcessed = status.events_processed || 0;
|
||||
const uptimeSeconds = status.uptime_seconds || 0;
|
||||
|
||||
// Format uptime
|
||||
const formatUptime = function(seconds) {
|
||||
if (seconds < 60) return seconds + 's';
|
||||
if (seconds < 3600) return Math.floor(seconds / 60) + 'm ' + (seconds % 60) + 's';
|
||||
return Math.floor(seconds / 3600) + 'h ' + Math.floor((seconds % 3600) / 60) + 'm';
|
||||
};
|
||||
|
||||
return '<div class="modal-backdrop" id="watcherControlModal">' +
|
||||
'<div class="modal-container max-w-lg">' +
|
||||
'<div class="modal-header">' +
|
||||
'<div class="flex items-center gap-3">' +
|
||||
'<div class="modal-icon">' +
|
||||
'<i data-lucide="eye" class="w-5 h-5"></i>' +
|
||||
'</div>' +
|
||||
'<div>' +
|
||||
'<h2 class="text-lg font-bold">' + (t('codexlens.watcherControl') || 'File Watcher') + '</h2>' +
|
||||
'<p class="text-xs text-muted-foreground">' + (t('codexlens.watcherControlDesc') || 'Real-time incremental index updates') + '</p>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'<button onclick="closeWatcherModal()" class="text-muted-foreground hover:text-foreground">' +
|
||||
'<i data-lucide="x" class="w-5 h-5"></i>' +
|
||||
'</button>' +
|
||||
'</div>' +
|
||||
|
||||
'<div class="modal-body space-y-4">' +
|
||||
// Status and Toggle
|
||||
'<div class="flex items-center justify-between p-4 bg-muted/30 rounded-lg">' +
|
||||
'<div class="flex items-center gap-3">' +
|
||||
'<div class="w-3 h-3 rounded-full ' + (running ? 'bg-success animate-pulse' : 'bg-muted-foreground') + '"></div>' +
|
||||
'<div>' +
|
||||
'<span class="font-medium">' + (running ? (t('codexlens.watcherRunning') || 'Watcher Running') : (t('codexlens.watcherStopped') || 'Watcher Stopped')) + '</span>' +
|
||||
(running ? '<p class="text-xs text-muted-foreground">' + rootPath + '</p>' : '') +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'<label class="relative inline-flex items-center cursor-pointer">' +
|
||||
'<input type="checkbox" id="watcherToggle" ' + (running ? 'checked' : '') + ' onchange="toggleWatcher()" class="sr-only peer" />' +
|
||||
'<div class="w-11 h-6 bg-muted peer-focus:outline-none rounded-full peer peer-checked:after:translate-x-full peer-checked:after:border-white after:content-[\'\'] after:absolute after:top-[2px] after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all peer-checked:bg-success"></div>' +
|
||||
'</label>' +
|
||||
'</div>' +
|
||||
|
||||
// Statistics (shown when running)
|
||||
'<div id="watcherStats" class="tool-config-section" style="display:' + (running ? 'block' : 'none') + '">' +
|
||||
'<h4>' + (t('codexlens.watcherStats') || 'Statistics') + '</h4>' +
|
||||
'<div class="grid grid-cols-2 gap-4">' +
|
||||
'<div class="p-3 bg-muted/20 rounded-lg">' +
|
||||
'<div class="text-2xl font-bold text-primary" id="watcherEventsCount">' + eventsProcessed + '</div>' +
|
||||
'<div class="text-xs text-muted-foreground">' + (t('codexlens.eventsProcessed') || 'Events Processed') + '</div>' +
|
||||
'</div>' +
|
||||
'<div class="p-3 bg-muted/20 rounded-lg">' +
|
||||
'<div class="text-2xl font-bold text-primary" id="watcherUptime">' + formatUptime(uptimeSeconds) + '</div>' +
|
||||
'<div class="text-xs text-muted-foreground">' + (t('codexlens.uptime') || 'Uptime') + '</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
||||
// Start Configuration (shown when not running)
|
||||
'<div id="watcherStartConfig" class="tool-config-section" style="display:' + (running ? 'none' : 'block') + '">' +
|
||||
'<h4>' + (t('codexlens.watcherConfig') || 'Configuration') + '</h4>' +
|
||||
'<div class="space-y-3">' +
|
||||
'<div>' +
|
||||
'<label class="block text-sm font-medium mb-1.5">' + (t('codexlens.watchPath') || 'Watch Path') + '</label>' +
|
||||
'<input type="text" id="watcherPath" value="" placeholder="Leave empty for current workspace" ' +
|
||||
'class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm" />' +
|
||||
'</div>' +
|
||||
'<div>' +
|
||||
'<label class="block text-sm font-medium mb-1.5">' + (t('codexlens.debounceMs') || 'Debounce (ms)') + '</label>' +
|
||||
'<input type="number" id="watcherDebounce" value="1000" min="100" max="10000" step="100" ' +
|
||||
'class="w-full px-3 py-2 border border-border rounded-lg bg-background text-sm" />' +
|
||||
'<p class="text-xs text-muted-foreground mt-1">' + (t('codexlens.debounceHint') || 'Time to wait before processing file changes') + '</p>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
||||
// Info box
|
||||
'<div class="flex items-start gap-2 bg-primary/10 border border-primary/30 rounded-lg p-3">' +
|
||||
'<i data-lucide="info" class="w-4 h-4 text-primary mt-0.5"></i>' +
|
||||
'<div class="text-sm text-muted-foreground">' +
|
||||
(t('codexlens.watcherInfo') || 'The file watcher monitors your codebase for changes and automatically updates the search index in real-time.') +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
|
||||
'<div class="modal-footer">' +
|
||||
'<button onclick="closeWatcherModal()" class="btn btn-outline">' + t('common.close') + '</button>' +
|
||||
'</div>' +
|
||||
'</div>' +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle file watcher on/off
|
||||
*/
|
||||
async function toggleWatcher() {
|
||||
var toggle = document.getElementById('watcherToggle');
|
||||
var shouldRun = toggle.checked;
|
||||
|
||||
try {
|
||||
if (shouldRun) {
|
||||
// Start watcher
|
||||
var watchPath = document.getElementById('watcherPath').value.trim();
|
||||
var debounceMs = parseInt(document.getElementById('watcherDebounce').value, 10) || 1000;
|
||||
|
||||
var response = await fetch('/api/codexlens/watch/start', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ path: watchPath || undefined, debounce_ms: debounceMs })
|
||||
});
|
||||
|
||||
var result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
showRefreshToast((t('codexlens.watcherStarted') || 'Watcher started') + ': ' + result.path, 'success');
|
||||
document.getElementById('watcherStats').style.display = 'block';
|
||||
document.getElementById('watcherStartConfig').style.display = 'none';
|
||||
startWatcherStatusPolling();
|
||||
} else {
|
||||
toggle.checked = false;
|
||||
showRefreshToast(t('common.error') + ': ' + result.error, 'error');
|
||||
}
|
||||
} else {
|
||||
// Stop watcher
|
||||
var response = await fetch('/api/codexlens/watch/stop', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
});
|
||||
|
||||
var result = await response.json();
|
||||
|
||||
if (result.success) {
|
||||
showRefreshToast((t('codexlens.watcherStopped') || 'Watcher stopped') + ': ' + result.events_processed + ' events processed', 'success');
|
||||
document.getElementById('watcherStats').style.display = 'none';
|
||||
document.getElementById('watcherStartConfig').style.display = 'block';
|
||||
stopWatcherStatusPolling();
|
||||
} else {
|
||||
toggle.checked = true;
|
||||
showRefreshToast(t('common.error') + ': ' + result.error, 'error');
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
toggle.checked = !shouldRun;
|
||||
showRefreshToast(t('common.error') + ': ' + err.message, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
// Watcher status polling
|
||||
var watcherPollingInterval = null;
|
||||
|
||||
function startWatcherStatusPolling() {
|
||||
if (watcherPollingInterval) return;
|
||||
|
||||
watcherPollingInterval = setInterval(async function() {
|
||||
try {
|
||||
var response = await fetch('/api/codexlens/watch/status');
|
||||
var status = await response.json();
|
||||
|
||||
if (status.running) {
|
||||
document.getElementById('watcherEventsCount').textContent = status.events_processed || 0;
|
||||
|
||||
// Format uptime
|
||||
var seconds = status.uptime_seconds || 0;
|
||||
var formatted = seconds < 60 ? seconds + 's' :
|
||||
seconds < 3600 ? Math.floor(seconds / 60) + 'm ' + (seconds % 60) + 's' :
|
||||
Math.floor(seconds / 3600) + 'h ' + Math.floor((seconds % 3600) / 60) + 'm';
|
||||
document.getElementById('watcherUptime').textContent = formatted;
|
||||
} else {
|
||||
// Watcher stopped externally
|
||||
stopWatcherStatusPolling();
|
||||
document.getElementById('watcherToggle').checked = false;
|
||||
document.getElementById('watcherStats').style.display = 'none';
|
||||
document.getElementById('watcherStartConfig').style.display = 'block';
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to poll watcher status:', err);
|
||||
}
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
function stopWatcherStatusPolling() {
|
||||
if (watcherPollingInterval) {
|
||||
clearInterval(watcherPollingInterval);
|
||||
watcherPollingInterval = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the watcher control modal
|
||||
*/
|
||||
function closeWatcherModal() {
|
||||
stopWatcherStatusPolling();
|
||||
var modal = document.getElementById('watcherControlModal');
|
||||
if (modal) modal.remove();
|
||||
}
|
||||
|
||||
@@ -80,6 +80,18 @@ reranker = [
|
||||
"transformers>=4.36",
|
||||
]
|
||||
|
||||
# SPLADE sparse retrieval
|
||||
splade = [
|
||||
"transformers>=4.36",
|
||||
"optimum[onnxruntime]>=1.16",
|
||||
]
|
||||
|
||||
# SPLADE with GPU acceleration (CUDA)
|
||||
splade-gpu = [
|
||||
"transformers>=4.36",
|
||||
"optimum[onnxruntime-gpu]>=1.16",
|
||||
]
|
||||
|
||||
# Encoding detection for non-UTF8 files
|
||||
encoding = [
|
||||
"chardet>=5.0",
|
||||
|
||||
@@ -415,11 +415,20 @@ def search(
|
||||
depth: int = typer.Option(-1, "--depth", "-d", help="Search depth (-1 = unlimited, 0 = current only)."),
|
||||
files_only: bool = typer.Option(False, "--files-only", "-f", help="Return only file paths without content snippets."),
|
||||
mode: str = typer.Option("auto", "--mode", "-m", help="Search mode: auto, exact, fuzzy, hybrid, vector, pure-vector."),
|
||||
weights: Optional[str] = typer.Option(None, "--weights", help="Custom RRF weights as 'exact,fuzzy,vector' (e.g., '0.5,0.3,0.2')."),
|
||||
weights: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--weights", "-w",
|
||||
help="RRF weights as key=value pairs (e.g., 'splade=0.4,vector=0.6' or 'exact=0.3,fuzzy=0.1,vector=0.6'). Default: auto-detect based on available backends."
|
||||
),
|
||||
use_fts: bool = typer.Option(
|
||||
False,
|
||||
"--use-fts",
|
||||
help="Use FTS (exact+fuzzy) instead of SPLADE for sparse retrieval"
|
||||
),
|
||||
json_mode: bool = typer.Option(False, "--json", help="Output JSON response."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Search indexed file contents using SQLite FTS5 or semantic vectors.
|
||||
"""Search indexed file contents using hybrid retrieval.
|
||||
|
||||
Uses chain search across directory indexes.
|
||||
Use --depth to limit search recursion (0 = current dir only).
|
||||
@@ -428,17 +437,27 @@ def search(
|
||||
- auto: Auto-detect (hybrid if embeddings exist, exact otherwise) [default]
|
||||
- exact: Exact FTS using unicode61 tokenizer - for code identifiers
|
||||
- fuzzy: Fuzzy FTS using trigram tokenizer - for typo-tolerant search
|
||||
- hybrid: RRF fusion of exact + fuzzy + vector (recommended) - best recall
|
||||
- vector: Vector search with exact FTS fallback - semantic + keyword
|
||||
- hybrid: RRF fusion of sparse + dense search (recommended) - best recall
|
||||
- vector: Vector search with sparse fallback - semantic + keyword
|
||||
- pure-vector: Pure semantic vector search only - natural language queries
|
||||
|
||||
SPLADE Mode:
|
||||
When SPLADE is available (pip install codex-lens[splade]), it automatically
|
||||
replaces FTS (exact+fuzzy) as the sparse retrieval backend. SPLADE provides
|
||||
semantic term expansion for better synonym handling.
|
||||
|
||||
Use --use-fts to force FTS mode instead of SPLADE.
|
||||
|
||||
Vector Search Requirements:
|
||||
Vector search modes require pre-generated embeddings.
|
||||
Use 'codexlens embeddings-generate' to create embeddings first.
|
||||
|
||||
Hybrid Mode:
|
||||
Default weights: exact=0.3, fuzzy=0.1, vector=0.6
|
||||
Use --weights to customize (e.g., --weights 0.5,0.3,0.2)
|
||||
Hybrid Mode Weights:
|
||||
Use --weights to adjust RRF fusion weights:
|
||||
- SPLADE mode: 'splade=0.4,vector=0.6' (default)
|
||||
- FTS mode: 'exact=0.3,fuzzy=0.1,vector=0.6' (default)
|
||||
|
||||
Legacy format also supported: '0.3,0.1,0.6' (exact,fuzzy,vector)
|
||||
|
||||
Examples:
|
||||
# Auto-detect mode (uses hybrid if embeddings available)
|
||||
@@ -450,11 +469,19 @@ def search(
|
||||
# Semantic search (requires embeddings)
|
||||
codexlens search "how to verify user credentials" --mode pure-vector
|
||||
|
||||
# Force hybrid mode
|
||||
codexlens search "authentication" --mode hybrid
|
||||
# Force hybrid mode with custom weights
|
||||
codexlens search "authentication" --mode hybrid --weights splade=0.5,vector=0.5
|
||||
|
||||
# Force FTS instead of SPLADE
|
||||
codexlens search "authentication" --use-fts
|
||||
"""
|
||||
_configure_logging(verbose, json_mode)
|
||||
search_path = path.expanduser().resolve()
|
||||
|
||||
# Configure search with FTS fallback if requested
|
||||
config = Config()
|
||||
if use_fts:
|
||||
config.use_fts_fallback = True
|
||||
|
||||
# Validate mode
|
||||
valid_modes = ["auto", "exact", "fuzzy", "hybrid", "vector", "pure-vector"]
|
||||
@@ -470,22 +497,56 @@ def search(
|
||||
hybrid_weights = None
|
||||
if weights:
|
||||
try:
|
||||
weight_parts = [float(w.strip()) for w in weights.split(",")]
|
||||
if len(weight_parts) == 3:
|
||||
weight_sum = sum(weight_parts)
|
||||
# Check if using key=value format (new) or legacy comma-separated format
|
||||
if "=" in weights:
|
||||
# New format: splade=0.4,vector=0.6 or exact=0.3,fuzzy=0.1,vector=0.6
|
||||
weight_dict = {}
|
||||
for pair in weights.split(","):
|
||||
if "=" in pair:
|
||||
key, val = pair.split("=", 1)
|
||||
weight_dict[key.strip()] = float(val.strip())
|
||||
else:
|
||||
raise ValueError("Mixed format not supported - use all key=value pairs")
|
||||
|
||||
# Validate and normalize weights
|
||||
weight_sum = sum(weight_dict.values())
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
# Normalize weights
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"exact": weight_parts[0],
|
||||
"fuzzy": weight_parts[1],
|
||||
"vector": weight_parts[2],
|
||||
}
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_dict = {k: v / weight_sum for k, v in weight_dict.items()}
|
||||
|
||||
hybrid_weights = weight_dict
|
||||
else:
|
||||
console.print("[yellow]Warning: Invalid weights format (need 3 values). Using defaults.[/yellow]")
|
||||
except ValueError:
|
||||
console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]")
|
||||
# Legacy format: 0.3,0.1,0.6 (exact,fuzzy,vector)
|
||||
weight_parts = [float(w.strip()) for w in weights.split(",")]
|
||||
if len(weight_parts) == 3:
|
||||
weight_sum = sum(weight_parts)
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"exact": weight_parts[0],
|
||||
"fuzzy": weight_parts[1],
|
||||
"vector": weight_parts[2],
|
||||
}
|
||||
elif len(weight_parts) == 2:
|
||||
# Two values: assume splade,vector
|
||||
weight_sum = sum(weight_parts)
|
||||
if abs(weight_sum - 1.0) > 0.01:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Weights sum to {weight_sum:.2f}, should sum to 1.0. Normalizing...[/yellow]")
|
||||
weight_parts = [w / weight_sum for w in weight_parts]
|
||||
hybrid_weights = {
|
||||
"splade": weight_parts[0],
|
||||
"vector": weight_parts[1],
|
||||
}
|
||||
else:
|
||||
if not json_mode:
|
||||
console.print("[yellow]Warning: Invalid weights format. Using defaults.[/yellow]")
|
||||
except ValueError as e:
|
||||
if not json_mode:
|
||||
console.print(f"[yellow]Warning: Invalid weights format ({e}). Using defaults.[/yellow]")
|
||||
|
||||
registry: RegistryStore | None = None
|
||||
try:
|
||||
@@ -2381,6 +2442,188 @@ def gpu_reset(
|
||||
console.print(f" Device: [cyan]{gpu_info.gpu_name}[/cyan]")
|
||||
|
||||
|
||||
|
||||
# ==================== SPLADE Commands ====================
|
||||
|
||||
@app.command("splade-index")
|
||||
def splade_index_command(
|
||||
path: Path = typer.Argument(..., help="Project path to index"),
|
||||
rebuild: bool = typer.Option(False, "--rebuild", "-r", help="Force rebuild SPLADE index"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""Generate SPLADE sparse index for existing codebase.
|
||||
|
||||
Encodes all semantic chunks with SPLADE model and builds inverted index
|
||||
for efficient sparse retrieval.
|
||||
|
||||
Examples:
|
||||
codexlens splade-index ~/projects/my-app
|
||||
codexlens splade-index . --rebuild
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
|
||||
# Check SPLADE availability
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
console.print(f"[red]SPLADE not available: {err}[/red]")
|
||||
console.print("[dim]Install with: pip install transformers torch[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Find index database
|
||||
target_path = path.expanduser().resolve()
|
||||
|
||||
# Try to find _index.db
|
||||
if target_path.is_file() and target_path.name == "_index.db":
|
||||
index_db = target_path
|
||||
elif target_path.is_dir():
|
||||
# Check for local .codexlens/_index.db
|
||||
local_index = target_path / ".codexlens" / "_index.db"
|
||||
if local_index.exists():
|
||||
index_db = local_index
|
||||
else:
|
||||
# Try to find via registry
|
||||
registry = RegistryStore()
|
||||
try:
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(target_path)
|
||||
if not index_db.exists():
|
||||
console.print(f"[red]Error:[/red] No index found for {target_path}")
|
||||
console.print("Run 'codexlens init' first to create an index")
|
||||
raise typer.Exit(1)
|
||||
finally:
|
||||
registry.close()
|
||||
else:
|
||||
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
|
||||
raise typer.Exit(1)
|
||||
|
||||
splade_db = index_db.parent / "_splade.db"
|
||||
|
||||
if splade_db.exists() and not rebuild:
|
||||
console.print("[yellow]SPLADE index exists. Use --rebuild to regenerate.[/yellow]")
|
||||
return
|
||||
|
||||
# Load chunks from vector store
|
||||
console.print(f"[blue]Loading chunks from {index_db.name}...[/blue]")
|
||||
vector_store = VectorStore(index_db)
|
||||
chunks = vector_store.get_all_chunks()
|
||||
|
||||
if not chunks:
|
||||
console.print("[yellow]No chunks found in vector store[/yellow]")
|
||||
console.print("[dim]Generate embeddings first with 'codexlens embeddings-generate'[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[blue]Encoding {len(chunks)} chunks with SPLADE...[/blue]")
|
||||
|
||||
# Initialize SPLADE
|
||||
encoder = get_splade_encoder()
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Encode in batches with progress bar
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Encoding...", total=len(chunks))
|
||||
for chunk in chunks:
|
||||
sparse_vec = encoder.encode_text(chunk.content)
|
||||
splade_index.add_posting(chunk.id, sparse_vec)
|
||||
progress.advance(task)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=encoder.model_name,
|
||||
vocab_size=encoder.vocab_size
|
||||
)
|
||||
|
||||
stats = splade_index.get_stats()
|
||||
console.print(f"[green]✓[/green] SPLADE index built: {stats['unique_chunks']} chunks, {stats['total_postings']} postings")
|
||||
console.print(f" Database: [dim]{splade_db}[/dim]")
|
||||
|
||||
|
||||
@app.command("splade-status")
|
||||
def splade_status_command(
|
||||
path: Path = typer.Argument(..., help="Project path"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output."),
|
||||
) -> None:
|
||||
"""Show SPLADE index status and statistics.
|
||||
|
||||
Examples:
|
||||
codexlens splade-status ~/projects/my-app
|
||||
codexlens splade-status .
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
|
||||
# Find index database
|
||||
target_path = path.expanduser().resolve()
|
||||
|
||||
if target_path.is_file() and target_path.name == "_index.db":
|
||||
splade_db = target_path.parent / "_splade.db"
|
||||
elif target_path.is_dir():
|
||||
# Check for local .codexlens/_splade.db
|
||||
local_splade = target_path / ".codexlens" / "_splade.db"
|
||||
if local_splade.exists():
|
||||
splade_db = local_splade
|
||||
else:
|
||||
# Try to find via registry
|
||||
registry = RegistryStore()
|
||||
try:
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(target_path)
|
||||
splade_db = index_db.parent / "_splade.db"
|
||||
finally:
|
||||
registry.close()
|
||||
else:
|
||||
console.print(f"[red]Error:[/red] Path must be _index.db file or indexed directory")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if not splade_db.exists():
|
||||
console.print("[yellow]No SPLADE index found[/yellow]")
|
||||
console.print(f"[dim]Run 'codexlens splade-index {path}' to create one[/dim]")
|
||||
return
|
||||
|
||||
splade_index = SpladeIndex(splade_db)
|
||||
|
||||
if not splade_index.has_index():
|
||||
console.print("[yellow]SPLADE tables not initialized[/yellow]")
|
||||
return
|
||||
|
||||
metadata = splade_index.get_metadata()
|
||||
stats = splade_index.get_stats()
|
||||
|
||||
# Create status table
|
||||
table = Table(title="SPLADE Index Status", show_header=False)
|
||||
table.add_column("Property", style="cyan")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Database", str(splade_db))
|
||||
if metadata:
|
||||
table.add_row("Model", metadata['model_name'])
|
||||
table.add_row("Vocab Size", str(metadata['vocab_size']))
|
||||
table.add_row("Chunks", str(stats['unique_chunks']))
|
||||
table.add_row("Unique Tokens", str(stats['unique_tokens']))
|
||||
table.add_row("Total Postings", str(stats['total_postings']))
|
||||
|
||||
ok, err = check_splade_available()
|
||||
status_text = "[green]Yes[/green]" if ok else f"[red]No[/red] - {err}"
|
||||
table.add_row("SPLADE Available", status_text)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
# ==================== Watch Command ====================
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -33,6 +33,15 @@ def _cleanup_fastembed_resources() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _cleanup_splade_resources() -> None:
|
||||
"""Release SPLADE encoder ONNX resources."""
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import clear_splade_cache
|
||||
clear_splade_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _generate_chunks_from_cursor(
|
||||
cursor,
|
||||
chunker,
|
||||
@@ -675,10 +684,96 @@ def generate_embeddings(
|
||||
if progress_callback:
|
||||
progress_callback(f"Finalizing index... Building ANN index for {total_chunks_created} chunks")
|
||||
|
||||
# --- SPLADE SPARSE ENCODING (after dense embeddings) ---
|
||||
# Add SPLADE encoding if enabled in config
|
||||
splade_success = False
|
||||
splade_error = None
|
||||
|
||||
try:
|
||||
from codexlens.config import Config
|
||||
config = Config.load()
|
||||
|
||||
if config.enable_splade:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available, get_splade_encoder
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if ok:
|
||||
if progress_callback:
|
||||
progress_callback(f"Generating SPLADE sparse vectors for {total_chunks_created} chunks...")
|
||||
|
||||
# Initialize SPLADE encoder and index
|
||||
splade_encoder = get_splade_encoder(use_gpu=use_gpu)
|
||||
# Use main index database for SPLADE (not separate _splade.db)
|
||||
splade_index = SpladeIndex(index_path)
|
||||
splade_index.create_tables()
|
||||
|
||||
# Retrieve all chunks from database for SPLADE encoding
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute("SELECT id, content FROM semantic_chunks ORDER BY id")
|
||||
|
||||
# Batch encode for efficiency
|
||||
SPLADE_BATCH_SIZE = 32
|
||||
batch_postings = []
|
||||
chunk_batch = []
|
||||
chunk_ids = []
|
||||
|
||||
for row in cursor:
|
||||
chunk_id = row["id"]
|
||||
content = row["content"]
|
||||
|
||||
chunk_ids.append(chunk_id)
|
||||
chunk_batch.append(content)
|
||||
|
||||
# Process batch when full
|
||||
if len(chunk_batch) >= SPLADE_BATCH_SIZE:
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
batch_postings.append((cid, sparse_vec))
|
||||
|
||||
chunk_batch = []
|
||||
chunk_ids = []
|
||||
|
||||
# Process remaining chunks
|
||||
if chunk_batch:
|
||||
sparse_vecs = splade_encoder.encode_batch(chunk_batch, batch_size=SPLADE_BATCH_SIZE)
|
||||
for cid, sparse_vec in zip(chunk_ids, sparse_vecs):
|
||||
batch_postings.append((cid, sparse_vec))
|
||||
|
||||
# Batch insert all postings
|
||||
if batch_postings:
|
||||
splade_index.add_postings_batch(batch_postings)
|
||||
|
||||
# Set metadata
|
||||
splade_index.set_metadata(
|
||||
model_name=splade_encoder.model_name,
|
||||
vocab_size=splade_encoder.vocab_size
|
||||
)
|
||||
|
||||
splade_success = True
|
||||
if progress_callback:
|
||||
stats = splade_index.get_stats()
|
||||
progress_callback(
|
||||
f"SPLADE index created: {stats['total_postings']} postings, "
|
||||
f"{stats['unique_tokens']} unique tokens"
|
||||
)
|
||||
else:
|
||||
logger.debug("SPLADE not available: %s", err)
|
||||
splade_error = f"SPLADE not available: {err}"
|
||||
except Exception as e:
|
||||
splade_error = str(e)
|
||||
logger.warning("SPLADE encoding failed: %s", e)
|
||||
|
||||
# Report SPLADE status after processing
|
||||
if progress_callback and not splade_success and splade_error:
|
||||
progress_callback(f"SPLADE index: FAILED - {splade_error}")
|
||||
|
||||
except Exception as e:
|
||||
# Cleanup on error to prevent process hanging
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -690,6 +785,7 @@ def generate_embeddings(
|
||||
# This is critical - without it, ONNX Runtime threads prevent Python from exiting
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -874,6 +970,7 @@ def generate_embeddings_recursive(
|
||||
# Each generate_embeddings() call does its own cleanup, but do a final one to be safe
|
||||
try:
|
||||
_cleanup_fastembed_resources()
|
||||
_cleanup_splade_resources()
|
||||
gc.collect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -103,6 +103,15 @@ class Config:
|
||||
# For litellm: model name from config (e.g., "qwen3-embedding")
|
||||
embedding_use_gpu: bool = True # For fastembed: whether to use GPU acceleration
|
||||
|
||||
# SPLADE sparse retrieval configuration
|
||||
enable_splade: bool = True # Enable SPLADE as default sparse backend
|
||||
splade_model: str = "naver/splade-cocondenser-ensembledistil"
|
||||
splade_threshold: float = 0.01 # Min weight to store in index
|
||||
splade_onnx_path: Optional[str] = None # Custom ONNX model path
|
||||
|
||||
# FTS fallback (disabled by default, available via --use-fts)
|
||||
use_fts_fallback: bool = False # Use FTS instead of SPLADE
|
||||
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
enable_merkle_detection: bool = True # Enable content-hash based incremental indexing
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError, as_completed
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -33,6 +33,8 @@ def timer(name: str, logger: logging.Logger, level: int = logging.DEBUG):
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import (
|
||||
DEFAULT_WEIGHTS,
|
||||
FTS_FALLBACK_WEIGHTS,
|
||||
apply_symbol_boost,
|
||||
cross_encoder_rerank,
|
||||
get_rrf_weights,
|
||||
@@ -54,12 +56,9 @@ class HybridSearchEngine:
|
||||
default_weights: Default RRF weights for each source
|
||||
"""
|
||||
|
||||
# Default RRF weights (vector: 60%, exact: 30%, fuzzy: 10%)
|
||||
DEFAULT_WEIGHTS = {
|
||||
"exact": 0.3,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.6,
|
||||
}
|
||||
# NOTE: DEFAULT_WEIGHTS imported from ranking.py - single source of truth
|
||||
# Default RRF weights: SPLADE-based hybrid (splade: 0.4, vector: 0.6)
|
||||
# FTS fallback mode uses FTS_FALLBACK_WEIGHTS (exact: 0.3, fuzzy: 0.1, vector: 0.6)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -75,10 +74,11 @@ class HybridSearchEngine:
|
||||
embedder: Optional embedder instance for embedding-based reranking
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.weights = weights or self.DEFAULT_WEIGHTS.copy()
|
||||
self.weights = weights or DEFAULT_WEIGHTS.copy()
|
||||
self._config = config
|
||||
self.embedder = embedder
|
||||
self.reranker: Any = None
|
||||
self._use_gpu = config.embedding_use_gpu if config else True
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -124,6 +124,26 @@ class HybridSearchEngine:
|
||||
|
||||
# Determine which backends to use
|
||||
backends = {}
|
||||
|
||||
# Check if SPLADE is available
|
||||
splade_available = False
|
||||
# Respect config.enable_splade flag and use_fts_fallback flag
|
||||
if self._config and getattr(self._config, 'use_fts_fallback', False):
|
||||
# Config explicitly requests FTS fallback - disable SPLADE
|
||||
splade_available = False
|
||||
elif self._config and not getattr(self._config, 'enable_splade', True):
|
||||
# Config explicitly disabled SPLADE
|
||||
splade_available = False
|
||||
else:
|
||||
# Check if SPLADE dependencies are available
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import check_splade_available
|
||||
ok, _ = check_splade_available()
|
||||
if ok:
|
||||
# SPLADE tables are in main index database, will check table existence in _search_splade
|
||||
splade_available = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if pure_vector:
|
||||
# Pure vector mode: only use vector search, no FTS fallback
|
||||
@@ -138,12 +158,19 @@ class HybridSearchEngine:
|
||||
)
|
||||
backends["exact"] = True
|
||||
else:
|
||||
# Hybrid mode: always include exact search as baseline
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
# Hybrid mode: default to SPLADE if available, otherwise use FTS
|
||||
if splade_available:
|
||||
# Default: enable SPLADE, disable exact and fuzzy
|
||||
backends["splade"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
else:
|
||||
# Fallback mode: enable exact+fuzzy when SPLADE unavailable
|
||||
backends["exact"] = True
|
||||
if enable_fuzzy:
|
||||
backends["fuzzy"] = True
|
||||
if enable_vector:
|
||||
backends["vector"] = True
|
||||
|
||||
# Execute parallel searches
|
||||
with timer("parallel_search_total", self.logger):
|
||||
@@ -354,23 +381,40 @@ class HybridSearchEngine:
|
||||
)
|
||||
future_to_source[future] = "vector"
|
||||
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_source):
|
||||
source = future_to_source[future]
|
||||
elapsed_ms = (time.perf_counter() - submit_times[source]) * 1000
|
||||
timing_data[source] = elapsed_ms
|
||||
try:
|
||||
results = future.result()
|
||||
# Tag results with source for debugging
|
||||
tagged_results = tag_search_source(results, source)
|
||||
results_map[source] = tagged_results
|
||||
self.logger.debug(
|
||||
"[TIMING] %s_search: %.2fms (%d results)",
|
||||
source, elapsed_ms, len(results)
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.error("Search failed for %s: %s", source, exc)
|
||||
results_map[source] = []
|
||||
if backends.get("splade"):
|
||||
submit_times["splade"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_splade, index_path, query, limit
|
||||
)
|
||||
future_to_source[future] = "splade"
|
||||
|
||||
# Collect results as they complete with timeout protection
|
||||
try:
|
||||
for future in as_completed(future_to_source, timeout=30.0):
|
||||
source = future_to_source[future]
|
||||
elapsed_ms = (time.perf_counter() - submit_times[source]) * 1000
|
||||
timing_data[source] = elapsed_ms
|
||||
try:
|
||||
results = future.result(timeout=10.0)
|
||||
# Tag results with source for debugging
|
||||
tagged_results = tag_search_source(results, source)
|
||||
results_map[source] = tagged_results
|
||||
self.logger.debug(
|
||||
"[TIMING] %s_search: %.2fms (%d results)",
|
||||
source, elapsed_ms, len(results)
|
||||
)
|
||||
except (Exception, FuturesTimeoutError) as exc:
|
||||
self.logger.error("Search failed for %s: %s", source, exc)
|
||||
results_map[source] = []
|
||||
except FuturesTimeoutError:
|
||||
self.logger.warning("Search timeout: some backends did not respond in time")
|
||||
# Cancel remaining futures
|
||||
for future in future_to_source:
|
||||
future.cancel()
|
||||
# Set empty results for sources that didn't complete
|
||||
for source in backends:
|
||||
if source not in results_map:
|
||||
results_map[source] = []
|
||||
|
||||
# Log timing summary
|
||||
if timing_data:
|
||||
@@ -564,3 +608,113 @@ class HybridSearchEngine:
|
||||
except Exception as exc:
|
||||
self.logger.error("Vector search error: %s", exc)
|
||||
return []
|
||||
|
||||
def _search_splade(
|
||||
self, index_path: Path, query: str, limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""SPLADE sparse retrieval via inverted index.
|
||||
|
||||
Args:
|
||||
index_path: Path to _index.db file
|
||||
query: Natural language query string
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of SearchResult ordered by SPLADE score
|
||||
"""
|
||||
try:
|
||||
from codexlens.semantic.splade_encoder import get_splade_encoder, check_splade_available
|
||||
from codexlens.storage.splade_index import SpladeIndex
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
# Check dependencies
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
self.logger.debug("SPLADE not available: %s", err)
|
||||
return []
|
||||
|
||||
# Use main index database (SPLADE tables are in _index.db, not separate _splade.db)
|
||||
splade_index = SpladeIndex(index_path)
|
||||
if not splade_index.has_index():
|
||||
self.logger.debug("SPLADE index not initialized")
|
||||
return []
|
||||
|
||||
# Encode query to sparse vector
|
||||
encoder = get_splade_encoder(use_gpu=self._use_gpu)
|
||||
query_sparse = encoder.encode_text(query)
|
||||
|
||||
# Search inverted index for top matches
|
||||
raw_results = splade_index.search(query_sparse, limit=limit, min_score=0.0)
|
||||
|
||||
if not raw_results:
|
||||
return []
|
||||
|
||||
# Fetch chunk details from main index database
|
||||
chunk_ids = [chunk_id for chunk_id, _ in raw_results]
|
||||
score_map = {chunk_id: score for chunk_id, score in raw_results}
|
||||
|
||||
# Query semantic_chunks table for full details
|
||||
placeholders = ",".join("?" * len(chunk_ids))
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT id, file_path, content, metadata
|
||||
FROM semantic_chunks
|
||||
WHERE id IN ({placeholders})
|
||||
""",
|
||||
chunk_ids
|
||||
).fetchall()
|
||||
|
||||
# Build SearchResult objects
|
||||
results = []
|
||||
for row in rows:
|
||||
chunk_id = row["id"]
|
||||
file_path = row["file_path"]
|
||||
content = row["content"]
|
||||
metadata_json = row["metadata"]
|
||||
metadata = json.loads(metadata_json) if metadata_json else {}
|
||||
|
||||
score = score_map.get(chunk_id, 0.0)
|
||||
|
||||
# Build excerpt (short preview)
|
||||
excerpt = content[:200] + "..." if len(content) > 200 else content
|
||||
|
||||
# Extract symbol information from metadata
|
||||
symbol_name = metadata.get("symbol_name")
|
||||
symbol_kind = metadata.get("symbol_kind")
|
||||
start_line = metadata.get("start_line")
|
||||
end_line = metadata.get("end_line")
|
||||
|
||||
# Build Symbol object if we have symbol info
|
||||
symbol = None
|
||||
if symbol_name and symbol_kind and start_line and end_line:
|
||||
try:
|
||||
from codexlens.entities import Symbol
|
||||
symbol = Symbol(
|
||||
name=symbol_name,
|
||||
kind=symbol_kind,
|
||||
range=(start_line, end_line)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
results.append(SearchResult(
|
||||
path=file_path,
|
||||
score=score,
|
||||
excerpt=excerpt,
|
||||
content=content,
|
||||
symbol=symbol,
|
||||
metadata=metadata,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
symbol_name=symbol_name,
|
||||
symbol_kind=symbol_kind,
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.debug("SPLADE search error: %s", exc)
|
||||
return []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Ranking algorithms for hybrid search result fusion.
|
||||
|
||||
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
||||
for combining results from heterogeneous search backends (exact FTS, fuzzy FTS, vector search).
|
||||
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,6 +14,20 @@ from typing import Any, Dict, List
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
# Default RRF weights for SPLADE-based hybrid search
|
||||
DEFAULT_WEIGHTS = {
|
||||
"splade": 0.4, # Replaces exact(0.3) + fuzzy(0.1)
|
||||
"vector": 0.6,
|
||||
}
|
||||
|
||||
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
||||
FTS_FALLBACK_WEIGHTS = {
|
||||
"exact": 0.3,
|
||||
"fuzzy": 0.1,
|
||||
"vector": 0.6,
|
||||
}
|
||||
|
||||
|
||||
class QueryIntent(str, Enum):
|
||||
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||
|
||||
@@ -87,15 +101,24 @@ def adjust_weights_by_intent(
|
||||
intent: QueryIntent,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Map intent → weights (kept aligned with TypeScript mapping)."""
|
||||
"""Adjust RRF weights based on query intent."""
|
||||
# Check if using SPLADE or FTS mode
|
||||
use_splade = "splade" in base_weights
|
||||
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
if use_splade:
|
||||
target = {"splade": 0.6, "vector": 0.4}
|
||||
else:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
if use_splade:
|
||||
target = {"splade": 0.3, "vector": 0.7}
|
||||
else:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
# Preserve only keys that are present in base_weights (active backends).
|
||||
|
||||
# Filter to active backends
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
return normalize_weights(filtered)
|
||||
|
||||
225
codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md
Normal file
225
codex-lens/src/codexlens/semantic/SPLADE_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPLADE Encoder Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
Created `splade_encoder.py` - A complete ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
## File Location
|
||||
|
||||
`src/codexlens/semantic/splade_encoder.py` (474 lines)
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Dependency Checking
|
||||
|
||||
**Function**: `check_splade_available() -> Tuple[bool, Optional[str]]`
|
||||
- Validates numpy, onnxruntime, optimum, transformers availability
|
||||
- Returns (True, None) if all dependencies present
|
||||
- Returns (False, error_message) with install instructions if missing
|
||||
|
||||
### 2. Caching System
|
||||
|
||||
**Global Cache**: Thread-safe singleton pattern
|
||||
- `_splade_cache: Dict[str, SpladeEncoder]` - Global encoder cache
|
||||
- `_cache_lock: threading.RLock()` - Thread safety lock
|
||||
|
||||
**Factory Function**: `get_splade_encoder(...) -> SpladeEncoder`
|
||||
- Cache key includes: model_name, gpu/cpu, max_length, sparsity_threshold
|
||||
- Pre-loads model on first access
|
||||
- Returns cached instance on subsequent calls
|
||||
|
||||
**Cleanup Function**: `clear_splade_cache() -> None`
|
||||
- Releases ONNX resources
|
||||
- Clears model and tokenizer references
|
||||
- Prevents memory leaks
|
||||
|
||||
### 3. SpladeEncoder Class
|
||||
|
||||
#### Initialization Parameters
|
||||
- `model_name: str` - Default: "naver/splade-cocondenser-ensembledistil"
|
||||
- `use_gpu: bool` - Enable GPU acceleration (default: True)
|
||||
- `max_length: int` - Max sequence length (default: 512)
|
||||
- `sparsity_threshold: float` - Min weight threshold (default: 0.01)
|
||||
- `providers: Optional[List]` - Explicit ONNX providers (overrides use_gpu)
|
||||
|
||||
#### Core Methods
|
||||
|
||||
**`_load_model()`**: Lazy loading with GPU support
|
||||
- Uses `optimum.onnxruntime.ORTModelForMaskedLM`
|
||||
- Falls back to CPU if GPU unavailable
|
||||
- Integrates with `gpu_support.get_optimal_providers()`
|
||||
- Handles device_id options for DirectML/CUDA
|
||||
|
||||
**`_splade_activation(logits, attention_mask)`**: Static method
|
||||
- Formula: `log(1 + ReLU(logits)) * attention_mask`
|
||||
- Input: (batch, seq_len, vocab_size)
|
||||
- Output: (batch, seq_len, vocab_size)
|
||||
|
||||
**`_max_pooling(splade_repr)`**: Static method
|
||||
- Max pooling over sequence dimension
|
||||
- Input: (batch, seq_len, vocab_size)
|
||||
- Output: (batch, vocab_size)
|
||||
|
||||
**`_to_sparse_dict(dense_vec)`**: Conversion helper
|
||||
- Filters by sparsity_threshold
|
||||
- Returns: `Dict[int, float]` mapping token_id to weight
|
||||
|
||||
**`encode_text(text: str) -> Dict[int, float]`**: Single text encoding
|
||||
- Tokenizes input with truncation/padding
|
||||
- Forward pass through ONNX model
|
||||
- Applies SPLADE activation + max pooling
|
||||
- Returns sparse vector
|
||||
|
||||
**`encode_batch(texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]`**: Batch encoding
|
||||
- Processes in batches for memory efficiency
|
||||
- Returns list of sparse vectors
|
||||
|
||||
#### Properties
|
||||
|
||||
**`vocab_size: int`**: Vocabulary size (~30k for BERT)
|
||||
- Cached after first model load
|
||||
- Returns tokenizer length
|
||||
|
||||
#### Debugging Methods
|
||||
|
||||
**`get_token(token_id: int) -> str`**
|
||||
- Converts token_id to human-readable string
|
||||
- Uses tokenizer.decode()
|
||||
|
||||
**`get_top_tokens(sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]`**
|
||||
- Extracts top-k highest-weight tokens
|
||||
- Returns (token_string, weight) pairs
|
||||
- Useful for understanding model focus
|
||||
|
||||
## Design Patterns Followed
|
||||
|
||||
### 1. From `onnx_reranker.py`
|
||||
✓ ONNX loading with provider detection
|
||||
✓ Lazy model initialization
|
||||
✓ Thread-safe loading with RLock
|
||||
✓ Signature inspection for backward compatibility
|
||||
✓ Fallback for older Optimum versions
|
||||
✓ Static helper methods for numerical operations
|
||||
|
||||
### 2. From `embedder.py`
|
||||
✓ Global cache with thread safety
|
||||
✓ Factory function pattern (get_splade_encoder)
|
||||
✓ Cache cleanup function (clear_splade_cache)
|
||||
✓ GPU provider configuration
|
||||
✓ Batch processing support
|
||||
|
||||
### 3. From `gpu_support.py`
|
||||
✓ `get_optimal_providers(use_gpu, with_device_options=True)`
|
||||
✓ Device ID options for DirectML/CUDA
|
||||
✓ Provider tuple format: (provider_name, options_dict)
|
||||
|
||||
## SPLADE Algorithm
|
||||
|
||||
### Activation Formula
|
||||
```python
|
||||
# Step 1: ReLU activation
|
||||
relu_logits = max(0, logits)
|
||||
|
||||
# Step 2: Log(1 + x) transformation
|
||||
log_relu = log(1 + relu_logits)
|
||||
|
||||
# Step 3: Apply attention mask
|
||||
splade_repr = log_relu * attention_mask
|
||||
|
||||
# Step 4: Max pooling over sequence
|
||||
splade_vec = max(splade_repr, axis=sequence_length)
|
||||
|
||||
# Step 5: Sparsification by threshold
|
||||
sparse_dict = {token_id: weight for token_id, weight in enumerate(splade_vec) if weight > threshold}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
- Sparse dictionary: `{token_id: weight}`
|
||||
- Token IDs: 0 to vocab_size-1 (typically ~30,000)
|
||||
- Weights: Float values > sparsity_threshold
|
||||
- Interpretable: Can decode token_ids to strings
|
||||
|
||||
## Integration Points
|
||||
|
||||
### With `splade_index.py`
|
||||
- `SpladeIndex.add_posting(chunk_id, sparse_vec: Dict[int, float])`
|
||||
- `SpladeIndex.search(query_sparse: Dict[int, float])`
|
||||
- Encoder produces the sparse vectors consumed by index
|
||||
|
||||
### With Indexing Pipeline
|
||||
```python
|
||||
encoder = get_splade_encoder(use_gpu=True)
|
||||
|
||||
# Single document
|
||||
sparse_vec = encoder.encode_text("def main():\n print('hello')")
|
||||
index.add_posting(chunk_id=1, sparse_vec=sparse_vec)
|
||||
|
||||
# Batch indexing
|
||||
texts = ["code chunk 1", "code chunk 2", ...]
|
||||
sparse_vecs = encoder.encode_batch(texts, batch_size=64)
|
||||
postings = [(chunk_id, vec) for chunk_id, vec in enumerate(sparse_vecs)]
|
||||
index.add_postings_batch(postings)
|
||||
```
|
||||
|
||||
### With Search Pipeline
|
||||
```python
|
||||
encoder = get_splade_encoder(use_gpu=True)
|
||||
query_sparse = encoder.encode_text("authentication function")
|
||||
results = index.search(query_sparse, limit=50, min_score=0.5)
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
Required packages:
|
||||
- `numpy` - Numerical operations
|
||||
- `onnxruntime` - ONNX model execution (CPU)
|
||||
- `onnxruntime-gpu` - ONNX with GPU support (optional)
|
||||
- `optimum[onnxruntime]` - Hugging Face ONNX optimization
|
||||
- `transformers` - Tokenizer and model loading
|
||||
|
||||
Install command:
|
||||
```bash
|
||||
# CPU only
|
||||
pip install numpy onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
# With GPU support
|
||||
pip install numpy onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
```
|
||||
|
||||
## Testing Status
|
||||
|
||||
✓ Python syntax validation passed
|
||||
✓ Module import successful
|
||||
✓ Dependency checking works correctly
|
||||
✗ Full functional test pending (requires optimum installation)
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Install dependencies for functional testing
|
||||
2. Create unit tests in `tests/semantic/test_splade_encoder.py`
|
||||
3. Benchmark encoding performance (CPU vs GPU)
|
||||
4. Integrate with codex-lens indexing pipeline
|
||||
5. Add SPLADE option to semantic search configuration
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Memory Usage
|
||||
- Model size: ~100MB (ONNX optimized)
|
||||
- Sparse vectors: ~100-500 non-zero entries per document
|
||||
- Batch size: 32 recommended (adjust based on GPU memory)
|
||||
|
||||
### Speed Benchmarks (Expected)
|
||||
- CPU encoding: ~10-20 docs/sec
|
||||
- GPU encoding (CUDA): ~100-200 docs/sec
|
||||
- GPU encoding (DirectML): ~50-100 docs/sec
|
||||
|
||||
### Sparsity Analysis
|
||||
- Threshold 0.01: ~200-400 tokens per document
|
||||
- Threshold 0.05: ~100-200 tokens per document
|
||||
- Threshold 0.10: ~50-100 tokens per document
|
||||
|
||||
## References
|
||||
|
||||
- SPLADE paper: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
- Naver model: https://huggingface.co/naver/splade-cocondenser-ensembledistil
|
||||
474
codex-lens/src/codexlens/semantic/splade_encoder.py
Normal file
474
codex-lens/src/codexlens/semantic/splade_encoder.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""ONNX-optimized SPLADE sparse encoder for code search.
|
||||
|
||||
This module provides SPLADE (Sparse Lexical and Expansion) encoding using ONNX Runtime
|
||||
for efficient sparse vector generation. SPLADE produces vocabulary-aligned sparse vectors
|
||||
that combine the interpretability of BM25 with neural relevance modeling.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
|
||||
Install (GPU):
|
||||
pip install onnxruntime-gpu optimum[onnxruntime-gpu] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_splade_available() -> Tuple[bool, Optional[str]]:
|
||||
"""Check whether SPLADE dependencies are available.
|
||||
|
||||
Returns:
|
||||
Tuple of (available: bool, error_message: Optional[str])
|
||||
"""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc:
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc:
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# Global cache for SPLADE encoders (singleton pattern)
|
||||
_splade_cache: Dict[str, "SpladeEncoder"] = {}
|
||||
_cache_lock = threading.RLock()
|
||||
|
||||
|
||||
def get_splade_encoder(
|
||||
model_name: str = "naver/splade-cocondenser-ensembledistil",
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
) -> "SpladeEncoder":
|
||||
"""Get or create cached SPLADE encoder (thread-safe singleton).
|
||||
|
||||
This function provides significant performance improvement by reusing
|
||||
SpladeEncoder instances across multiple searches, avoiding repeated model
|
||||
loading overhead.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
|
||||
Returns:
|
||||
Cached SpladeEncoder instance for the given configuration
|
||||
"""
|
||||
global _splade_cache
|
||||
|
||||
# Cache key includes all configuration parameters
|
||||
cache_key = f"{model_name}:{'gpu' if use_gpu else 'cpu'}:{max_length}:{sparsity_threshold}"
|
||||
|
||||
with _cache_lock:
|
||||
encoder = _splade_cache.get(cache_key)
|
||||
if encoder is not None:
|
||||
return encoder
|
||||
|
||||
# Create new encoder and cache it
|
||||
encoder = SpladeEncoder(
|
||||
model_name=model_name,
|
||||
use_gpu=use_gpu,
|
||||
max_length=max_length,
|
||||
sparsity_threshold=sparsity_threshold,
|
||||
)
|
||||
# Pre-load model to ensure it's ready
|
||||
encoder._load_model()
|
||||
_splade_cache[cache_key] = encoder
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def clear_splade_cache() -> None:
|
||||
"""Clear the SPLADE encoder cache and release ONNX resources.
|
||||
|
||||
This method ensures proper cleanup of ONNX model resources to prevent
|
||||
memory leaks when encoders are no longer needed.
|
||||
"""
|
||||
global _splade_cache
|
||||
with _cache_lock:
|
||||
# Release ONNX resources before clearing cache
|
||||
for encoder in _splade_cache.values():
|
||||
if encoder._model is not None:
|
||||
del encoder._model
|
||||
encoder._model = None
|
||||
if encoder._tokenizer is not None:
|
||||
del encoder._tokenizer
|
||||
encoder._tokenizer = None
|
||||
_splade_cache.clear()
|
||||
|
||||
|
||||
class SpladeEncoder:
|
||||
"""ONNX-optimized SPLADE sparse encoder.
|
||||
|
||||
Produces sparse vectors with vocabulary-aligned dimensions.
|
||||
Output: Dict[int, float] mapping token_id to weight.
|
||||
|
||||
SPLADE activation formula:
|
||||
splade_repr = log(1 + ReLU(logits)) * attention_mask
|
||||
splade_vec = max_pooling(splade_repr, axis=sequence_length)
|
||||
|
||||
References:
|
||||
- SPLADE: https://arxiv.org/abs/2107.05720
|
||||
- SPLADE v2: https://arxiv.org/abs/2109.10086
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL = "naver/splade-cocondenser-ensembledistil"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_MODEL,
|
||||
use_gpu: bool = True,
|
||||
max_length: int = 512,
|
||||
sparsity_threshold: float = 0.01,
|
||||
providers: Optional[List[Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize SPLADE encoder.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name (default: naver/splade-cocondenser-ensembledistil)
|
||||
use_gpu: If True, use GPU acceleration when available
|
||||
max_length: Maximum sequence length for tokenization
|
||||
sparsity_threshold: Minimum weight to include in sparse vector
|
||||
providers: Explicit ONNX providers list (overrides use_gpu)
|
||||
"""
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.max_length = int(max_length) if max_length > 0 else 512
|
||||
self.sparsity_threshold = float(sparsity_threshold)
|
||||
self.providers = providers
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._vocab_size: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Lazy load ONNX model and tokenizer."""
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_splade_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForMaskedLM
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from .gpu_support import get_optimal_providers
|
||||
|
||||
# Include device_id options for DirectML/CUDA selection when available
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=True
|
||||
)
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`
|
||||
# Prefer passing the full providers list, with a conservative fallback
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForMaskedLM.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
model_kwargs = {}
|
||||
|
||||
try:
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(
|
||||
self.model_name,
|
||||
**model_kwargs,
|
||||
)
|
||||
logger.debug(f"SPLADE model loaded: {self.model_name}")
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments
|
||||
self._model = ORTModelForMaskedLM.from_pretrained(self.model_name)
|
||||
logger.warning(
|
||||
"Optimum version doesn't support provider parameters. "
|
||||
"Upgrade optimum for GPU acceleration: pip install --upgrade optimum"
|
||||
)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache vocabulary size
|
||||
self._vocab_size = len(self._tokenizer)
|
||||
logger.debug(f"SPLADE tokenizer loaded: vocab_size={self._vocab_size}")
|
||||
|
||||
@staticmethod
|
||||
def _splade_activation(logits: Any, attention_mask: Any) -> Any:
|
||||
"""Apply SPLADE activation function to model outputs.
|
||||
|
||||
Formula: log(1 + ReLU(logits)) * attention_mask
|
||||
|
||||
Args:
|
||||
logits: Model output logits (batch, seq_len, vocab_size)
|
||||
attention_mask: Attention mask (batch, seq_len)
|
||||
|
||||
Returns:
|
||||
SPLADE representations (batch, seq_len, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# ReLU activation
|
||||
relu_logits = np.maximum(0, logits)
|
||||
|
||||
# Log(1 + x) transformation
|
||||
log_relu = np.log1p(relu_logits)
|
||||
|
||||
# Apply attention mask (expand to match vocab dimension)
|
||||
# attention_mask: (batch, seq_len) -> (batch, seq_len, 1)
|
||||
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
||||
|
||||
# Element-wise multiplication
|
||||
splade_repr = log_relu * mask_expanded
|
||||
|
||||
return splade_repr
|
||||
|
||||
@staticmethod
|
||||
def _max_pooling(splade_repr: Any) -> Any:
|
||||
"""Max pooling over sequence length dimension.
|
||||
|
||||
Args:
|
||||
splade_repr: SPLADE representations (batch, seq_len, vocab_size)
|
||||
|
||||
Returns:
|
||||
Pooled sparse vectors (batch, vocab_size)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Max pooling over sequence dimension (axis=1)
|
||||
return np.max(splade_repr, axis=1)
|
||||
|
||||
def _to_sparse_dict(self, dense_vec: Any) -> Dict[int, float]:
|
||||
"""Convert dense vector to sparse dictionary.
|
||||
|
||||
Args:
|
||||
dense_vec: Dense vector (vocab_size,)
|
||||
|
||||
Returns:
|
||||
Sparse dictionary {token_id: weight} with weights above threshold
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Find non-zero indices above threshold
|
||||
nonzero_indices = np.where(dense_vec > self.sparsity_threshold)[0]
|
||||
|
||||
# Create sparse dictionary
|
||||
sparse_dict = {
|
||||
int(idx): float(dense_vec[idx])
|
||||
for idx in nonzero_indices
|
||||
}
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_text(self, text: str) -> Dict[int, float]:
|
||||
"""Encode text to sparse vector {token_id: weight}.
|
||||
|
||||
Args:
|
||||
text: Input text to encode
|
||||
|
||||
Returns:
|
||||
Sparse vector as dictionary mapping token_id to weight
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Tokenize input
|
||||
encoded = self._tokenizer(
|
||||
text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vec = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert to sparse dictionary (single item batch)
|
||||
sparse_dict = self._to_sparse_dict(splade_vec[0])
|
||||
|
||||
return sparse_dict
|
||||
|
||||
def encode_batch(self, texts: List[str], batch_size: int = 32) -> List[Dict[int, float]]:
|
||||
"""Batch encode texts to sparse vectors.
|
||||
|
||||
Args:
|
||||
texts: List of input texts to encode
|
||||
batch_size: Batch size for encoding (default: 32)
|
||||
|
||||
Returns:
|
||||
List of sparse vectors as dictionaries
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
import numpy as np
|
||||
|
||||
results: List[Dict[int, float]] = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Tokenize batch
|
||||
encoded = self._tokenizer(
|
||||
batch_texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
# Forward pass through model
|
||||
outputs = self._model(**encoded)
|
||||
|
||||
# Extract logits
|
||||
if hasattr(outputs, "logits"):
|
||||
logits = outputs.logits
|
||||
elif isinstance(outputs, dict) and "logits" in outputs:
|
||||
logits = outputs["logits"]
|
||||
elif isinstance(outputs, (list, tuple)) and outputs:
|
||||
logits = outputs[0]
|
||||
else:
|
||||
raise RuntimeError("Unexpected model output format")
|
||||
|
||||
# Apply SPLADE activation
|
||||
attention_mask = encoded["attention_mask"]
|
||||
splade_repr = self._splade_activation(logits, attention_mask)
|
||||
|
||||
# Max pooling over sequence length
|
||||
splade_vecs = self._max_pooling(splade_repr)
|
||||
|
||||
# Convert each vector to sparse dictionary
|
||||
for vec in splade_vecs:
|
||||
sparse_dict = self._to_sparse_dict(vec)
|
||||
results.append(sparse_dict)
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Return vocabulary size (~30k for BERT-based models).
|
||||
|
||||
Returns:
|
||||
Vocabulary size (number of tokens in tokenizer)
|
||||
"""
|
||||
if self._vocab_size is not None:
|
||||
return self._vocab_size
|
||||
|
||||
self._load_model()
|
||||
return self._vocab_size or 0
|
||||
|
||||
def get_token(self, token_id: int) -> str:
|
||||
"""Convert token_id to string (for debugging).
|
||||
|
||||
Args:
|
||||
token_id: Token ID to convert
|
||||
|
||||
Returns:
|
||||
Token string
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded")
|
||||
|
||||
return self._tokenizer.decode([token_id])
|
||||
|
||||
def get_top_tokens(self, sparse_vec: Dict[int, float], top_k: int = 10) -> List[Tuple[str, float]]:
|
||||
"""Get top-k tokens with highest weights from sparse vector.
|
||||
|
||||
Useful for debugging and understanding what the model is focusing on.
|
||||
|
||||
Args:
|
||||
sparse_vec: Sparse vector as {token_id: weight}
|
||||
top_k: Number of top tokens to return
|
||||
|
||||
Returns:
|
||||
List of (token_string, weight) tuples, sorted by weight descending
|
||||
"""
|
||||
self._load_model()
|
||||
|
||||
if not sparse_vec:
|
||||
return []
|
||||
|
||||
# Sort by weight descending
|
||||
sorted_items = sorted(sparse_vec.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Take top-k and convert token_ids to strings
|
||||
top_items = sorted_items[:top_k]
|
||||
|
||||
return [
|
||||
(self.get_token(token_id), weight)
|
||||
for token_id, weight in top_items
|
||||
]
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Migration 009: Add SPLADE sparse retrieval tables.
|
||||
|
||||
This migration introduces SPLADE (Sparse Lexical AnD Expansion) support:
|
||||
- splade_metadata: Model configuration (model name, vocab size, ONNX path)
|
||||
- splade_posting_list: Inverted index mapping token_id -> (chunk_id, weight)
|
||||
|
||||
The SPLADE tables are designed for efficient sparse vector retrieval:
|
||||
- Token-based lookup for query expansion
|
||||
- Chunk-based deletion for index maintenance
|
||||
- Maintains backward compatibility with existing FTS tables
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlite3 import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade(db_conn: Connection) -> None:
|
||||
"""
|
||||
Adds SPLADE tables for sparse retrieval.
|
||||
|
||||
Creates:
|
||||
- splade_metadata: Stores model configuration and ONNX path
|
||||
- splade_posting_list: Inverted index with token_id -> (chunk_id, weight) mappings
|
||||
- Indexes for efficient token-based and chunk-based lookups
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Creating splade_metadata table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
model_name TEXT NOT NULL,
|
||||
vocab_size INTEGER NOT NULL,
|
||||
onnx_path TEXT,
|
||||
created_at REAL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating splade_posting_list table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||
token_id INTEGER NOT NULL,
|
||||
chunk_id INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
PRIMARY KEY (token_id, chunk_id),
|
||||
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating indexes for splade_posting_list...")
|
||||
# Index for efficient chunk-based lookups (deletion, updates)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||
ON splade_posting_list(chunk_id)
|
||||
"""
|
||||
)
|
||||
|
||||
# Index for efficient term-based retrieval
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||
ON splade_posting_list(token_id)
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Migration 009 completed successfully")
|
||||
|
||||
|
||||
def downgrade(db_conn: Connection) -> None:
|
||||
"""
|
||||
Removes SPLADE tables.
|
||||
|
||||
Drops:
|
||||
- splade_posting_list (and associated indexes)
|
||||
- splade_metadata
|
||||
|
||||
Args:
|
||||
db_conn: The SQLite database connection.
|
||||
"""
|
||||
cursor = db_conn.cursor()
|
||||
|
||||
log.info("Dropping SPLADE indexes...")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_chunk")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_splade_by_token")
|
||||
|
||||
log.info("Dropping splade_posting_list table...")
|
||||
cursor.execute("DROP TABLE IF EXISTS splade_posting_list")
|
||||
|
||||
log.info("Dropping splade_metadata table...")
|
||||
cursor.execute("DROP TABLE IF EXISTS splade_metadata")
|
||||
|
||||
log.info("Migration 009 downgrade completed successfully")
|
||||
432
codex-lens/src/codexlens/storage/splade_index.py
Normal file
432
codex-lens/src/codexlens/storage/splade_index.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""SPLADE inverted index storage for sparse vector retrieval.
|
||||
|
||||
This module implements SQLite-based inverted index for SPLADE sparse vectors,
|
||||
enabling efficient sparse retrieval using dot-product scoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.errors import StorageError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpladeIndex:
|
||||
"""SQLite-based inverted index for SPLADE sparse vectors.
|
||||
|
||||
Stores sparse vectors as posting lists mapping token_id -> (chunk_id, weight).
|
||||
Supports efficient dot-product retrieval using SQL joins.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path | str) -> None:
|
||||
"""Initialize SPLADE index.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file.
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Thread-safe connection management
|
||||
self._lock = threading.RLock()
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get or create a thread-local database connection."""
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is None:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
conn.execute("PRAGMA mmap_size=30000000000") # 30GB limit
|
||||
self._local.conn = conn
|
||||
return conn
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close thread-local database connection."""
|
||||
with self._lock:
|
||||
conn = getattr(self._local, "conn", None)
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
self._local.conn = None
|
||||
|
||||
def __enter__(self) -> SpladeIndex:
|
||||
"""Context manager entry."""
|
||||
self.create_tables()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
def has_index(self) -> bool:
|
||||
"""Check if SPLADE tables exist in database.
|
||||
|
||||
Returns:
|
||||
True if tables exist, False otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='splade_posting_list'
|
||||
"""
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to check index existence: %s", e)
|
||||
return False
|
||||
|
||||
def create_tables(self) -> None:
|
||||
"""Create SPLADE schema if not exists.
|
||||
|
||||
Note: The splade_posting_list table has a FOREIGN KEY constraint
|
||||
referencing semantic_chunks(id). Ensure VectorStore.create_tables()
|
||||
is called first to create the semantic_chunks table.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Inverted index for sparse vectors
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_posting_list (
|
||||
token_id INTEGER NOT NULL,
|
||||
chunk_id INTEGER NOT NULL,
|
||||
weight REAL NOT NULL,
|
||||
PRIMARY KEY (token_id, chunk_id),
|
||||
FOREIGN KEY (chunk_id) REFERENCES semantic_chunks(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for efficient lookups
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_chunk
|
||||
ON splade_posting_list(chunk_id)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_splade_by_token
|
||||
ON splade_posting_list(token_id)
|
||||
""")
|
||||
|
||||
# Model metadata
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS splade_metadata (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
model_name TEXT NOT NULL,
|
||||
vocab_size INTEGER NOT NULL,
|
||||
onnx_path TEXT,
|
||||
created_at REAL
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
logger.debug("SPLADE schema created successfully")
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to create SPLADE schema: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="create_tables"
|
||||
) from e
|
||||
|
||||
def add_posting(self, chunk_id: int, sparse_vec: Dict[int, float]) -> None:
|
||||
"""Add a single document to inverted index.
|
||||
|
||||
Args:
|
||||
chunk_id: Chunk ID (foreign key to semantic_chunks.id).
|
||||
sparse_vec: Sparse vector as {token_id: weight} mapping.
|
||||
"""
|
||||
if not sparse_vec:
|
||||
logger.warning("Empty sparse vector for chunk_id=%d, skipping", chunk_id)
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Insert all non-zero weights for this chunk
|
||||
postings = [
|
||||
(token_id, chunk_id, weight)
|
||||
for token_id, weight in sparse_vec.items()
|
||||
if weight > 0 # Only store non-zero weights
|
||||
]
|
||||
|
||||
if postings:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_posting_list
|
||||
(token_id, chunk_id, weight)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
postings
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug(
|
||||
"Added %d postings for chunk_id=%d", len(postings), chunk_id
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to add posting for chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_posting"
|
||||
) from e
|
||||
|
||||
def add_postings_batch(
|
||||
self, postings: List[Tuple[int, Dict[int, float]]]
|
||||
) -> None:
|
||||
"""Batch insert postings for multiple chunks.
|
||||
|
||||
Args:
|
||||
postings: List of (chunk_id, sparse_vec) tuples.
|
||||
"""
|
||||
if not postings:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Flatten all postings into single batch
|
||||
batch_data = []
|
||||
for chunk_id, sparse_vec in postings:
|
||||
for token_id, weight in sparse_vec.items():
|
||||
if weight > 0: # Only store non-zero weights
|
||||
batch_data.append((token_id, chunk_id, weight))
|
||||
|
||||
if batch_data:
|
||||
conn.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_posting_list
|
||||
(token_id, chunk_id, weight)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
batch_data
|
||||
)
|
||||
conn.commit()
|
||||
logger.debug(
|
||||
"Batch inserted %d postings for %d chunks",
|
||||
len(batch_data),
|
||||
len(postings)
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to batch insert postings: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="add_postings_batch"
|
||||
) from e
|
||||
|
||||
def remove_chunk(self, chunk_id: int) -> int:
|
||||
"""Remove all postings for a chunk.
|
||||
|
||||
Args:
|
||||
chunk_id: Chunk ID to remove.
|
||||
|
||||
Returns:
|
||||
Number of deleted postings.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM splade_posting_list WHERE chunk_id = ?",
|
||||
(chunk_id,)
|
||||
)
|
||||
conn.commit()
|
||||
deleted = cursor.rowcount
|
||||
logger.debug("Removed %d postings for chunk_id=%d", deleted, chunk_id)
|
||||
return deleted
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to remove chunk_id={chunk_id}: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="remove_chunk"
|
||||
) from e
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_sparse: Dict[int, float],
|
||||
limit: int = 50,
|
||||
min_score: float = 0.0
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""Search for similar chunks using dot-product scoring.
|
||||
|
||||
Implements efficient sparse dot-product via SQL JOIN:
|
||||
score(q, d) = sum(q[t] * d[t]) for all tokens t
|
||||
|
||||
Args:
|
||||
query_sparse: Query sparse vector as {token_id: weight}.
|
||||
limit: Maximum number of results.
|
||||
min_score: Minimum score threshold.
|
||||
|
||||
Returns:
|
||||
List of (chunk_id, score) tuples, ordered by score descending.
|
||||
"""
|
||||
if not query_sparse:
|
||||
logger.warning("Empty query sparse vector")
|
||||
return []
|
||||
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Build VALUES clause for query terms
|
||||
# Each term: (token_id, weight)
|
||||
query_terms = [
|
||||
(token_id, weight)
|
||||
for token_id, weight in query_sparse.items()
|
||||
if weight > 0
|
||||
]
|
||||
|
||||
if not query_terms:
|
||||
logger.warning("No non-zero query terms")
|
||||
return []
|
||||
|
||||
# Create CTE for query terms using parameterized VALUES
|
||||
# Build placeholders and params to prevent SQL injection
|
||||
params = []
|
||||
placeholders = []
|
||||
for token_id, weight in query_terms:
|
||||
placeholders.append("(?, ?)")
|
||||
params.extend([token_id, weight])
|
||||
|
||||
values_placeholders = ", ".join(placeholders)
|
||||
|
||||
sql = f"""
|
||||
WITH query_terms(token_id, weight) AS (
|
||||
VALUES {values_placeholders}
|
||||
)
|
||||
SELECT
|
||||
p.chunk_id,
|
||||
SUM(p.weight * q.weight) as score
|
||||
FROM splade_posting_list p
|
||||
INNER JOIN query_terms q ON p.token_id = q.token_id
|
||||
GROUP BY p.chunk_id
|
||||
HAVING score >= ?
|
||||
ORDER BY score DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
# Append min_score and limit to params
|
||||
params.extend([min_score, limit])
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
results = [(row["chunk_id"], float(row["score"])) for row in rows]
|
||||
logger.debug(
|
||||
"SPLADE search: %d query terms, %d results",
|
||||
len(query_terms),
|
||||
len(results)
|
||||
)
|
||||
return results
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"SPLADE search failed: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="search"
|
||||
) from e
|
||||
|
||||
def get_metadata(self) -> Optional[Dict]:
|
||||
"""Get SPLADE model metadata.
|
||||
|
||||
Returns:
|
||||
Dictionary with model_name, vocab_size, onnx_path, created_at,
|
||||
or None if not set.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT model_name, vocab_size, onnx_path, created_at
|
||||
FROM splade_metadata
|
||||
WHERE id = 1
|
||||
"""
|
||||
).fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"model_name": row["model_name"],
|
||||
"vocab_size": row["vocab_size"],
|
||||
"onnx_path": row["onnx_path"],
|
||||
"created_at": row["created_at"]
|
||||
}
|
||||
return None
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get metadata: %s", e)
|
||||
return None
|
||||
|
||||
def set_metadata(
|
||||
self,
|
||||
model_name: str,
|
||||
vocab_size: int,
|
||||
onnx_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""Set SPLADE model metadata.
|
||||
|
||||
Args:
|
||||
model_name: SPLADE model name.
|
||||
vocab_size: Vocabulary size (typically ~30k for BERT vocab).
|
||||
onnx_path: Optional path to ONNX model file.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
current_time = time.time()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO splade_metadata
|
||||
(id, model_name, vocab_size, onnx_path, created_at)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
""",
|
||||
(model_name, vocab_size, onnx_path, current_time)
|
||||
)
|
||||
conn.commit()
|
||||
logger.info(
|
||||
"Set SPLADE metadata: model=%s, vocab_size=%d",
|
||||
model_name,
|
||||
vocab_size
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise StorageError(
|
||||
f"Failed to set metadata: {e}",
|
||||
db_path=str(self.db_path),
|
||||
operation="set_metadata"
|
||||
) from e
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with total_postings, unique_tokens, unique_chunks.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
row = conn.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total_postings,
|
||||
COUNT(DISTINCT token_id) as unique_tokens,
|
||||
COUNT(DISTINCT chunk_id) as unique_chunks
|
||||
FROM splade_posting_list
|
||||
""").fetchone()
|
||||
|
||||
return {
|
||||
"total_postings": row["total_postings"],
|
||||
"unique_tokens": row["unique_tokens"],
|
||||
"unique_chunks": row["unique_chunks"]
|
||||
}
|
||||
except sqlite3.Error as e:
|
||||
logger.error("Failed to get stats: %s", e)
|
||||
return {
|
||||
"total_postings": 0,
|
||||
"unique_tokens": 0,
|
||||
"unique_chunks": 0
|
||||
}
|
||||
117
codex-lens/verify_watcher.py
Normal file
117
codex-lens/verify_watcher.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Verification script for FileWatcher event filtering and debouncing."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from codexlens.watcher.file_watcher import FileWatcher
|
||||
from codexlens.watcher.events import WatcherConfig, FileEvent
|
||||
|
||||
def test_should_index_file():
|
||||
"""Test _should_index_file filtering logic."""
|
||||
print("Testing _should_index_file filtering...")
|
||||
|
||||
# Create watcher instance
|
||||
config = WatcherConfig()
|
||||
watcher = FileWatcher(
|
||||
root_path=Path("."),
|
||||
config=config,
|
||||
on_changes=lambda events: None,
|
||||
)
|
||||
|
||||
# Test cases
|
||||
test_cases = [
|
||||
# (path, expected_result, description)
|
||||
(Path("test.py"), True, "Python file should be indexed"),
|
||||
(Path("test.txt"), True, "Text file should be indexed"),
|
||||
(Path("test.js"), True, "JavaScript file should be indexed"),
|
||||
(Path("test.ts"), True, "TypeScript file should be indexed"),
|
||||
(Path("src/test.py"), True, "Python file in subdirectory should be indexed"),
|
||||
(Path(".git/config"), False, ".git files should be filtered"),
|
||||
(Path("node_modules/pkg/index.js"), False, "node_modules should be filtered"),
|
||||
(Path("__pycache__/test.pyc"), False, "__pycache__ should be filtered"),
|
||||
(Path(".venv/lib/test.py"), False, ".venv should be filtered"),
|
||||
(Path("test.unknown"), False, "Unknown extension should be filtered"),
|
||||
(Path("README.md"), True, "Markdown file should be indexed"),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for path, expected, description in test_cases:
|
||||
result = watcher._should_index_file(path)
|
||||
status = "✓" if result == expected else "✗"
|
||||
|
||||
if result == expected:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
print(f" {status} {description}")
|
||||
print(f" Path: {path}, Expected: {expected}, Got: {result}")
|
||||
|
||||
print(f"\nResults: {passed} passed, {failed} failed")
|
||||
return failed == 0
|
||||
|
||||
def test_debounce_and_dedup():
|
||||
"""Test event debouncing and deduplication."""
|
||||
print("\nTesting event debouncing and deduplication...")
|
||||
|
||||
received_events = []
|
||||
|
||||
def on_changes(events):
|
||||
received_events.append(events)
|
||||
print(f" Received batch: {len(events)} events")
|
||||
|
||||
# Create watcher with short debounce time for testing
|
||||
config = WatcherConfig(debounce_ms=500)
|
||||
watcher = FileWatcher(
|
||||
root_path=Path("."),
|
||||
config=config,
|
||||
on_changes=on_changes,
|
||||
)
|
||||
|
||||
# Simulate rapid events to same file (should be deduplicated)
|
||||
from codexlens.watcher.events import ChangeType
|
||||
|
||||
test_path = Path("test_file.py")
|
||||
for i in range(5):
|
||||
event = FileEvent(
|
||||
path=test_path,
|
||||
change_type=ChangeType.MODIFIED,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
watcher._on_raw_event(event)
|
||||
|
||||
# Wait for debounce
|
||||
time.sleep(0.6)
|
||||
|
||||
# Force flush to ensure we get the events
|
||||
watcher._flush_events()
|
||||
|
||||
if received_events:
|
||||
batch = received_events[0]
|
||||
# Should deduplicate 5 events to 1
|
||||
if len(batch) == 1:
|
||||
print(" ✓ Deduplication working: 5 events reduced to 1")
|
||||
return True
|
||||
else:
|
||||
print(f" ✗ Deduplication failed: expected 1 event, got {len(batch)}")
|
||||
return False
|
||||
else:
|
||||
print(" ✗ No events received")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("FileWatcher Verification")
|
||||
print("=" * 60)
|
||||
|
||||
test1 = test_should_index_file()
|
||||
test2 = test_debounce_and_dedup()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if test1 and test2:
|
||||
print("✓ All tests passed!")
|
||||
else:
|
||||
print("✗ Some tests failed")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user