From 520f2d26f267433606f17a06d3ed244aed600efe Mon Sep 17 00:00:00 2001 From: catlog22 Date: Thu, 1 Jan 2026 13:23:52 +0800 Subject: [PATCH] feat(codex-lens): add unified reranker architecture and file watcher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Unified Reranker Architecture: - Add BaseReranker ABC with factory pattern - Implement 4 backends: ONNX (default), API, LiteLLM, Legacy - Add .env configuration parsing for API credentials - Migrate from sentence-transformers to optimum+onnxruntime File Watcher Module: - Add real-time file system monitoring with watchdog - Implement IncrementalIndexer for single-file updates - Add WatcherManager with signal handling and graceful shutdown - Add 'codexlens watch' CLI command - Event filtering, debouncing, and deduplication - Thread-safe design with proper resource cleanup Tests: 16 watcher tests + 5 reranker test files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- codex-lens/.env.example | 66 ++++ codex-lens/pyproject.toml | 29 +- codex-lens/src/codexlens/cli/commands.py | 185 +++++++++ codex-lens/src/codexlens/config.py | 33 +- codex-lens/src/codexlens/env_config.py | 260 +++++++++++++ .../src/codexlens/search/hybrid_search.py | 44 ++- .../codexlens/semantic/reranker/__init__.py | 22 ++ .../semantic/reranker/api_reranker.py | 310 +++++++++++++++ .../src/codexlens/semantic/reranker/base.py | 36 ++ .../codexlens/semantic/reranker/factory.py | 138 +++++++ .../{reranker.py => reranker/legacy.py} | 15 +- .../semantic/reranker/litellm_reranker.py | 214 +++++++++++ .../semantic/reranker/onnx_reranker.py | 268 +++++++++++++ codex-lens/src/codexlens/watcher/__init__.py | 17 + codex-lens/src/codexlens/watcher/events.py | 54 +++ .../src/codexlens/watcher/file_watcher.py | 245 ++++++++++++ .../codexlens/watcher/incremental_indexer.py | 359 ++++++++++++++++++ codex-lens/src/codexlens/watcher/manager.py | 194 ++++++++++ codex-lens/tests/test_api_reranker.py | 171 +++++++++ .../test_hybrid_search_reranker_backend.py | 139 +++++++ codex-lens/tests/test_litellm_reranker.py | 85 +++++ codex-lens/tests/test_reranker_backends.py | 115 ++++++ codex-lens/tests/test_reranker_factory.py | 315 +++++++++++++++ codex-lens/tests/test_watcher/__init__.py | 1 + codex-lens/tests/test_watcher/conftest.py | 43 +++ codex-lens/tests/test_watcher/test_events.py | 103 +++++ .../tests/test_watcher/test_file_watcher.py | 124 ++++++ 27 files changed, 3571 insertions(+), 14 deletions(-) create mode 100644 codex-lens/.env.example create mode 100644 codex-lens/src/codexlens/env_config.py create mode 100644 codex-lens/src/codexlens/semantic/reranker/__init__.py create mode 100644 codex-lens/src/codexlens/semantic/reranker/api_reranker.py create mode 100644 codex-lens/src/codexlens/semantic/reranker/base.py create mode 100644 codex-lens/src/codexlens/semantic/reranker/factory.py rename codex-lens/src/codexlens/semantic/{reranker.py => reranker/legacy.py} (87%) create mode 100644 codex-lens/src/codexlens/semantic/reranker/litellm_reranker.py create mode 100644 codex-lens/src/codexlens/semantic/reranker/onnx_reranker.py create mode 100644 codex-lens/src/codexlens/watcher/__init__.py create mode 100644 codex-lens/src/codexlens/watcher/events.py create mode 100644 codex-lens/src/codexlens/watcher/file_watcher.py create mode 100644 codex-lens/src/codexlens/watcher/incremental_indexer.py create mode 100644 codex-lens/src/codexlens/watcher/manager.py create mode 100644 codex-lens/tests/test_api_reranker.py create mode 100644 codex-lens/tests/test_hybrid_search_reranker_backend.py create mode 100644 codex-lens/tests/test_litellm_reranker.py create mode 100644 codex-lens/tests/test_reranker_backends.py create mode 100644 codex-lens/tests/test_reranker_factory.py create mode 100644 codex-lens/tests/test_watcher/__init__.py create mode 100644 codex-lens/tests/test_watcher/conftest.py create mode 100644 codex-lens/tests/test_watcher/test_events.py create mode 100644 codex-lens/tests/test_watcher/test_file_watcher.py diff --git a/codex-lens/.env.example b/codex-lens/.env.example new file mode 100644 index 00000000..1db45b70 --- /dev/null +++ b/codex-lens/.env.example @@ -0,0 +1,66 @@ +# CodexLens Environment Configuration +# Copy this file to .codexlens/.env and fill in your values +# +# Priority order: +# 1. Environment variables (already set in shell) +# 2. .codexlens/.env (workspace-local, this file) +# 3. .env (project root) + +# ============================================ +# RERANKER Configuration +# ============================================ + +# API key for reranker service (SiliconFlow/Cohere/Jina) +# Required for 'api' backend +# RERANKER_API_KEY=sk-xxxx + +# Base URL for reranker API (overrides provider default) +# SiliconFlow: https://api.siliconflow.cn +# Cohere: https://api.cohere.ai +# Jina: https://api.jina.ai +# RERANKER_API_BASE=https://api.siliconflow.cn + +# Reranker provider: siliconflow, cohere, jina +# RERANKER_PROVIDER=siliconflow + +# Reranker model name +# SiliconFlow: BAAI/bge-reranker-v2-m3 +# Cohere: rerank-english-v3.0 +# Jina: jina-reranker-v2-base-multilingual +# RERANKER_MODEL=BAAI/bge-reranker-v2-m3 + +# ============================================ +# EMBEDDING Configuration +# ============================================ + +# API key for embedding service (for litellm backend) +# EMBEDDING_API_KEY=sk-xxxx + +# Base URL for embedding API +# EMBEDDING_API_BASE=https://api.openai.com + +# Embedding model name +# EMBEDDING_MODEL=text-embedding-3-small + +# ============================================ +# LITELLM Configuration +# ============================================ + +# API key for LiteLLM (for litellm reranker backend) +# LITELLM_API_KEY=sk-xxxx + +# Base URL for LiteLLM +# LITELLM_API_BASE= + +# LiteLLM model name +# LITELLM_MODEL=gpt-4o-mini + +# ============================================ +# General Configuration +# ============================================ + +# Custom data directory path (default: ~/.codexlens) +# CODEXLENS_DATA_DIR=~/.codexlens + +# Enable debug mode (true/false) +# CODEXLENS_DEBUG=false diff --git a/codex-lens/pyproject.toml b/codex-lens/pyproject.toml index fbb61294..35e66501 100644 --- a/codex-lens/pyproject.toml +++ b/codex-lens/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "tree-sitter-javascript>=0.25", "tree-sitter-typescript>=0.23", "pathspec>=0.11", + "watchdog>=3.0", ] [project.optional-dependencies] @@ -50,11 +51,35 @@ semantic-directml = [ ] # Cross-encoder reranking (second-stage, optional) -# Install with: pip install codexlens[reranker] -reranker = [ +# Install with: pip install codexlens[reranker] (default: ONNX backend) +reranker-onnx = [ + "optimum>=1.16", + "onnxruntime>=1.15", + "transformers>=4.36", +] + +# Remote reranking via HTTP API +reranker-api = [ + "httpx>=0.25", +] + +# LLM-based reranking via ccw-litellm +reranker-litellm = [ + "ccw-litellm>=0.1", +] + +# Legacy sentence-transformers CrossEncoder reranker +reranker-legacy = [ "sentence-transformers>=2.2", ] +# Backward-compatible alias for default reranker backend +reranker = [ + "optimum>=1.16", + "onnxruntime>=1.15", + "transformers>=4.36", +] + # Encoding detection for non-UTF8 files encoding = [ "chardet>=5.0", diff --git a/codex-lens/src/codexlens/cli/commands.py b/codex-lens/src/codexlens/cli/commands.py index 10f66e9b..119e6cc2 100644 --- a/codex-lens/src/codexlens/cli/commands.py +++ b/codex-lens/src/codexlens/cli/commands.py @@ -22,6 +22,7 @@ from codexlens.storage.registry import RegistryStore, ProjectInfo from codexlens.storage.index_tree import IndexTreeBuilder from codexlens.storage.dir_index import DirIndexStore from codexlens.search.chain_search import ChainSearchEngine, SearchOptions +from codexlens.watcher import WatcherManager, WatcherConfig from .output import ( console, @@ -321,6 +322,91 @@ def init( registry.close() +@app.command() +def watch( + path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to watch."), + language: Optional[List[str]] = typer.Option( + None, + "--language", + "-l", + help="Limit watching to specific languages (repeat or comma-separated).", + ), + debounce: int = typer.Option(1000, "--debounce", "-d", min=100, max=10000, help="Debounce interval in milliseconds."), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose logging."), +) -> None: + """Watch directory for changes and update index incrementally. + + Monitors filesystem events and automatically updates the index + when files are created, modified, or deleted. + + The directory must already be indexed (run 'codexlens init' first). + + Press Ctrl+C to stop watching. + + Examples: + codexlens watch . + codexlens watch /path/to/project --debounce 500 --verbose + codexlens watch . --language python,typescript + """ + _configure_logging(verbose) + + from codexlens.watcher.events import IndexResult + + base_path = path.expanduser().resolve() + + # Check if path is indexed + mapper = PathMapper() + index_db = mapper.source_to_index_db(base_path) + if not index_db.exists(): + console.print(f"[red]Error:[/red] Directory not indexed: {base_path}") + console.print("Run 'codexlens init' first to create the index.") + raise typer.Exit(code=1) + + # Parse languages + languages = _parse_languages(language) + + # Create watcher config + watcher_config = WatcherConfig( + debounce_ms=debounce, + languages=languages, + ) + + # Callback for indexed files + def on_indexed(result: IndexResult) -> None: + if result.files_indexed > 0: + console.print(f" [green]Indexed:[/green] {result.files_indexed} files ({result.symbols_added} symbols)") + if result.files_removed > 0: + console.print(f" [yellow]Removed:[/yellow] {result.files_removed} files") + if result.errors: + for error in result.errors[:3]: # Show first 3 errors + console.print(f" [red]Error:[/red] {error}") + + console.print(f"[bold]Watching:[/bold] {base_path}") + console.print(f" Debounce: {debounce}ms") + if languages: + console.print(f" Languages: {', '.join(languages)}") + console.print(" Press Ctrl+C to stop.\n") + + manager: WatcherManager | None = None + try: + manager = WatcherManager( + root_path=base_path, + watcher_config=watcher_config, + on_indexed=on_indexed, + ) + manager.start() + manager.wait() + except KeyboardInterrupt: + pass + except Exception as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(code=1) + finally: + if manager is not None: + manager.stop() + console.print("\n[dim]Watcher stopped.[/dim]") + + @app.command() def search( query: str = typer.Argument(..., help="FTS query to run."), @@ -2293,3 +2379,102 @@ def gpu_reset( if gpu_info.preferred_device_id is not None: console.print(f" Auto-selected device: {gpu_info.preferred_device_id}") console.print(f" Device: [cyan]{gpu_info.gpu_name}[/cyan]") + + +# ==================== Watch Command ==================== + +@app.command() +def watch( + path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to watch."), + language: Optional[List[str]] = typer.Option(None, "--language", "-l", help="Languages to watch (comma-separated)."), + debounce: int = typer.Option(1000, "--debounce", "-d", min=100, max=10000, help="Debounce interval in milliseconds."), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."), +) -> None: + """Watch a directory for file changes and incrementally update the index. + + Monitors the specified directory for file system changes (create, modify, delete) + and automatically updates the CodexLens index. The directory must already be indexed + using 'codexlens init' before watching. + + Examples: + # Watch current directory + codexlens watch . + + # Watch with custom debounce interval + codexlens watch . --debounce 2000 + + # Watch only Python and JavaScript files + codexlens watch . --language python,javascript + + Press Ctrl+C to stop watching. + """ + _configure_logging(verbose) + watch_path = path.expanduser().resolve() + + registry: RegistryStore | None = None + try: + # Validate that path is indexed + registry = RegistryStore() + registry.initialize() + mapper = PathMapper() + + project_record = registry.find_by_source_path(str(watch_path)) + if not project_record: + console.print(f"[red]Error:[/red] Directory is not indexed: {watch_path}") + console.print("[dim]Run 'codexlens init' first to create an index.[/dim]") + raise typer.Exit(code=1) + + # Parse languages + languages = _parse_languages(language) + + # Create watcher config + watcher_config = WatcherConfig( + debounce_ms=debounce, + languages=languages, + ) + + # Display startup message + console.print(f"[green]Starting watcher for:[/green] {watch_path}") + console.print(f"[dim]Debounce interval: {debounce}ms[/dim]") + if languages: + console.print(f"[dim]Watching languages: {', '.join(languages)}[/dim]") + console.print("[dim]Press Ctrl+C to stop[/dim]\n") + + # Create and start watcher manager + manager = WatcherManager( + root_path=watch_path, + watcher_config=watcher_config, + on_indexed=lambda result: _display_index_result(result), + ) + + manager.start() + manager.wait() + + except KeyboardInterrupt: + console.print("\n[yellow]Stopping watcher...[/yellow]") + except CodexLensError as exc: + console.print(f"[red]Watch failed:[/red] {exc}") + raise typer.Exit(code=1) + except Exception as exc: + console.print(f"[red]Unexpected error:[/red] {exc}") + raise typer.Exit(code=1) + finally: + if registry is not None: + registry.close() + + +def _display_index_result(result) -> None: + """Display indexing result in real-time.""" + if result.files_indexed > 0 or result.files_removed > 0: + parts = [] + if result.files_indexed > 0: + parts.append(f"[green]✓ Indexed {result.files_indexed} file(s)[/green]") + if result.files_removed > 0: + parts.append(f"[yellow]✗ Removed {result.files_removed} file(s)[/yellow]") + console.print(" | ".join(parts)) + + if result.errors: + for error in result.errors[:3]: # Show max 3 errors + console.print(f" [red]Error:[/red] {error}") + if len(result.errors) > 3: + console.print(f" [dim]... and {len(result.errors) - 3} more errors[/dim]") diff --git a/codex-lens/src/codexlens/config.py b/codex-lens/src/codexlens/config.py index 11c550bf..ac523c3c 100644 --- a/codex-lens/src/codexlens/config.py +++ b/codex-lens/src/codexlens/config.py @@ -116,8 +116,9 @@ class Config: reranking_top_k: int = 50 symbol_boost_factor: float = 1.5 - # Optional cross-encoder reranking (second stage, requires codexlens[reranker]) + # Optional cross-encoder reranking (second stage; requires optional reranker deps) enable_cross_encoder_rerank: bool = False + reranker_backend: str = "onnx" reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" reranker_top_k: int = 50 @@ -311,6 +312,35 @@ class WorkspaceConfig: """Cache directory for this workspace.""" return self.codexlens_dir / "cache" + @property + def env_path(self) -> Path: + """Path to workspace .env file.""" + return self.codexlens_dir / ".env" + + def load_env(self, *, override: bool = False) -> int: + """Load .env file and apply to os.environ. + + Args: + override: If True, override existing environment variables + + Returns: + Number of variables applied + """ + from .env_config import apply_workspace_env + return apply_workspace_env(self.workspace_root, override=override) + + def get_api_config(self, prefix: str) -> dict: + """Get API configuration from environment. + + Args: + prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING") + + Returns: + Dictionary with api_key, api_base, model, etc. + """ + from .env_config import get_api_config + return get_api_config(prefix, workspace_root=self.workspace_root) + def initialize(self) -> None: """Create the .codexlens directory structure.""" try: @@ -324,6 +354,7 @@ class WorkspaceConfig: "# CodexLens workspace data\n" "cache/\n" "*.log\n" + ".env\n" # Exclude .env from git ) except Exception as exc: raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc diff --git a/codex-lens/src/codexlens/env_config.py b/codex-lens/src/codexlens/env_config.py new file mode 100644 index 00000000..bb6061c6 --- /dev/null +++ b/codex-lens/src/codexlens/env_config.py @@ -0,0 +1,260 @@ +"""Environment configuration loader for CodexLens. + +Loads .env files from workspace .codexlens directory with fallback to project root. +Provides unified access to API configurations. + +Priority order: +1. Environment variables (already set) +2. .codexlens/.env (workspace-local) +3. .env (project root) +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional + +log = logging.getLogger(__name__) + +# Supported environment variables with descriptions +ENV_VARS = { + # Reranker API configuration + "RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)", + "RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)", + "RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina", + "RERANKER_MODEL": "Reranker model name", + # Embedding API configuration + "EMBEDDING_API_KEY": "API key for embedding service", + "EMBEDDING_API_BASE": "Base URL for embedding API", + "EMBEDDING_MODEL": "Embedding model name", + # LiteLLM configuration + "LITELLM_API_KEY": "API key for LiteLLM", + "LITELLM_API_BASE": "Base URL for LiteLLM", + "LITELLM_MODEL": "LiteLLM model name", + # General configuration + "CODEXLENS_DATA_DIR": "Custom data directory path", + "CODEXLENS_DEBUG": "Enable debug mode (true/false)", +} + + +def _parse_env_line(line: str) -> tuple[str, str] | None: + """Parse a single .env line, returning (key, value) or None.""" + line = line.strip() + + # Skip empty lines and comments + if not line or line.startswith("#"): + return None + + # Handle export prefix + if line.startswith("export "): + line = line[7:].strip() + + # Split on first = + if "=" not in line: + return None + + key, _, value = line.partition("=") + key = key.strip() + value = value.strip() + + # Remove surrounding quotes + if len(value) >= 2: + if (value.startswith('"') and value.endswith('"')) or \ + (value.startswith("'") and value.endswith("'")): + value = value[1:-1] + + return key, value + + +def load_env_file(env_path: Path) -> Dict[str, str]: + """Load environment variables from a .env file. + + Args: + env_path: Path to .env file + + Returns: + Dictionary of environment variables + """ + if not env_path.is_file(): + return {} + + env_vars: Dict[str, str] = {} + + try: + content = env_path.read_text(encoding="utf-8") + for line in content.splitlines(): + result = _parse_env_line(line) + if result: + key, value = result + env_vars[key] = value + except Exception as exc: + log.warning("Failed to load .env file %s: %s", env_path, exc) + + return env_vars + + +def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]: + """Load environment variables from workspace .env files. + + Priority (later overrides earlier): + 1. Project root .env + 2. .codexlens/.env + + Args: + workspace_root: Workspace root directory. If None, uses current directory. + + Returns: + Merged dictionary of environment variables + """ + if workspace_root is None: + workspace_root = Path.cwd() + + workspace_root = Path(workspace_root).resolve() + + env_vars: Dict[str, str] = {} + + # Load from project root .env (lowest priority) + root_env = workspace_root / ".env" + if root_env.is_file(): + env_vars.update(load_env_file(root_env)) + log.debug("Loaded %d vars from %s", len(env_vars), root_env) + + # Load from .codexlens/.env (higher priority) + codexlens_env = workspace_root / ".codexlens" / ".env" + if codexlens_env.is_file(): + loaded = load_env_file(codexlens_env) + env_vars.update(loaded) + log.debug("Loaded %d vars from %s", len(loaded), codexlens_env) + + return env_vars + + +def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int: + """Load .env files and apply to os.environ. + + Args: + workspace_root: Workspace root directory + override: If True, override existing environment variables + + Returns: + Number of variables applied + """ + env_vars = load_workspace_env(workspace_root) + applied = 0 + + for key, value in env_vars.items(): + if override or key not in os.environ: + os.environ[key] = value + applied += 1 + log.debug("Applied env var: %s", key) + + return applied + + +def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None: + """Get environment variable with .env file fallback. + + Priority: + 1. os.environ (already set) + 2. .codexlens/.env + 3. .env + 4. default value + + Args: + key: Environment variable name + default: Default value if not found + workspace_root: Workspace root for .env file lookup + + Returns: + Value or default + """ + # Check os.environ first + if key in os.environ: + return os.environ[key] + + # Load from .env files + env_vars = load_workspace_env(workspace_root) + if key in env_vars: + return env_vars[key] + + return default + + +def get_api_config( + prefix: str, + *, + workspace_root: Path | None = None, + defaults: Dict[str, Any] | None = None, +) -> Dict[str, Any]: + """Get API configuration from environment. + + Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc. + + Args: + prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING") + workspace_root: Workspace root for .env file lookup + defaults: Default values + + Returns: + Dictionary with api_key, api_base, model, etc. + """ + defaults = defaults or {} + + config: Dict[str, Any] = {} + + # Standard API config fields + field_mapping = { + "api_key": f"{prefix}_API_KEY", + "api_base": f"{prefix}_API_BASE", + "model": f"{prefix}_MODEL", + "provider": f"{prefix}_PROVIDER", + "timeout": f"{prefix}_TIMEOUT", + } + + for field, env_key in field_mapping.items(): + value = get_env(env_key, workspace_root=workspace_root) + if value is not None: + # Type conversion for specific fields + if field == "timeout": + try: + config[field] = float(value) + except ValueError: + pass + else: + config[field] = value + elif field in defaults: + config[field] = defaults[field] + + return config + + +def generate_env_example() -> str: + """Generate .env.example content with all supported variables. + + Returns: + String content for .env.example file + """ + lines = [ + "# CodexLens Environment Configuration", + "# Copy this file to .codexlens/.env and fill in your values", + "", + ] + + # Group by prefix + groups: Dict[str, list] = {} + for key, desc in ENV_VARS.items(): + prefix = key.split("_")[0] + if prefix not in groups: + groups[prefix] = [] + groups[prefix].append((key, desc)) + + for prefix, items in groups.items(): + lines.append(f"# {prefix} Configuration") + for key, desc in items: + lines.append(f"# {desc}") + lines.append(f"# {key}=") + lines.append("") + + return "\n".join(lines) diff --git a/codex-lens/src/codexlens/search/hybrid_search.py b/codex-lens/src/codexlens/search/hybrid_search.py index b461e02a..ff790431 100644 --- a/codex-lens/src/codexlens/search/hybrid_search.py +++ b/codex-lens/src/codexlens/search/hybrid_search.py @@ -258,20 +258,52 @@ class HybridSearchEngine: return None try: - from codexlens.semantic.reranker import CrossEncoderReranker, check_cross_encoder_available + from codexlens.semantic.reranker import ( + check_reranker_available, + get_reranker, + ) except Exception as exc: - self.logger.debug("Cross-encoder reranker unavailable: %s", exc) + self.logger.debug("Reranker factory unavailable: %s", exc) return None - ok, err = check_cross_encoder_available() + backend = (getattr(self._config, "reranker_backend", "") or "").strip().lower() or "onnx" + + ok, err = check_reranker_available(backend) if not ok: - self.logger.debug("Cross-encoder reranker unavailable: %s", err) + self.logger.debug( + "Reranker backend unavailable (backend=%s): %s", + backend, + err, + ) return None try: - return CrossEncoderReranker(model_name=self._config.reranker_model) + model_name = (getattr(self._config, "reranker_model", "") or "").strip() or None + + if backend != "legacy" and model_name == "cross-encoder/ms-marco-MiniLM-L-6-v2": + model_name = None + + device: str | None = None + kwargs: dict[str, Any] = {} + + if backend == "onnx": + kwargs["use_gpu"] = bool(getattr(self._config, "embedding_use_gpu", True)) + elif backend == "legacy": + if not bool(getattr(self._config, "embedding_use_gpu", True)): + device = "cpu" + + return get_reranker( + backend=backend, + model_name=model_name, + device=device, + **kwargs, + ) except Exception as exc: - self.logger.debug("Failed to initialize cross-encoder reranker: %s", exc) + self.logger.debug( + "Failed to initialize reranker (backend=%s): %s", + backend, + exc, + ) return None def _search_parallel( diff --git a/codex-lens/src/codexlens/semantic/reranker/__init__.py b/codex-lens/src/codexlens/semantic/reranker/__init__.py new file mode 100644 index 00000000..18c079b8 --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker/__init__.py @@ -0,0 +1,22 @@ +"""Reranker backends for second-stage search ranking. + +This subpackage provides a unified interface and factory for different reranking +implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers). +""" + +from __future__ import annotations + +from .base import BaseReranker +from .factory import check_reranker_available, get_reranker +from .legacy import CrossEncoderReranker, check_cross_encoder_available +from .onnx_reranker import ONNXReranker, check_onnx_reranker_available + +__all__ = [ + "BaseReranker", + "check_reranker_available", + "get_reranker", + "CrossEncoderReranker", + "check_cross_encoder_available", + "ONNXReranker", + "check_onnx_reranker_available", +] diff --git a/codex-lens/src/codexlens/semantic/reranker/api_reranker.py b/codex-lens/src/codexlens/semantic/reranker/api_reranker.py new file mode 100644 index 00000000..88cf34a3 --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker/api_reranker.py @@ -0,0 +1,310 @@ +"""API-based reranker using a remote HTTP provider. + +Supported providers: +- SiliconFlow: https://api.siliconflow.cn/v1/rerank +- Cohere: https://api.cohere.ai/v1/rerank +- Jina: https://api.jina.ai/v1/rerank +""" + +from __future__ import annotations + +import logging +import os +import random +import time +from pathlib import Path +from typing import Any, Mapping, Sequence + +from .base import BaseReranker + +logger = logging.getLogger(__name__) + +_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY" + + +def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None: + """Get environment variable with .env file fallback.""" + # Check os.environ first + if key in os.environ: + return os.environ[key] + + # Try loading from .env files + try: + from codexlens.env_config import get_env + return get_env(key, workspace_root=workspace_root) + except ImportError: + return None + + +def check_httpx_available() -> tuple[bool, str | None]: + try: + import httpx # noqa: F401 + except ImportError as exc: # pragma: no cover - optional dependency + return False, f"httpx not available: {exc}. Install with: pip install httpx" + return True, None + + +class APIReranker(BaseReranker): + """Reranker backed by a remote reranking HTTP API.""" + + _PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = { + "siliconflow": { + "api_base": "https://api.siliconflow.cn", + "endpoint": "/v1/rerank", + "default_model": "BAAI/bge-reranker-v2-m3", + }, + "cohere": { + "api_base": "https://api.cohere.ai", + "endpoint": "/v1/rerank", + "default_model": "rerank-english-v3.0", + }, + "jina": { + "api_base": "https://api.jina.ai", + "endpoint": "/v1/rerank", + "default_model": "jina-reranker-v2-base-multilingual", + }, + } + + def __init__( + self, + *, + provider: str = "siliconflow", + model_name: str | None = None, + api_key: str | None = None, + api_base: str | None = None, + timeout: float = 30.0, + max_retries: int = 3, + backoff_base_s: float = 0.5, + backoff_max_s: float = 8.0, + env_api_key: str = _DEFAULT_ENV_API_KEY, + workspace_root: Path | str | None = None, + ) -> None: + ok, err = check_httpx_available() + if not ok: # pragma: no cover - exercised via factory availability tests + raise ImportError(err) + + import httpx + + self._workspace_root = Path(workspace_root) if workspace_root else None + + self.provider = (provider or "").strip().lower() + if self.provider not in self._PROVIDER_DEFAULTS: + raise ValueError( + f"Unknown reranker provider: {provider}. " + f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}" + ) + + defaults = self._PROVIDER_DEFAULTS[self.provider] + + # Load api_base from env with .env fallback + env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root) + self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/") + self.endpoint = defaults["endpoint"] + + # Load model from env with .env fallback + env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root) + self.model_name = (model_name or env_model or defaults["default_model"]).strip() + if not self.model_name: + raise ValueError("model_name cannot be blank") + + # Load API key from env with .env fallback + resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or "" + resolved_key = resolved_key.strip() + if not resolved_key: + raise ValueError( + f"Missing API key for reranker provider '{self.provider}'. " + f"Pass api_key=... or set ${env_api_key}." + ) + self._api_key = resolved_key + + self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0 + self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3 + self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5 + self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0 + + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + } + if self.provider == "cohere": + headers.setdefault("Cohere-Version", "2022-12-06") + + self._client = httpx.Client( + base_url=self.api_base, + headers=headers, + timeout=self.timeout_s, + ) + + def close(self) -> None: + try: + self._client.close() + except Exception: # pragma: no cover - defensive + return + + def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None: + if retry_after_s is not None and retry_after_s > 0: + time.sleep(min(float(retry_after_s), self.backoff_max_s)) + return + + exp = self.backoff_base_s * (2**attempt) + jitter = random.uniform(0, min(0.5, self.backoff_base_s)) + time.sleep(min(self.backoff_max_s, exp + jitter)) + + @staticmethod + def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None: + value = (headers.get("Retry-After") or "").strip() + if not value: + return None + try: + return float(value) + except ValueError: + return None + + @staticmethod + def _should_retry_status(status_code: int) -> bool: + return status_code == 429 or 500 <= status_code <= 599 + + def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]: + last_exc: Exception | None = None + + for attempt in range(self.max_retries + 1): + try: + response = self._client.post(self.endpoint, json=dict(payload)) + except Exception as exc: # httpx is optional at import-time + last_exc = exc + if attempt < self.max_retries: + self._sleep_backoff(attempt) + continue + raise RuntimeError( + f"Rerank request failed for provider '{self.provider}' after " + f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}" + ) from exc + + status = int(getattr(response, "status_code", 0) or 0) + if status >= 400: + body_preview = "" + try: + body_preview = (response.text or "").strip() + except Exception: + body_preview = "" + if len(body_preview) > 300: + body_preview = body_preview[:300] + "…" + + if self._should_retry_status(status) and attempt < self.max_retries: + retry_after = self._parse_retry_after_seconds(response.headers) + logger.warning( + "Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…", + self.api_base, + self.endpoint, + status, + attempt + 1, + self.max_retries + 1, + ) + self._sleep_backoff(attempt, retry_after_s=retry_after) + continue + + if status in {401, 403}: + raise RuntimeError( + f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). " + "Check your API key." + ) + + raise RuntimeError( + f"Rerank request failed for provider '{self.provider}' (HTTP {status}). " + f"Response: {body_preview or ''}" + ) + + try: + data = response.json() + except Exception as exc: + raise RuntimeError( + f"Rerank response from provider '{self.provider}' is not valid JSON: " + f"{type(exc).__name__}: {exc}" + ) from exc + + if not isinstance(data, dict): + raise RuntimeError( + f"Rerank response from provider '{self.provider}' must be a JSON object; " + f"got {type(data).__name__}" + ) + + return data + + raise RuntimeError( + f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}" + ) + + @staticmethod + def _extract_scores_from_results(results: Any, expected: int) -> list[float]: + if not isinstance(results, list): + raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}") + + scores: list[float] = [0.0 for _ in range(expected)] + filled = 0 + + for item in results: + if not isinstance(item, dict): + continue + idx = item.get("index") + score = item.get("relevance_score", item.get("score")) + if idx is None or score is None: + continue + try: + idx_int = int(idx) + score_f = float(score) + except (TypeError, ValueError): + continue + if 0 <= idx_int < expected: + scores[idx_int] = score_f + filled += 1 + + if filled != expected: + raise RuntimeError( + f"Rerank response contained {filled}/{expected} scored documents; " + "ensure top_n matches the number of documents." + ) + + return scores + + def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]: + payload: dict[str, Any] = { + "model": self.model_name, + "query": query, + "documents": list(documents), + "top_n": len(documents), + "return_documents": False, + } + return payload + + def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]: + if not documents: + return [] + + payload = self._build_payload(query=query, documents=documents) + data = self._request_json(payload) + + results = data.get("results") + return self._extract_scores_from_results(results, expected=len(documents)) + + def score_pairs( + self, + pairs: Sequence[tuple[str, str]], + *, + batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility + ) -> list[float]: + if not pairs: + return [] + + grouped: dict[str, list[tuple[int, str]]] = {} + for idx, (query, doc) in enumerate(pairs): + grouped.setdefault(str(query), []).append((idx, str(doc))) + + scores: list[float] = [0.0 for _ in range(len(pairs))] + + for query, items in grouped.items(): + documents = [doc for _, doc in items] + query_scores = self._rerank_one_query(query=query, documents=documents) + for (orig_idx, _), score in zip(items, query_scores): + scores[orig_idx] = float(score) + + return scores diff --git a/codex-lens/src/codexlens/semantic/reranker/base.py b/codex-lens/src/codexlens/semantic/reranker/base.py new file mode 100644 index 00000000..870aca84 --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker/base.py @@ -0,0 +1,36 @@ +"""Base class for rerankers. + +Defines the interface that all rerankers must implement. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Sequence + + +class BaseReranker(ABC): + """Base class for all rerankers. + + All reranker implementations must inherit from this class and implement + the abstract methods to ensure a consistent interface. + """ + + @abstractmethod + def score_pairs( + self, + pairs: Sequence[tuple[str, str]], + *, + batch_size: int = 32, + ) -> list[float]: + """Score (query, doc) pairs. + + Args: + pairs: Sequence of (query, doc) string pairs to score. + batch_size: Batch size for scoring. + + Returns: + List of scores (one per pair). + """ + ... + diff --git a/codex-lens/src/codexlens/semantic/reranker/factory.py b/codex-lens/src/codexlens/semantic/reranker/factory.py new file mode 100644 index 00000000..6940020d --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker/factory.py @@ -0,0 +1,138 @@ +"""Factory for creating rerankers. + +Provides a unified interface for instantiating different reranker backends. +""" + +from __future__ import annotations + +from typing import Any + +from .base import BaseReranker + + +def check_reranker_available(backend: str) -> tuple[bool, str | None]: + """Check whether a specific reranker backend can be used. + + Notes: + - "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]). + - "onnx" uses Optimum + ONNX Runtime (pip install codexlens[reranker] or codexlens[reranker-onnx]). + - "api" uses a remote reranking HTTP API (requires httpx). + - "litellm" uses `ccw-litellm` for unified access to LLM providers. + """ + backend = (backend or "").strip().lower() + + if backend == "legacy": + from .legacy import check_cross_encoder_available + + return check_cross_encoder_available() + + if backend == "onnx": + from .onnx_reranker import check_onnx_reranker_available + + return check_onnx_reranker_available() + + if backend == "litellm": + try: + import ccw_litellm # noqa: F401 + except ImportError as exc: # pragma: no cover - optional dependency + return ( + False, + f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm", + ) + + try: + from .litellm_reranker import LiteLLMReranker # noqa: F401 + except Exception as exc: # pragma: no cover - defensive + return False, f"LiteLLM reranker backend not available: {exc}" + + return True, None + + if backend == "api": + from .api_reranker import check_httpx_available + + return check_httpx_available() + + return False, ( + f"Invalid reranker backend: {backend}. " + "Must be 'onnx', 'api', 'litellm', or 'legacy'." + ) + + +def get_reranker( + backend: str = "onnx", + model_name: str | None = None, + *, + device: str | None = None, + **kwargs: Any, +) -> BaseReranker: + """Factory function to create reranker based on backend. + + Args: + backend: Reranker backend to use. Options: + - "onnx": Optimum + onnxruntime backend (default) + - "api": HTTP API backend (remote providers) + - "litellm": LiteLLM backend (LLM-based, experimental) + - "legacy": sentence-transformers CrossEncoder backend (optional) + model_name: Model identifier for model-based backends. Defaults depend on backend: + - onnx: Xenova/ms-marco-MiniLM-L-6-v2 + - api: BAAI/bge-reranker-v2-m3 (SiliconFlow) + - legacy: cross-encoder/ms-marco-MiniLM-L-6-v2 + - litellm: default + device: Optional device string for backends that support it (legacy only). + **kwargs: Additional backend-specific arguments. + + Returns: + BaseReranker: Configured reranker instance. + + Raises: + ValueError: If backend is not recognized. + ImportError: If required backend dependencies are not installed or backend is unavailable. + """ + backend = (backend or "").strip().lower() + + if backend == "onnx": + ok, err = check_reranker_available("onnx") + if not ok: + raise ImportError(err) + + from .onnx_reranker import ONNXReranker + + resolved_model_name = (model_name or "").strip() or ONNXReranker.DEFAULT_MODEL + _ = device # Device selection is managed via ONNX Runtime providers. + return ONNXReranker(model_name=resolved_model_name, **kwargs) + + if backend == "legacy": + ok, err = check_reranker_available("legacy") + if not ok: + raise ImportError(err) + + from .legacy import CrossEncoderReranker + + resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2" + return CrossEncoderReranker(model_name=resolved_model_name, device=device) + + if backend == "litellm": + ok, err = check_reranker_available("litellm") + if not ok: + raise ImportError(err) + + from .litellm_reranker import LiteLLMReranker + + _ = device # Device selection is not applicable to remote LLM backends. + resolved_model_name = (model_name or "").strip() or "default" + return LiteLLMReranker(model=resolved_model_name, **kwargs) + + if backend == "api": + ok, err = check_reranker_available("api") + if not ok: + raise ImportError(err) + + from .api_reranker import APIReranker + + _ = device # Device selection is not applicable to remote HTTP backends. + resolved_model_name = (model_name or "").strip() or None + return APIReranker(model_name=resolved_model_name, **kwargs) + + raise ValueError( + f"Unknown backend: {backend}. Supported backends: 'onnx', 'api', 'litellm', 'legacy'" + ) diff --git a/codex-lens/src/codexlens/semantic/reranker.py b/codex-lens/src/codexlens/semantic/reranker/legacy.py similarity index 87% rename from codex-lens/src/codexlens/semantic/reranker.py rename to codex-lens/src/codexlens/semantic/reranker/legacy.py index 99a720fe..a5ee05de 100644 --- a/codex-lens/src/codexlens/semantic/reranker.py +++ b/codex-lens/src/codexlens/semantic/reranker/legacy.py @@ -1,6 +1,6 @@ -"""Optional cross-encoder reranker for second-stage search ranking. +"""Legacy sentence-transformers cross-encoder reranker. -Install with: pip install codexlens[reranker] +Install with: pip install codexlens[reranker-legacy] """ from __future__ import annotations @@ -9,6 +9,8 @@ import logging import threading from typing import List, Sequence, Tuple +from .base import BaseReranker + logger = logging.getLogger(__name__) try: @@ -25,10 +27,14 @@ except ImportError as exc: # pragma: no cover - optional dependency def check_cross_encoder_available() -> tuple[bool, str | None]: if CROSS_ENCODER_AVAILABLE: return True, None - return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]" + return ( + False, + _import_error + or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]", + ) -class CrossEncoderReranker: +class CrossEncoderReranker(BaseReranker): """Cross-encoder reranker with lazy model loading.""" def __init__(self, model_name: str, *, device: str | None = None) -> None: @@ -83,4 +89,3 @@ class CrossEncoderReranker: bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32 scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr] return [float(s) for s in scores] - diff --git a/codex-lens/src/codexlens/semantic/reranker/litellm_reranker.py b/codex-lens/src/codexlens/semantic/reranker/litellm_reranker.py new file mode 100644 index 00000000..ec735994 --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker/litellm_reranker.py @@ -0,0 +1,214 @@ +"""Experimental LiteLLM reranker backend. + +This module provides :class:`LiteLLMReranker`, which uses an LLM to score the +relevance of a single (query, document) pair per request. + +Notes: + - This backend is experimental and may be slow/expensive compared to local + rerankers. + - It relies on `ccw-litellm` for a unified LLM API across providers. +""" + +from __future__ import annotations + +import json +import logging +import re +import threading +import time +from typing import Any, Sequence + +from .base import BaseReranker + +logger = logging.getLogger(__name__) + +_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?") + + +def _coerce_score_to_unit_interval(score: float) -> float: + """Coerce a numeric score into [0, 1]. + + The prompt asks for a float in [0, 1], but some models may respond with 0-10 + or 0-100 scales. This function attempts a conservative normalization. + """ + if 0.0 <= score <= 1.0: + return score + if 0.0 <= score <= 10.0: + return score / 10.0 + if 0.0 <= score <= 100.0: + return score / 100.0 + return max(0.0, min(1.0, score)) + + +def _extract_score(text: str) -> float | None: + """Extract a numeric relevance score from an LLM response.""" + content = (text or "").strip() + if not content: + return None + + # Prefer JSON if present. + if "{" in content and "}" in content: + try: + start = content.index("{") + end = content.rindex("}") + 1 + payload = json.loads(content[start:end]) + if isinstance(payload, dict) and "score" in payload: + return float(payload["score"]) + except Exception: + pass + + match = _NUMBER_RE.search(content) + if not match: + return None + try: + return float(match.group(0)) + except ValueError: + return None + + +class LiteLLMReranker(BaseReranker): + """Experimental reranker that uses a LiteLLM-compatible model. + + This reranker scores each (query, doc) pair in isolation (single-pair mode) + to improve prompt reliability across providers. + """ + + _SYSTEM_PROMPT = ( + "You are a relevance scoring assistant.\n" + "Given a search query and a document snippet, output a single numeric " + "relevance score between 0 and 1.\n\n" + "Scoring guidance:\n" + "- 1.0: The document directly answers the query.\n" + "- 0.5: The document is partially relevant.\n" + "- 0.0: The document is unrelated.\n\n" + "Output requirements:\n" + "- Output ONLY the number (e.g., 0.73).\n" + "- Do not include any other text." + ) + + def __init__( + self, + model: str = "default", + *, + requests_per_minute: float | None = None, + min_interval_seconds: float | None = None, + default_score: float = 0.0, + max_doc_chars: int = 8000, + **litellm_kwargs: Any, + ) -> None: + """Initialize the reranker. + + Args: + model: Model name from ccw-litellm configuration (default: "default"). + requests_per_minute: Optional rate limit in requests per minute. + min_interval_seconds: Optional minimum interval between requests. If set, + it takes precedence over requests_per_minute. + default_score: Score to use when an API call fails or parsing fails. + max_doc_chars: Maximum number of document characters to include in the prompt. + **litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`. + + Raises: + ImportError: If ccw-litellm is not installed. + ValueError: If model is blank. + """ + self.model_name = (model or "").strip() + if not self.model_name: + raise ValueError("model cannot be blank") + + self.default_score = float(default_score) + + self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0 + + if min_interval_seconds is not None: + self._min_interval_seconds = max(0.0, float(min_interval_seconds)) + elif requests_per_minute is not None and float(requests_per_minute) > 0: + self._min_interval_seconds = 60.0 / float(requests_per_minute) + else: + self._min_interval_seconds = 0.0 + + # Prefer deterministic output by default; allow overrides via kwargs. + litellm_kwargs = dict(litellm_kwargs) + litellm_kwargs.setdefault("temperature", 0.0) + litellm_kwargs.setdefault("max_tokens", 16) + + try: + from ccw_litellm import ChatMessage, LiteLLMClient + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "ccw-litellm not installed. Install with: pip install ccw-litellm" + ) from exc + + self._ChatMessage = ChatMessage + self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs) + + self._lock = threading.RLock() + self._last_request_at = 0.0 + + def _sanitize_text(self, text: str) -> str: + # Keep consistent with LiteLLMEmbedderWrapper workaround. + if text.startswith("import"): + return " " + text + return text + + def _rate_limit(self) -> None: + if self._min_interval_seconds <= 0: + return + with self._lock: + now = time.monotonic() + elapsed = now - self._last_request_at + if elapsed < self._min_interval_seconds: + time.sleep(self._min_interval_seconds - elapsed) + self._last_request_at = time.monotonic() + + def _build_user_prompt(self, query: str, doc: str) -> str: + sanitized_query = self._sanitize_text(query or "") + sanitized_doc = self._sanitize_text(doc or "") + if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars: + sanitized_doc = sanitized_doc[: self.max_doc_chars] + + return ( + "Query:\n" + f"{sanitized_query}\n\n" + "Document:\n" + f"{sanitized_doc}\n\n" + "Return the relevance score (0 to 1) as a single number:" + ) + + def _score_single_pair(self, query: str, doc: str) -> float: + messages = [ + self._ChatMessage(role="system", content=self._SYSTEM_PROMPT), + self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)), + ] + + try: + self._rate_limit() + response = self._client.chat(messages) + except Exception as exc: + logger.debug("LiteLLM reranker request failed: %s", exc) + return self.default_score + + raw = getattr(response, "content", "") or "" + score = _extract_score(raw) + if score is None: + logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw) + return self.default_score + return _coerce_score_to_unit_interval(float(score)) + + def score_pairs( + self, + pairs: Sequence[tuple[str, str]], + *, + batch_size: int = 32, + ) -> list[float]: + """Score (query, doc) pairs with per-pair LLM calls.""" + if not pairs: + return [] + + bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32 + + scores: list[float] = [] + for i in range(0, len(pairs), bs): + batch = pairs[i : i + bs] + for query, doc in batch: + scores.append(self._score_single_pair(query, doc)) + return scores diff --git a/codex-lens/src/codexlens/semantic/reranker/onnx_reranker.py b/codex-lens/src/codexlens/semantic/reranker/onnx_reranker.py new file mode 100644 index 00000000..0b22f45e --- /dev/null +++ b/codex-lens/src/codexlens/semantic/reranker/onnx_reranker.py @@ -0,0 +1,268 @@ +"""Optimum + ONNX Runtime reranker backend. + +This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence +classification models. It is designed to run without requiring PyTorch at +runtime by using numpy tensors and ONNX Runtime execution providers. + +Install (CPU): + pip install onnxruntime optimum[onnxruntime] transformers +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any, Iterable, Sequence + +from .base import BaseReranker + +logger = logging.getLogger(__name__) + + +def check_onnx_reranker_available() -> tuple[bool, str | None]: + """Check whether Optimum + ONNXRuntime reranker dependencies are available.""" + try: + import numpy # noqa: F401 + except ImportError as exc: # pragma: no cover - optional dependency + return False, f"numpy not available: {exc}. Install with: pip install numpy" + + try: + import onnxruntime # noqa: F401 + except ImportError as exc: # pragma: no cover - optional dependency + return ( + False, + f"onnxruntime not available: {exc}. Install with: pip install onnxruntime", + ) + + try: + from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401 + except ImportError as exc: # pragma: no cover - optional dependency + return ( + False, + f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]", + ) + + try: + from transformers import AutoTokenizer # noqa: F401 + except ImportError as exc: # pragma: no cover - optional dependency + return ( + False, + f"transformers not available: {exc}. Install with: pip install transformers", + ) + + return True, None + + +def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]: + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + +class ONNXReranker(BaseReranker): + """Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading.""" + + DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2" + + def __init__( + self, + model_name: str | None = None, + *, + use_gpu: bool = True, + providers: list[Any] | None = None, + max_length: int | None = None, + ) -> None: + self.model_name = (model_name or self.DEFAULT_MODEL).strip() + if not self.model_name: + raise ValueError("model_name cannot be blank") + + self.use_gpu = bool(use_gpu) + self.providers = providers + + self.max_length = int(max_length) if max_length is not None else None + + self._tokenizer: Any | None = None + self._model: Any | None = None + self._model_input_names: set[str] | None = None + self._lock = threading.RLock() + + def _load_model(self) -> None: + if self._model is not None and self._tokenizer is not None: + return + + ok, err = check_onnx_reranker_available() + if not ok: + raise ImportError(err) + + with self._lock: + if self._model is not None and self._tokenizer is not None: + return + + from inspect import signature + + from optimum.onnxruntime import ORTModelForSequenceClassification + from transformers import AutoTokenizer + + if self.providers is None: + from ..gpu_support import get_optimal_providers + + # Include device_id options for DirectML/CUDA selection when available. + self.providers = get_optimal_providers( + use_gpu=self.use_gpu, with_device_options=True + ) + + # Some Optimum versions accept `providers`, others accept a single `provider`. + # Prefer passing the full providers list, with a conservative fallback. + model_kwargs: dict[str, Any] = {} + try: + params = signature(ORTModelForSequenceClassification.from_pretrained).parameters + if "providers" in params: + model_kwargs["providers"] = self.providers + elif "provider" in params: + provider_name = "CPUExecutionProvider" + if self.providers: + first = self.providers[0] + provider_name = first[0] if isinstance(first, tuple) else str(first) + model_kwargs["provider"] = provider_name + except Exception: + model_kwargs = {} + + try: + self._model = ORTModelForSequenceClassification.from_pretrained( + self.model_name, + **model_kwargs, + ) + except TypeError: + # Fallback for older Optimum versions: retry without provider arguments. + self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name) + + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) + + # Cache model input names to filter tokenizer outputs defensively. + input_names: set[str] | None = None + for attr in ("input_names", "model_input_names"): + names = getattr(self._model, attr, None) + if isinstance(names, (list, tuple)) and names: + input_names = {str(n) for n in names} + break + if input_names is None: + try: + session = getattr(self._model, "model", None) + if session is not None and hasattr(session, "get_inputs"): + input_names = {i.name for i in session.get_inputs()} + except Exception: + input_names = None + self._model_input_names = input_names + + @staticmethod + def _sigmoid(x: "Any") -> "Any": + import numpy as np + + x = np.clip(x, -50.0, 50.0) + return 1.0 / (1.0 + np.exp(-x)) + + @staticmethod + def _select_relevance_logit(logits: "Any") -> "Any": + import numpy as np + + arr = np.asarray(logits) + if arr.ndim == 0: + return arr.reshape(1) + if arr.ndim == 1: + return arr + if arr.ndim >= 2: + # Common cases: + # - Regression: (batch, 1) + # - Binary classification: (batch, 2) + if arr.shape[-1] == 1: + return arr[..., 0] + if arr.shape[-1] == 2: + # Convert 2-logit softmax into a single logit via difference. + return arr[..., 1] - arr[..., 0] + return arr.max(axis=-1) + return arr.reshape(-1) + + def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]: + if self._tokenizer is None: + raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive + + queries = [q for q, _ in batch] + docs = [d for _, d in batch] + + tokenizer_kwargs: dict[str, Any] = { + "text": queries, + "text_pair": docs, + "padding": True, + "truncation": True, + "return_tensors": "np", + } + + max_len = self.max_length + if max_len is None: + try: + model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0) + if 0 < model_max < 10_000: + max_len = model_max + else: + max_len = 512 + except Exception: + max_len = 512 + if max_len is not None and max_len > 0: + tokenizer_kwargs["max_length"] = int(max_len) + + encoded = self._tokenizer(**tokenizer_kwargs) + inputs = dict(encoded) + + # Some models do not accept token_type_ids; filter to known input names if available. + if self._model_input_names: + inputs = {k: v for k, v in inputs.items() if k in self._model_input_names} + + return inputs + + def _forward_logits(self, inputs: dict[str, Any]) -> Any: + if self._model is None: + raise RuntimeError("Model not loaded") # pragma: no cover - defensive + + outputs = self._model(**inputs) + if hasattr(outputs, "logits"): + return outputs.logits + if isinstance(outputs, dict) and "logits" in outputs: + return outputs["logits"] + if isinstance(outputs, (list, tuple)) and outputs: + return outputs[0] + raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive + + def score_pairs( + self, + pairs: Sequence[tuple[str, str]], + *, + batch_size: int = 32, + ) -> list[float]: + """Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1].""" + if not pairs: + return [] + + self._load_model() + + if self._model is None or self._tokenizer is None: # pragma: no cover - defensive + return [] + + import numpy as np + + bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32 + scores: list[float] = [] + + for batch in _iter_batches(list(pairs), bs): + inputs = self._tokenize_batch(batch) + logits = self._forward_logits(inputs) + rel_logits = self._select_relevance_logit(logits) + probs = self._sigmoid(rel_logits) + probs = np.clip(probs, 0.0, 1.0) + scores.extend([float(p) for p in probs.reshape(-1).tolist()]) + + if len(scores) != len(pairs): + logger.debug( + "ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs) + ) + return scores[: len(pairs)] + + return scores diff --git a/codex-lens/src/codexlens/watcher/__init__.py b/codex-lens/src/codexlens/watcher/__init__.py new file mode 100644 index 00000000..4c095ec4 --- /dev/null +++ b/codex-lens/src/codexlens/watcher/__init__.py @@ -0,0 +1,17 @@ +"""File watcher module for real-time index updates.""" + +from .events import ChangeType, FileEvent, IndexResult, WatcherConfig, WatcherStats +from .file_watcher import FileWatcher +from .incremental_indexer import IncrementalIndexer +from .manager import WatcherManager + +__all__ = [ + "ChangeType", + "FileEvent", + "IndexResult", + "WatcherConfig", + "WatcherStats", + "FileWatcher", + "IncrementalIndexer", + "WatcherManager", +] diff --git a/codex-lens/src/codexlens/watcher/events.py b/codex-lens/src/codexlens/watcher/events.py new file mode 100644 index 00000000..96860c93 --- /dev/null +++ b/codex-lens/src/codexlens/watcher/events.py @@ -0,0 +1,54 @@ +"""Event types for file watcher.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import List, Optional, Set + + +class ChangeType(Enum): + """Type of file system change.""" + CREATED = "created" + MODIFIED = "modified" + DELETED = "deleted" + MOVED = "moved" + + +@dataclass +class FileEvent: + """A file system change event.""" + path: Path + change_type: ChangeType + timestamp: float + old_path: Optional[Path] = None # For MOVED events + + +@dataclass +class WatcherConfig: + """Configuration for file watcher.""" + debounce_ms: int = 1000 + ignored_patterns: Set[str] = field(default_factory=lambda: { + ".git", ".venv", "venv", "node_modules", + "__pycache__", ".codexlens", ".idea", ".vscode", + }) + languages: Optional[List[str]] = None # None = all supported + + +@dataclass +class IndexResult: + """Result of processing file changes.""" + files_indexed: int = 0 + files_removed: int = 0 + symbols_added: int = 0 + errors: List[str] = field(default_factory=list) + + +@dataclass +class WatcherStats: + """Runtime statistics for watcher.""" + files_watched: int = 0 + events_processed: int = 0 + last_event_time: Optional[float] = None + is_running: bool = False diff --git a/codex-lens/src/codexlens/watcher/file_watcher.py b/codex-lens/src/codexlens/watcher/file_watcher.py new file mode 100644 index 00000000..e8007358 --- /dev/null +++ b/codex-lens/src/codexlens/watcher/file_watcher.py @@ -0,0 +1,245 @@ +"""File system watcher using watchdog library.""" + +from __future__ import annotations + +import logging +import threading +import time +from pathlib import Path +from typing import Callable, Dict, List, Optional + +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler + +from .events import ChangeType, FileEvent, WatcherConfig +from ..config import Config + +logger = logging.getLogger(__name__) + + +class _CodexLensHandler(FileSystemEventHandler): + """Internal handler for watchdog events.""" + + def __init__( + self, + watcher: "FileWatcher", + on_event: Callable[[FileEvent], None], + ) -> None: + super().__init__() + self._watcher = watcher + self._on_event = on_event + + def on_created(self, event) -> None: + if event.is_directory: + return + self._emit(event.src_path, ChangeType.CREATED) + + def on_modified(self, event) -> None: + if event.is_directory: + return + self._emit(event.src_path, ChangeType.MODIFIED) + + def on_deleted(self, event) -> None: + if event.is_directory: + return + self._emit(event.src_path, ChangeType.DELETED) + + def on_moved(self, event) -> None: + if event.is_directory: + return + self._emit(event.dest_path, ChangeType.MOVED, old_path=event.src_path) + + def _emit( + self, + path: str, + change_type: ChangeType, + old_path: Optional[str] = None, + ) -> None: + path_obj = Path(path) + + # Filter out files that should not be indexed + if not self._watcher._should_index_file(path_obj): + return + + event = FileEvent( + path=path_obj, + change_type=change_type, + timestamp=time.time(), + old_path=Path(old_path) if old_path else None, + ) + self._on_event(event) + + +class FileWatcher: + """File system watcher for monitoring directory changes. + + Uses watchdog library for cross-platform file system monitoring. + Events are forwarded to the on_changes callback. + + Example: + def handle_changes(events: List[FileEvent]) -> None: + for event in events: + print(f"{event.change_type}: {event.path}") + + watcher = FileWatcher(Path("."), WatcherConfig(), handle_changes) + watcher.start() + watcher.wait() # Block until stopped + """ + + def __init__( + self, + root_path: Path, + config: WatcherConfig, + on_changes: Callable[[List[FileEvent]], None], + ) -> None: + """Initialize file watcher. + + Args: + root_path: Directory to watch recursively + config: Watcher configuration + on_changes: Callback invoked with batched events + """ + self.root_path = Path(root_path).resolve() + self.config = config + self.on_changes = on_changes + + self._observer: Optional[Observer] = None + self._running = False + self._stop_event = threading.Event() + self._lock = threading.RLock() + + # Event queue for batching + self._event_queue: List[FileEvent] = [] + self._queue_lock = threading.Lock() + + # Debounce thread + self._debounce_thread: Optional[threading.Thread] = None + + # Config instance for language checking + self._codexlens_config = Config() + + def _should_index_file(self, path: Path) -> bool: + """Check if file should be indexed based on extension and ignore patterns. + + Args: + path: File path to check + + Returns: + True if file should be indexed, False otherwise + """ + # Check against ignore patterns + parts = path.parts + for pattern in self.config.ignored_patterns: + if pattern in parts: + return False + + # Check extension against supported languages + language = self._codexlens_config.language_for_path(path) + return language is not None + + def _on_raw_event(self, event: FileEvent) -> None: + """Handle raw event from watchdog handler.""" + with self._queue_lock: + self._event_queue.append(event) + # Debouncing is handled by background thread + + def _debounce_loop(self) -> None: + """Background thread for debounced event batching.""" + while self._running: + time.sleep(self.config.debounce_ms / 1000.0) + self._flush_events() + + def _flush_events(self) -> None: + """Flush queued events with deduplication.""" + with self._queue_lock: + if not self._event_queue: + return + + # Deduplicate: keep latest event per path + deduped: Dict[Path, FileEvent] = {} + for event in self._event_queue: + deduped[event.path] = event + + events = list(deduped.values()) + self._event_queue.clear() + + if events: + try: + self.on_changes(events) + except Exception as exc: + logger.error("Error in on_changes callback: %s", exc) + + def start(self) -> None: + """Start watching the directory. + + Non-blocking. Use wait() to block until stopped. + """ + with self._lock: + if self._running: + logger.warning("Watcher already running") + return + + if not self.root_path.exists(): + raise ValueError(f"Root path does not exist: {self.root_path}") + + self._observer = Observer() + handler = _CodexLensHandler(self, self._on_raw_event) + self._observer.schedule(handler, str(self.root_path), recursive=True) + + self._running = True + self._stop_event.clear() + self._observer.start() + + # Start debounce thread + self._debounce_thread = threading.Thread( + target=self._debounce_loop, + daemon=True, + name="FileWatcher-Debounce", + ) + self._debounce_thread.start() + + logger.info("Started watching: %s", self.root_path) + + def stop(self) -> None: + """Stop watching the directory. + + Gracefully stops the observer and flushes remaining events. + """ + with self._lock: + if not self._running: + return + + self._running = False + self._stop_event.set() + + if self._observer: + self._observer.stop() + self._observer.join(timeout=5.0) + self._observer = None + + # Wait for debounce thread to finish + if self._debounce_thread and self._debounce_thread.is_alive(): + self._debounce_thread.join(timeout=2.0) + self._debounce_thread = None + + # Flush any remaining events + self._flush_events() + + logger.info("Stopped watching: %s", self.root_path) + + def wait(self) -> None: + """Block until watcher is stopped. + + Use Ctrl+C or call stop() from another thread to unblock. + """ + try: + while self._running: + self._stop_event.wait(timeout=1.0) + except KeyboardInterrupt: + logger.info("Received interrupt, stopping watcher...") + self.stop() + + @property + def is_running(self) -> bool: + """Check if watcher is currently running.""" + return self._running diff --git a/codex-lens/src/codexlens/watcher/incremental_indexer.py b/codex-lens/src/codexlens/watcher/incremental_indexer.py new file mode 100644 index 00000000..bb836034 --- /dev/null +++ b/codex-lens/src/codexlens/watcher/incremental_indexer.py @@ -0,0 +1,359 @@ +"""Incremental indexer for processing file changes.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +from codexlens.config import Config +from codexlens.parsers.factory import ParserFactory +from codexlens.storage.dir_index import DirIndexStore +from codexlens.storage.global_index import GlobalSymbolIndex +from codexlens.storage.path_mapper import PathMapper +from codexlens.storage.registry import RegistryStore + +from .events import ChangeType, FileEvent, IndexResult + +logger = logging.getLogger(__name__) + + +@dataclass +class FileIndexResult: + """Result of indexing a single file.""" + path: Path + symbols_count: int + success: bool + error: Optional[str] = None + + +class IncrementalIndexer: + """Incremental indexer for processing file change events. + + Processes file events (create, modify, delete, move) and updates + the corresponding index databases incrementally. + + Reuses existing infrastructure: + - ParserFactory for symbol extraction + - DirIndexStore for per-directory storage + - GlobalSymbolIndex for cross-file symbols + - PathMapper for source-to-index path conversion + + Example: + indexer = IncrementalIndexer(registry, mapper, config) + result = indexer.process_changes([ + FileEvent(Path("foo.py"), ChangeType.MODIFIED, time.time()), + ]) + print(f"Indexed {result.files_indexed} files") + """ + + def __init__( + self, + registry: RegistryStore, + mapper: PathMapper, + config: Optional[Config] = None, + ) -> None: + """Initialize incremental indexer. + + Args: + registry: Global project registry + mapper: Path mapper for source-to-index conversion + config: CodexLens configuration (uses defaults if None) + """ + self.registry = registry + self.mapper = mapper + self.config = config or Config() + self.parser_factory = ParserFactory(self.config) + + self._global_index: Optional[GlobalSymbolIndex] = None + self._dir_stores: dict[Path, DirIndexStore] = {} + self._lock = __import__("threading").RLock() + + def _get_global_index(self, index_root: Path) -> Optional[GlobalSymbolIndex]: + """Get or create global symbol index.""" + if not self.config.global_symbol_index_enabled: + return None + + if self._global_index is None: + global_db_path = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME + if global_db_path.exists(): + self._global_index = GlobalSymbolIndex(global_db_path) + + return self._global_index + + def _get_dir_store(self, dir_path: Path) -> Optional[DirIndexStore]: + """Get DirIndexStore for a directory, if indexed.""" + with self._lock: + if dir_path in self._dir_stores: + return self._dir_stores[dir_path] + + index_db = self.mapper.source_to_index_db(dir_path) + if not index_db.exists(): + logger.debug("No index found for directory: %s", dir_path) + return None + + # Get index root for global index + index_root = self.mapper.source_to_index_dir( + self.mapper.get_project_root(dir_path) or dir_path + ) + global_index = self._get_global_index(index_root) + + store = DirIndexStore( + index_db, + config=self.config, + global_index=global_index, + ) + self._dir_stores[dir_path] = store + return store + + def process_changes(self, events: List[FileEvent]) -> IndexResult: + """Process a batch of file change events. + + Args: + events: List of file events to process + + Returns: + IndexResult with statistics + """ + result = IndexResult() + + for event in events: + try: + if event.change_type == ChangeType.CREATED: + file_result = self._index_file(event.path) + if file_result.success: + result.files_indexed += 1 + result.symbols_added += file_result.symbols_count + else: + result.errors.append(file_result.error or f"Failed to index: {event.path}") + + elif event.change_type == ChangeType.MODIFIED: + file_result = self._index_file(event.path) + if file_result.success: + result.files_indexed += 1 + result.symbols_added += file_result.symbols_count + else: + result.errors.append(file_result.error or f"Failed to index: {event.path}") + + elif event.change_type == ChangeType.DELETED: + self._remove_file(event.path) + result.files_removed += 1 + + elif event.change_type == ChangeType.MOVED: + # Remove from old location, add at new location + if event.old_path: + self._remove_file(event.old_path) + result.files_removed += 1 + file_result = self._index_file(event.path) + if file_result.success: + result.files_indexed += 1 + result.symbols_added += file_result.symbols_count + else: + result.errors.append(file_result.error or f"Failed to index: {event.path}") + + except Exception as exc: + error_msg = f"Error processing {event.path}: {type(exc).__name__}: {exc}" + logger.error(error_msg) + result.errors.append(error_msg) + + return result + + def _index_file(self, path: Path) -> FileIndexResult: + """Index a single file. + + Args: + path: Path to the file to index + + Returns: + FileIndexResult with status + """ + path = Path(path).resolve() + + # Check if file exists + if not path.exists(): + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=f"File not found: {path}", + ) + + # Check if language is supported + language = self.config.language_for_path(path) + if not language: + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=f"Unsupported language for: {path}", + ) + + # Get directory store + dir_path = path.parent + store = self._get_dir_store(dir_path) + if store is None: + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=f"Directory not indexed: {dir_path}", + ) + + # Read file content with fallback encodings + try: + content = path.read_text(encoding="utf-8") + except UnicodeDecodeError: + logger.debug("UTF-8 decode failed for %s, using fallback with errors='ignore'", path) + try: + content = path.read_text(encoding="utf-8", errors="ignore") + except Exception as exc: + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=f"Failed to read file: {exc}", + ) + except Exception as exc: + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=f"Failed to read file: {exc}", + ) + + # Parse symbols + try: + parser = self.parser_factory.get_parser(language) + indexed_file = parser.parse(content, path) + except Exception as exc: + error_msg = f"Failed to parse {path}: {type(exc).__name__}: {exc}" + logger.error(error_msg) + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=error_msg, + ) + + # Update store with retry logic for transient database errors + max_retries = 3 + for attempt in range(max_retries): + try: + store.add_file( + name=path.name, + full_path=str(path), + content=content, + language=language, + symbols=indexed_file.symbols, + relationships=indexed_file.relationships, + ) + + # Update merkle root + store.update_merkle_root() + + logger.debug("Indexed file: %s (%d symbols)", path, len(indexed_file.symbols)) + + return FileIndexResult( + path=path, + symbols_count=len(indexed_file.symbols), + success=True, + ) + + except __import__("sqlite3").OperationalError as exc: + # Transient database errors (e.g., database locked) + if attempt < max_retries - 1: + import time + wait_time = 0.1 * (2 ** attempt) # Exponential backoff + logger.debug("Database operation failed (attempt %d/%d), retrying in %.2fs: %s", + attempt + 1, max_retries, wait_time, exc) + time.sleep(wait_time) + continue + else: + error_msg = f"Failed to store {path} after {max_retries} attempts: {exc}" + logger.error(error_msg) + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=error_msg, + ) + except Exception as exc: + error_msg = f"Failed to store {path}: {type(exc).__name__}: {exc}" + logger.error(error_msg) + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error=error_msg, + ) + + # Should never reach here + return FileIndexResult( + path=path, + symbols_count=0, + success=False, + error="Unexpected error in indexing loop", + ) + + def _remove_file(self, path: Path) -> bool: + """Remove a file from the index. + + Args: + path: Path to the file to remove + + Returns: + True if removed successfully + """ + path = Path(path).resolve() + dir_path = path.parent + + store = self._get_dir_store(dir_path) + if store is None: + logger.debug("Cannot remove file, directory not indexed: %s", dir_path) + return False + + # Retry logic for transient database errors + max_retries = 3 + for attempt in range(max_retries): + try: + store.remove_file(str(path)) + store.update_merkle_root() + logger.debug("Removed file from index: %s", path) + return True + + except __import__("sqlite3").OperationalError as exc: + # Transient database errors (e.g., database locked) + if attempt < max_retries - 1: + import time + wait_time = 0.1 * (2 ** attempt) # Exponential backoff + logger.debug("Database operation failed (attempt %d/%d), retrying in %.2fs: %s", + attempt + 1, max_retries, wait_time, exc) + time.sleep(wait_time) + continue + else: + logger.error("Failed to remove %s after %d attempts: %s", path, max_retries, exc) + return False + except Exception as exc: + logger.error("Failed to remove %s: %s", path, exc) + return False + + # Should never reach here + return False + + def close(self) -> None: + """Close all open stores.""" + with self._lock: + for store in self._dir_stores.values(): + try: + store.close() + except Exception: + pass + self._dir_stores.clear() + + if self._global_index: + try: + self._global_index.close() + except Exception: + pass + self._global_index = None diff --git a/codex-lens/src/codexlens/watcher/manager.py b/codex-lens/src/codexlens/watcher/manager.py new file mode 100644 index 00000000..5ca0afdb --- /dev/null +++ b/codex-lens/src/codexlens/watcher/manager.py @@ -0,0 +1,194 @@ +"""Watcher manager for coordinating file watching and incremental indexing.""" + +from __future__ import annotations + +import logging +import signal +import threading +import time +from pathlib import Path +from typing import Callable, List, Optional + +from codexlens.config import Config +from codexlens.storage.path_mapper import PathMapper +from codexlens.storage.registry import RegistryStore + +from .events import FileEvent, IndexResult, WatcherConfig, WatcherStats +from .file_watcher import FileWatcher +from .incremental_indexer import IncrementalIndexer + +logger = logging.getLogger(__name__) + + +class WatcherManager: + """High-level manager for file watching and incremental indexing. + + Coordinates FileWatcher and IncrementalIndexer with: + - Lifecycle management (start/stop) + - Signal handling (SIGINT/SIGTERM) + - Statistics tracking + - Graceful shutdown + """ + + def __init__( + self, + root_path: Path, + config: Optional[Config] = None, + watcher_config: Optional[WatcherConfig] = None, + on_indexed: Optional[Callable[[IndexResult], None]] = None, + ) -> None: + self.root_path = Path(root_path).resolve() + self.config = config or Config() + self.watcher_config = watcher_config or WatcherConfig() + self.on_indexed = on_indexed + + self._registry: Optional[RegistryStore] = None + self._mapper: Optional[PathMapper] = None + self._watcher: Optional[FileWatcher] = None + self._indexer: Optional[IncrementalIndexer] = None + + self._running = False + self._stop_event = threading.Event() + self._lock = threading.RLock() + + # Statistics + self._stats = WatcherStats() + self._original_sigint = None + self._original_sigterm = None + + def _handle_changes(self, events: List[FileEvent]) -> None: + """Handle file change events from watcher.""" + if not self._indexer or not events: + return + + logger.info("Processing %d file changes", len(events)) + result = self._indexer.process_changes(events) + + # Update stats + self._stats.events_processed += len(events) + self._stats.last_event_time = time.time() + + if result.files_indexed > 0 or result.files_removed > 0: + logger.info( + "Indexed %d files, removed %d files, %d errors", + result.files_indexed, result.files_removed, len(result.errors) + ) + + if self.on_indexed: + try: + self.on_indexed(result) + except Exception as exc: + logger.error("Error in on_indexed callback: %s", exc) + + def _signal_handler(self, signum, frame) -> None: + """Handle shutdown signals.""" + logger.info("Received signal %d, stopping...", signum) + self.stop() + + def _install_signal_handlers(self) -> None: + """Install signal handlers for graceful shutdown.""" + try: + self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler) + if hasattr(signal, 'SIGTERM'): + self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler) + except (ValueError, OSError): + # Signal handling not available (e.g., not main thread) + pass + + def _restore_signal_handlers(self) -> None: + """Restore original signal handlers.""" + try: + if self._original_sigint is not None: + signal.signal(signal.SIGINT, self._original_sigint) + if self._original_sigterm is not None and hasattr(signal, 'SIGTERM'): + signal.signal(signal.SIGTERM, self._original_sigterm) + except (ValueError, OSError): + pass + + def start(self) -> None: + """Start watching and indexing.""" + with self._lock: + if self._running: + logger.warning("WatcherManager already running") + return + + # Validate path + if not self.root_path.exists(): + raise ValueError(f"Root path does not exist: {self.root_path}") + + # Initialize components + self._registry = RegistryStore() + self._registry.initialize() + self._mapper = PathMapper() + + self._indexer = IncrementalIndexer( + self._registry, self._mapper, self.config + ) + + self._watcher = FileWatcher( + self.root_path, self.watcher_config, self._handle_changes + ) + + # Install signal handlers + self._install_signal_handlers() + + # Start watcher + self._running = True + self._stats.is_running = True + self._stop_event.clear() + self._watcher.start() + + logger.info("WatcherManager started for: %s", self.root_path) + + def stop(self) -> None: + """Stop watching and clean up.""" + with self._lock: + if not self._running: + return + + self._running = False + self._stats.is_running = False + self._stop_event.set() + + # Stop watcher + if self._watcher: + self._watcher.stop() + self._watcher = None + + # Close indexer + if self._indexer: + self._indexer.close() + self._indexer = None + + # Close registry + if self._registry: + self._registry.close() + self._registry = None + + # Restore signal handlers + self._restore_signal_handlers() + + logger.info("WatcherManager stopped") + + def wait(self) -> None: + """Block until stopped.""" + try: + while self._running: + self._stop_event.wait(timeout=1.0) + except KeyboardInterrupt: + logger.info("Interrupted, stopping...") + self.stop() + + @property + def is_running(self) -> bool: + """Check if manager is running.""" + return self._running + + def get_stats(self) -> WatcherStats: + """Get runtime statistics.""" + return WatcherStats( + files_watched=self._stats.files_watched, + events_processed=self._stats.events_processed, + last_event_time=self._stats.last_event_time, + is_running=self._running, + ) diff --git a/codex-lens/tests/test_api_reranker.py b/codex-lens/tests/test_api_reranker.py new file mode 100644 index 00000000..4b4bfd1e --- /dev/null +++ b/codex-lens/tests/test_api_reranker.py @@ -0,0 +1,171 @@ +"""Tests for APIReranker backend.""" + +from __future__ import annotations + +import sys +import types +from typing import Any + +import pytest + +from codexlens.semantic.reranker import get_reranker +from codexlens.semantic.reranker.api_reranker import APIReranker + + +class DummyResponse: + def __init__( + self, + *, + status_code: int = 200, + json_data: Any = None, + text: str = "", + headers: dict[str, str] | None = None, + ) -> None: + self.status_code = int(status_code) + self._json_data = json_data + self.text = text + self.headers = headers or {} + + def json(self) -> Any: + return self._json_data + + +class DummyClient: + def __init__(self, *, base_url: str | None = None, headers: dict[str, str] | None = None, timeout: float | None = None) -> None: + self.base_url = base_url + self.headers = headers or {} + self.timeout = timeout + self.closed = False + self.calls: list[dict[str, Any]] = [] + self._responses: list[DummyResponse] = [] + + def queue(self, response: DummyResponse) -> None: + self._responses.append(response) + + def post(self, endpoint: str, *, json: dict[str, Any] | None = None) -> DummyResponse: + self.calls.append({"endpoint": endpoint, "json": json}) + if not self._responses: + raise AssertionError("DummyClient has no queued responses") + return self._responses.pop(0) + + def close(self) -> None: + self.closed = True + + +@pytest.fixture +def httpx_clients(monkeypatch: pytest.MonkeyPatch) -> list[DummyClient]: + clients: list[DummyClient] = [] + + dummy_httpx = types.ModuleType("httpx") + + def Client(*, base_url: str | None = None, headers: dict[str, str] | None = None, timeout: float | None = None) -> DummyClient: + client = DummyClient(base_url=base_url, headers=headers, timeout=timeout) + clients.append(client) + return client + + dummy_httpx.Client = Client + monkeypatch.setitem(sys.modules, "httpx", dummy_httpx) + + return clients + + +def test_api_reranker_requires_api_key( + monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient] +) -> None: + monkeypatch.delenv("RERANKER_API_KEY", raising=False) + + with pytest.raises(ValueError, match="Missing API key"): + APIReranker() + + assert httpx_clients == [] + + +def test_api_reranker_reads_api_key_from_env( + monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient] +) -> None: + monkeypatch.setenv("RERANKER_API_KEY", "test-key") + + reranker = APIReranker() + assert len(httpx_clients) == 1 + assert httpx_clients[0].headers["Authorization"] == "Bearer test-key" + reranker.close() + assert httpx_clients[0].closed is True + + +def test_api_reranker_scores_pairs_siliconflow( + monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient] +) -> None: + monkeypatch.delenv("RERANKER_API_KEY", raising=False) + + reranker = APIReranker(api_key="k", provider="siliconflow") + client = httpx_clients[0] + + client.queue( + DummyResponse( + json_data={ + "results": [ + {"index": 0, "relevance_score": 0.9}, + {"index": 1, "relevance_score": 0.1}, + ] + } + ) + ) + + scores = reranker.score_pairs([("q", "d1"), ("q", "d2")]) + assert scores == pytest.approx([0.9, 0.1]) + + assert client.calls[0]["endpoint"] == "/v1/rerank" + payload = client.calls[0]["json"] + assert payload["model"] == "BAAI/bge-reranker-v2-m3" + assert payload["query"] == "q" + assert payload["documents"] == ["d1", "d2"] + assert payload["top_n"] == 2 + assert payload["return_documents"] is False + + +def test_api_reranker_retries_on_5xx( + monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient] +) -> None: + monkeypatch.setenv("RERANKER_API_KEY", "k") + + from codexlens.semantic.reranker import api_reranker as api_reranker_module + + monkeypatch.setattr(api_reranker_module.time, "sleep", lambda *_args, **_kwargs: None) + + reranker = APIReranker(max_retries=1) + client = httpx_clients[0] + + client.queue(DummyResponse(status_code=500, text="oops", json_data={"error": "oops"})) + client.queue( + DummyResponse( + json_data={"results": [{"index": 0, "relevance_score": 0.7}]}, + ) + ) + + scores = reranker.score_pairs([("q", "d")]) + assert scores == pytest.approx([0.7]) + assert len(client.calls) == 2 + + +def test_api_reranker_unauthorized_raises( + monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient] +) -> None: + monkeypatch.setenv("RERANKER_API_KEY", "k") + + reranker = APIReranker() + client = httpx_clients[0] + client.queue(DummyResponse(status_code=401, text="unauthorized")) + + with pytest.raises(RuntimeError, match="unauthorized"): + reranker.score_pairs([("q", "d")]) + + +def test_factory_api_backend_constructs_reranker( + monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient] +) -> None: + monkeypatch.setenv("RERANKER_API_KEY", "k") + + reranker = get_reranker(backend="api") + assert isinstance(reranker, APIReranker) + assert len(httpx_clients) == 1 + diff --git a/codex-lens/tests/test_hybrid_search_reranker_backend.py b/codex-lens/tests/test_hybrid_search_reranker_backend.py new file mode 100644 index 00000000..85a8564f --- /dev/null +++ b/codex-lens/tests/test_hybrid_search_reranker_backend.py @@ -0,0 +1,139 @@ +"""Tests for HybridSearchEngine reranker backend selection.""" + +from __future__ import annotations + +import pytest + +from codexlens.config import Config +from codexlens.search.hybrid_search import HybridSearchEngine + + +def test_get_cross_encoder_reranker_uses_factory_backend_legacy( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + calls: dict[str, object] = {} + + def fake_check_reranker_available(backend: str): + calls["check_backend"] = backend + return True, None + + sentinel = object() + + def fake_get_reranker(*, backend: str, model_name=None, device=None, **kwargs): + calls["get_args"] = { + "backend": backend, + "model_name": model_name, + "device": device, + "kwargs": kwargs, + } + return sentinel + + monkeypatch.setattr( + "codexlens.semantic.reranker.check_reranker_available", + fake_check_reranker_available, + ) + monkeypatch.setattr( + "codexlens.semantic.reranker.get_reranker", + fake_get_reranker, + ) + + config = Config( + data_dir=tmp_path / "legacy", + enable_reranking=True, + enable_cross_encoder_rerank=True, + reranker_backend="legacy", + reranker_model="dummy-model", + ) + engine = HybridSearchEngine(config=config) + + reranker = engine._get_cross_encoder_reranker() + assert reranker is sentinel + assert calls["check_backend"] == "legacy" + + get_args = calls["get_args"] + assert isinstance(get_args, dict) + assert get_args["backend"] == "legacy" + assert get_args["model_name"] == "dummy-model" + assert get_args["device"] is None + + +def test_get_cross_encoder_reranker_uses_factory_backend_onnx_gpu_flag( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + calls: dict[str, object] = {} + + def fake_check_reranker_available(backend: str): + calls["check_backend"] = backend + return True, None + + sentinel = object() + + def fake_get_reranker(*, backend: str, model_name=None, device=None, **kwargs): + calls["get_args"] = { + "backend": backend, + "model_name": model_name, + "device": device, + "kwargs": kwargs, + } + return sentinel + + monkeypatch.setattr( + "codexlens.semantic.reranker.check_reranker_available", + fake_check_reranker_available, + ) + monkeypatch.setattr( + "codexlens.semantic.reranker.get_reranker", + fake_get_reranker, + ) + + config = Config( + data_dir=tmp_path / "onnx", + enable_reranking=True, + enable_cross_encoder_rerank=True, + reranker_backend="onnx", + embedding_use_gpu=False, + ) + engine = HybridSearchEngine(config=config) + + reranker = engine._get_cross_encoder_reranker() + assert reranker is sentinel + assert calls["check_backend"] == "onnx" + + get_args = calls["get_args"] + assert isinstance(get_args, dict) + assert get_args["backend"] == "onnx" + assert get_args["model_name"] is None + assert get_args["device"] is None + assert get_args["kwargs"]["use_gpu"] is False + + +def test_get_cross_encoder_reranker_returns_none_when_backend_unavailable( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + def fake_check_reranker_available(backend: str): + return False, "missing deps" + + def fake_get_reranker(*args, **kwargs): + raise AssertionError("get_reranker should not be called when backend is unavailable") + + monkeypatch.setattr( + "codexlens.semantic.reranker.check_reranker_available", + fake_check_reranker_available, + ) + monkeypatch.setattr( + "codexlens.semantic.reranker.get_reranker", + fake_get_reranker, + ) + + config = Config( + data_dir=tmp_path / "unavailable", + enable_reranking=True, + enable_cross_encoder_rerank=True, + reranker_backend="onnx", + ) + engine = HybridSearchEngine(config=config) + + assert engine._get_cross_encoder_reranker() is None diff --git a/codex-lens/tests/test_litellm_reranker.py b/codex-lens/tests/test_litellm_reranker.py new file mode 100644 index 00000000..60c843d8 --- /dev/null +++ b/codex-lens/tests/test_litellm_reranker.py @@ -0,0 +1,85 @@ +"""Tests for LiteLLMReranker (LLM-based reranking).""" + +from __future__ import annotations + +import sys +import types +from dataclasses import dataclass + +import pytest + +from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker + + +def _install_dummy_ccw_litellm( + monkeypatch: pytest.MonkeyPatch, *, responses: list[str] +) -> None: + @dataclass(frozen=True, slots=True) + class ChatMessage: + role: str + content: str + + class LiteLLMClient: + def __init__(self, model: str = "default", **kwargs) -> None: + self.model = model + self.kwargs = kwargs + self._responses = list(responses) + self.calls: list[list[ChatMessage]] = [] + + def chat(self, messages, **kwargs): + self.calls.append(list(messages)) + content = self._responses.pop(0) if self._responses else "" + return types.SimpleNamespace(content=content) + + dummy = types.ModuleType("ccw_litellm") + dummy.ChatMessage = ChatMessage + dummy.LiteLLMClient = LiteLLMClient + monkeypatch.setitem(sys.modules, "ccw_litellm", dummy) + + +def test_score_pairs_parses_numbers_and_normalizes_scales( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _install_dummy_ccw_litellm(monkeypatch, responses=["0.73", "7", "80"]) + + reranker = LiteLLMReranker(model="dummy") + scores = reranker.score_pairs([("q", "d1"), ("q", "d2"), ("q", "d3")]) + assert scores == pytest.approx([0.73, 0.7, 0.8]) + + +def test_score_pairs_parses_json_score_field(monkeypatch: pytest.MonkeyPatch) -> None: + _install_dummy_ccw_litellm(monkeypatch, responses=['{"score": 0.42}']) + + reranker = LiteLLMReranker(model="dummy") + scores = reranker.score_pairs([("q", "d")]) + assert scores == pytest.approx([0.42]) + + +def test_score_pairs_uses_default_score_on_parse_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _install_dummy_ccw_litellm(monkeypatch, responses=["N/A"]) + + reranker = LiteLLMReranker(model="dummy", default_score=0.123) + scores = reranker.score_pairs([("q", "d")]) + assert scores == pytest.approx([0.123]) + + +def test_rate_limiting_sleeps_between_requests(monkeypatch: pytest.MonkeyPatch) -> None: + _install_dummy_ccw_litellm(monkeypatch, responses=["0.1", "0.2"]) + + reranker = LiteLLMReranker(model="dummy", min_interval_seconds=1.0) + + import codexlens.semantic.reranker.litellm_reranker as litellm_reranker_module + + sleeps: list[float] = [] + times = iter([100.0, 100.0, 100.1, 100.1]) + + monkeypatch.setattr(litellm_reranker_module.time, "monotonic", lambda: next(times)) + monkeypatch.setattr( + litellm_reranker_module.time, "sleep", lambda seconds: sleeps.append(seconds) + ) + + _ = reranker.score_pairs([("q", "d1"), ("q", "d2")]) + assert sleeps == pytest.approx([0.9]) + diff --git a/codex-lens/tests/test_reranker_backends.py b/codex-lens/tests/test_reranker_backends.py new file mode 100644 index 00000000..439631ef --- /dev/null +++ b/codex-lens/tests/test_reranker_backends.py @@ -0,0 +1,115 @@ +"""Mocked smoke tests for all reranker backends.""" + +from __future__ import annotations + +import sys +import types +from dataclasses import dataclass + +import pytest + + +def test_reranker_backend_legacy_scores_pairs(monkeypatch: pytest.MonkeyPatch) -> None: + from codexlens.semantic.reranker import legacy as legacy_module + + class DummyCrossEncoder: + def __init__(self, model_name: str, *, device: str | None = None) -> None: + self.model_name = model_name + self.device = device + self.calls: list[dict[str, object]] = [] + + def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]: + self.calls.append({"pairs": list(pairs), "batch_size": int(batch_size)}) + return [0.5 for _ in pairs] + + monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder) + monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True) + monkeypatch.setattr(legacy_module, "_import_error", None) + + reranker = legacy_module.CrossEncoderReranker(model_name="dummy-model", device="cpu") + scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0) + assert scores == pytest.approx([0.5, 0.5]) + + +def test_reranker_backend_onnx_availability_check(monkeypatch: pytest.MonkeyPatch) -> None: + from codexlens.semantic.reranker.onnx_reranker import check_onnx_reranker_available + + dummy_numpy = types.ModuleType("numpy") + dummy_onnxruntime = types.ModuleType("onnxruntime") + + dummy_optimum = types.ModuleType("optimum") + dummy_optimum.__path__ = [] # Mark as package for submodule imports. + dummy_optimum_ort = types.ModuleType("optimum.onnxruntime") + dummy_optimum_ort.ORTModelForSequenceClassification = object() + + dummy_transformers = types.ModuleType("transformers") + dummy_transformers.AutoTokenizer = object() + + monkeypatch.setitem(sys.modules, "numpy", dummy_numpy) + monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime) + monkeypatch.setitem(sys.modules, "optimum", dummy_optimum) + monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort) + monkeypatch.setitem(sys.modules, "transformers", dummy_transformers) + + ok, err = check_onnx_reranker_available() + assert ok is True + assert err is None + + +def test_reranker_backend_api_constructs_with_dummy_httpx(monkeypatch: pytest.MonkeyPatch) -> None: + from codexlens.semantic.reranker.api_reranker import APIReranker + + created: list[object] = [] + + class DummyClient: + def __init__( + self, + *, + base_url: str | None = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + ) -> None: + self.base_url = base_url + self.headers = headers or {} + self.timeout = timeout + self.closed = False + created.append(self) + + def close(self) -> None: + self.closed = True + + dummy_httpx = types.ModuleType("httpx") + dummy_httpx.Client = DummyClient + monkeypatch.setitem(sys.modules, "httpx", dummy_httpx) + + reranker = APIReranker(api_key="k", provider="siliconflow") + assert reranker.provider == "siliconflow" + assert len(created) == 1 + assert created[0].headers["Authorization"] == "Bearer k" + reranker.close() + assert created[0].closed is True + + +def test_reranker_backend_litellm_scores_pairs(monkeypatch: pytest.MonkeyPatch) -> None: + from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker + + @dataclass(frozen=True, slots=True) + class ChatMessage: + role: str + content: str + + class DummyLiteLLMClient: + def __init__(self, model: str = "default", **_kwargs: object) -> None: + self.model = model + + def chat(self, _messages: list[ChatMessage]) -> object: + return types.SimpleNamespace(content="0.5") + + dummy_litellm = types.ModuleType("ccw_litellm") + dummy_litellm.ChatMessage = ChatMessage + dummy_litellm.LiteLLMClient = DummyLiteLLMClient + monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm) + + reranker = LiteLLMReranker(model="dummy") + assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5]) + diff --git a/codex-lens/tests/test_reranker_factory.py b/codex-lens/tests/test_reranker_factory.py new file mode 100644 index 00000000..682c410f --- /dev/null +++ b/codex-lens/tests/test_reranker_factory.py @@ -0,0 +1,315 @@ +"""Tests for reranker factory and availability checks.""" + +from __future__ import annotations + +import builtins +import math +import sys +import types + +import pytest + +from codexlens.semantic.reranker import ( + BaseReranker, + ONNXReranker, + check_reranker_available, + get_reranker, +) +from codexlens.semantic.reranker import legacy as legacy_module + + +def test_public_imports_work() -> None: + from codexlens.semantic.reranker import BaseReranker as ImportedBaseReranker + from codexlens.semantic.reranker import get_reranker as imported_get_reranker + + assert ImportedBaseReranker is BaseReranker + assert imported_get_reranker is get_reranker + + +def test_base_reranker_is_abstract() -> None: + with pytest.raises(TypeError): + BaseReranker() # type: ignore[abstract] + + +def test_check_reranker_available_invalid_backend() -> None: + ok, err = check_reranker_available("nope") + assert ok is False + assert "Invalid reranker backend" in (err or "") + + +def test_get_reranker_invalid_backend_raises_value_error() -> None: + with pytest.raises(ValueError, match="Unknown backend"): + get_reranker("nope") + + +def test_get_reranker_legacy_missing_dependency_raises_import_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", False) + monkeypatch.setattr(legacy_module, "_import_error", "missing sentence-transformers") + + with pytest.raises(ImportError, match="missing sentence-transformers"): + get_reranker(backend="legacy", model_name="dummy-model") + + +def test_get_reranker_legacy_returns_cross_encoder_reranker( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class DummyCrossEncoder: + def __init__(self, model_name: str, *, device: str | None = None) -> None: + self.model_name = model_name + self.device = device + self.last_batch_size: int | None = None + + def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]: + self.last_batch_size = int(batch_size) + return [0.5 for _ in pairs] + + monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder) + monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True) + monkeypatch.setattr(legacy_module, "_import_error", None) + + reranker = get_reranker(backend=" LEGACY ", model_name="dummy-model", device="cpu") + assert isinstance(reranker, legacy_module.CrossEncoderReranker) + + assert reranker.score_pairs([]) == [] + + scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0) + assert scores == pytest.approx([0.5, 0.5]) + assert reranker._model is not None + assert reranker._model.last_batch_size == 32 + + +def test_check_reranker_available_onnx_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None: + real_import = builtins.__import__ + + def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): + if name == "onnxruntime": + raise ImportError("no onnxruntime") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + ok, err = check_reranker_available("onnx") + assert ok is False + assert "onnxruntime not available" in (err or "") + + +def test_check_reranker_available_onnx_deps_present(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_onnxruntime = types.ModuleType("onnxruntime") + dummy_optimum = types.ModuleType("optimum") + dummy_optimum.__path__ = [] # Mark as package for submodule imports. + dummy_optimum_ort = types.ModuleType("optimum.onnxruntime") + dummy_optimum_ort.ORTModelForSequenceClassification = object() + + dummy_transformers = types.ModuleType("transformers") + dummy_transformers.AutoTokenizer = object() + + monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime) + monkeypatch.setitem(sys.modules, "optimum", dummy_optimum) + monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort) + monkeypatch.setitem(sys.modules, "transformers", dummy_transformers) + + ok, err = check_reranker_available("onnx") + assert ok is True + assert err is None + + +def test_check_reranker_available_litellm_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None: + real_import = builtins.__import__ + + def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): + if name == "ccw_litellm": + raise ImportError("no ccw-litellm") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + ok, err = check_reranker_available("litellm") + assert ok is False + assert "ccw-litellm not available" in (err or "") + + +def test_check_reranker_available_litellm_deps_present( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dummy_litellm = types.ModuleType("ccw_litellm") + monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm) + + ok, err = check_reranker_available("litellm") + assert ok is True + assert err is None + + +def test_check_reranker_available_api_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None: + real_import = builtins.__import__ + + def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): + if name == "httpx": + raise ImportError("no httpx") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + ok, err = check_reranker_available("api") + assert ok is False + assert "httpx not available" in (err or "") + + +def test_check_reranker_available_api_deps_present(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_httpx = types.ModuleType("httpx") + monkeypatch.setitem(sys.modules, "httpx", dummy_httpx) + + ok, err = check_reranker_available("api") + assert ok is True + assert err is None + + +def test_get_reranker_litellm_returns_litellm_reranker( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from dataclasses import dataclass + + @dataclass(frozen=True, slots=True) + class ChatMessage: + role: str + content: str + + class DummyLiteLLMClient: + def __init__(self, model: str = "default", **kwargs) -> None: + self.model = model + self.kwargs = kwargs + + def chat(self, messages, **kwargs): + return types.SimpleNamespace(content="0.5") + + dummy_litellm = types.ModuleType("ccw_litellm") + dummy_litellm.ChatMessage = ChatMessage + dummy_litellm.LiteLLMClient = DummyLiteLLMClient + monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm) + + reranker = get_reranker(backend="litellm", model_name="dummy-model") + + from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker + + assert isinstance(reranker, LiteLLMReranker) + assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5]) + + +def test_get_reranker_onnx_raises_import_error_with_dependency_hint( + monkeypatch: pytest.MonkeyPatch, +) -> None: + real_import = builtins.__import__ + + def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0): + if name == "onnxruntime": + raise ImportError("no onnxruntime") + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + with pytest.raises(ImportError) as exc: + get_reranker(backend="onnx", model_name="any") + + assert "onnxruntime" in str(exc.value) + + +def test_get_reranker_default_backend_is_onnx(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_onnxruntime = types.ModuleType("onnxruntime") + dummy_optimum = types.ModuleType("optimum") + dummy_optimum.__path__ = [] # Mark as package for submodule imports. + dummy_optimum_ort = types.ModuleType("optimum.onnxruntime") + dummy_optimum_ort.ORTModelForSequenceClassification = object() + + dummy_transformers = types.ModuleType("transformers") + dummy_transformers.AutoTokenizer = object() + + monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime) + monkeypatch.setitem(sys.modules, "optimum", dummy_optimum) + monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort) + monkeypatch.setitem(sys.modules, "transformers", dummy_transformers) + + reranker = get_reranker() + assert isinstance(reranker, ONNXReranker) + + +def test_onnx_reranker_scores_pairs_with_sigmoid_normalization( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import numpy as np + + dummy_onnxruntime = types.ModuleType("onnxruntime") + + dummy_optimum = types.ModuleType("optimum") + dummy_optimum.__path__ = [] # Mark as package for submodule imports. + dummy_optimum_ort = types.ModuleType("optimum.onnxruntime") + + class DummyModelOutput: + def __init__(self, logits: np.ndarray) -> None: + self.logits = logits + + class DummyModel: + input_names = ["input_ids", "attention_mask"] + + def __init__(self) -> None: + self.calls: list[int] = [] + self._next_logit = 0 + + def __call__(self, **inputs): + batch = int(inputs["input_ids"].shape[0]) + start = self._next_logit + self._next_logit += batch + self.calls.append(batch) + logits = np.arange(start, start + batch, dtype=np.float32).reshape(batch, 1) + return DummyModelOutput(logits=logits) + + class DummyORTModelForSequenceClassification: + @classmethod + def from_pretrained(cls, model_name: str, providers=None, **kwargs): + _ = model_name, providers, kwargs + return DummyModel() + + dummy_optimum_ort.ORTModelForSequenceClassification = DummyORTModelForSequenceClassification + + dummy_transformers = types.ModuleType("transformers") + + class DummyAutoTokenizer: + model_max_length = 512 + + @classmethod + def from_pretrained(cls, model_name: str, **kwargs): + _ = model_name, kwargs + return cls() + + def __call__(self, *, text, text_pair, return_tensors, **kwargs): + _ = text_pair, kwargs + assert return_tensors == "np" + batch = len(text) + # Include token_type_ids to ensure input filtering is exercised. + return { + "input_ids": np.zeros((batch, 4), dtype=np.int64), + "attention_mask": np.ones((batch, 4), dtype=np.int64), + "token_type_ids": np.zeros((batch, 4), dtype=np.int64), + } + + dummy_transformers.AutoTokenizer = DummyAutoTokenizer + + monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime) + monkeypatch.setitem(sys.modules, "optimum", dummy_optimum) + monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort) + monkeypatch.setitem(sys.modules, "transformers", dummy_transformers) + + reranker = get_reranker(backend="onnx", model_name="dummy-model", use_gpu=False) + assert isinstance(reranker, ONNXReranker) + assert reranker._model is None + + pairs = [("q", f"d{idx}") for idx in range(5)] + scores = reranker.score_pairs(pairs, batch_size=2) + + assert reranker._model is not None + assert reranker._model.calls == [2, 2, 1] + assert len(scores) == len(pairs) + assert all(0.0 <= s <= 1.0 for s in scores) + + expected = [1.0 / (1.0 + math.exp(-float(i))) for i in range(len(pairs))] + assert scores == pytest.approx(expected, rel=1e-6, abs=1e-6) diff --git a/codex-lens/tests/test_watcher/__init__.py b/codex-lens/tests/test_watcher/__init__.py new file mode 100644 index 00000000..f736461b --- /dev/null +++ b/codex-lens/tests/test_watcher/__init__.py @@ -0,0 +1 @@ +"""Tests for watcher module.""" diff --git a/codex-lens/tests/test_watcher/conftest.py b/codex-lens/tests/test_watcher/conftest.py new file mode 100644 index 00000000..f3fcecfb --- /dev/null +++ b/codex-lens/tests/test_watcher/conftest.py @@ -0,0 +1,43 @@ +"""Fixtures for watcher tests.""" + +from __future__ import annotations + +import tempfile +from pathlib import Path +from typing import Generator + +import pytest + + +@pytest.fixture +def temp_project() -> Generator[Path, None, None]: + """Create a temporary project directory with sample files.""" + with tempfile.TemporaryDirectory() as tmpdir: + project = Path(tmpdir) + + # Create sample Python file + py_file = project / "main.py" + py_file.write_text("def hello():\n print('Hello')\n") + + # Create sample JavaScript file + js_file = project / "app.js" + js_file.write_text("function greet() {\n console.log('Hi');\n}\n") + + # Create subdirectory with file + sub_dir = project / "src" + sub_dir.mkdir() + (sub_dir / "utils.py").write_text("def add(a, b):\n return a + b\n") + + # Create ignored directory + git_dir = project / ".git" + git_dir.mkdir() + (git_dir / "config").write_text("[core]\n") + + yield project + + +@pytest.fixture +def watcher_config(): + """Create default watcher configuration.""" + from codexlens.watcher import WatcherConfig + return WatcherConfig(debounce_ms=100) # Short debounce for tests diff --git a/codex-lens/tests/test_watcher/test_events.py b/codex-lens/tests/test_watcher/test_events.py new file mode 100644 index 00000000..c3f3a53f --- /dev/null +++ b/codex-lens/tests/test_watcher/test_events.py @@ -0,0 +1,103 @@ +"""Tests for watcher event types.""" + +from __future__ import annotations + +import time +from pathlib import Path + +import pytest + +from codexlens.watcher import ChangeType, FileEvent, WatcherConfig, IndexResult, WatcherStats + + +class TestChangeType: + """Tests for ChangeType enum.""" + + def test_change_types_exist(self): + """Verify all change types are defined.""" + assert ChangeType.CREATED.value == "created" + assert ChangeType.MODIFIED.value == "modified" + assert ChangeType.DELETED.value == "deleted" + assert ChangeType.MOVED.value == "moved" + + def test_change_type_count(self): + """Verify we have exactly 4 change types.""" + assert len(ChangeType) == 4 + + +class TestFileEvent: + """Tests for FileEvent dataclass.""" + + def test_create_event(self): + """Test creating a file event.""" + event = FileEvent( + path=Path("/test/file.py"), + change_type=ChangeType.CREATED, + timestamp=time.time(), + ) + assert event.path == Path("/test/file.py") + assert event.change_type == ChangeType.CREATED + assert event.old_path is None + + def test_moved_event(self): + """Test creating a moved event with old_path.""" + event = FileEvent( + path=Path("/test/new.py"), + change_type=ChangeType.MOVED, + timestamp=time.time(), + old_path=Path("/test/old.py"), + ) + assert event.old_path == Path("/test/old.py") + + +class TestWatcherConfig: + """Tests for WatcherConfig dataclass.""" + + def test_default_config(self): + """Test default configuration values.""" + config = WatcherConfig() + assert config.debounce_ms == 1000 + assert ".git" in config.ignored_patterns + assert "node_modules" in config.ignored_patterns + assert "__pycache__" in config.ignored_patterns + assert config.languages is None + + def test_custom_debounce(self): + """Test custom debounce setting.""" + config = WatcherConfig(debounce_ms=500) + assert config.debounce_ms == 500 + + +class TestIndexResult: + """Tests for IndexResult dataclass.""" + + def test_default_result(self): + """Test default result values.""" + result = IndexResult() + assert result.files_indexed == 0 + assert result.files_removed == 0 + assert result.symbols_added == 0 + assert result.errors == [] + + def test_custom_result(self): + """Test creating result with values.""" + result = IndexResult( + files_indexed=5, + files_removed=2, + symbols_added=50, + errors=["error1"], + ) + assert result.files_indexed == 5 + assert result.files_removed == 2 + + +class TestWatcherStats: + """Tests for WatcherStats dataclass.""" + + def test_default_stats(self): + """Test default stats values.""" + stats = WatcherStats() + assert stats.files_watched == 0 + assert stats.events_processed == 0 + assert stats.last_event_time is None + assert stats.is_running is False diff --git a/codex-lens/tests/test_watcher/test_file_watcher.py b/codex-lens/tests/test_watcher/test_file_watcher.py new file mode 100644 index 00000000..50aa352a --- /dev/null +++ b/codex-lens/tests/test_watcher/test_file_watcher.py @@ -0,0 +1,124 @@ +"""Tests for FileWatcher class.""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import List + +import pytest + +from codexlens.watcher import FileWatcher, WatcherConfig, FileEvent, ChangeType + + +class TestFileWatcherInit: + """Tests for FileWatcher initialization.""" + + def test_init_with_valid_path(self, temp_project: Path, watcher_config: WatcherConfig): + """Test initializing with valid path.""" + events: List[FileEvent] = [] + watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e)) + + assert watcher.root_path == temp_project.resolve() + assert watcher.config == watcher_config + assert not watcher.is_running + + def test_start_with_invalid_path(self, watcher_config: WatcherConfig): + """Test starting watcher with non-existent path.""" + events: List[FileEvent] = [] + watcher = FileWatcher(Path("/nonexistent/path"), watcher_config, lambda e: events.extend(e)) + + with pytest.raises(ValueError, match="does not exist"): + watcher.start() + + +class TestFileWatcherLifecycle: + """Tests for FileWatcher start/stop lifecycle.""" + + def test_start_stop(self, temp_project: Path, watcher_config: WatcherConfig): + """Test basic start and stop.""" + events: List[FileEvent] = [] + watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e)) + + watcher.start() + assert watcher.is_running + + watcher.stop() + assert not watcher.is_running + + def test_double_start(self, temp_project: Path, watcher_config: WatcherConfig): + """Test calling start twice.""" + events: List[FileEvent] = [] + watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e)) + + watcher.start() + watcher.start() # Should not raise + assert watcher.is_running + + watcher.stop() + + def test_double_stop(self, temp_project: Path, watcher_config: WatcherConfig): + """Test calling stop twice.""" + events: List[FileEvent] = [] + watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e)) + + watcher.start() + watcher.stop() + watcher.stop() # Should not raise + assert not watcher.is_running + + +class TestFileWatcherEvents: + """Tests for FileWatcher event detection.""" + + def test_detect_file_creation(self, temp_project: Path, watcher_config: WatcherConfig): + """Test detecting new file creation.""" + events: List[FileEvent] = [] + watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e)) + + try: + watcher.start() + time.sleep(0.3) # Let watcher start (longer for Windows) + + # Create new file + new_file = temp_project / "new_file.py" + new_file.write_text("# New file\n") + + # Wait for event with retries (watchdog timing varies by platform) + max_wait = 2.0 + waited = 0.0 + while waited < max_wait: + time.sleep(0.2) + waited += 0.2 + # Windows may report MODIFIED instead of CREATED + file_events = [e for e in events if e.change_type in (ChangeType.CREATED, ChangeType.MODIFIED)] + if any(e.path.name == "new_file.py" for e in file_events): + break + + # Check event was detected (Windows may report MODIFIED instead of CREATED) + relevant_events = [e for e in events if e.change_type in (ChangeType.CREATED, ChangeType.MODIFIED)] + assert len(relevant_events) >= 1, f"Expected file event, got: {events}" + assert any(e.path.name == "new_file.py" for e in relevant_events) + finally: + watcher.stop() + + def test_filter_ignored_directories(self, temp_project: Path, watcher_config: WatcherConfig): + """Test that files in ignored directories are filtered.""" + events: List[FileEvent] = [] + watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e)) + + try: + watcher.start() + time.sleep(0.1) + + # Create file in .git (should be ignored) + git_file = temp_project / ".git" / "test.py" + git_file.write_text("# In git\n") + + time.sleep(watcher_config.debounce_ms / 1000.0 + 0.2) + + # No events should be detected for .git files + git_events = [e for e in events if ".git" in str(e.path)] + assert len(git_events) == 0 + finally: + watcher.stop()