mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
feat(codex-lens): add unified reranker architecture and file watcher
Unified Reranker Architecture: - Add BaseReranker ABC with factory pattern - Implement 4 backends: ONNX (default), API, LiteLLM, Legacy - Add .env configuration parsing for API credentials - Migrate from sentence-transformers to optimum+onnxruntime File Watcher Module: - Add real-time file system monitoring with watchdog - Implement IncrementalIndexer for single-file updates - Add WatcherManager with signal handling and graceful shutdown - Add 'codexlens watch' CLI command - Event filtering, debouncing, and deduplication - Thread-safe design with proper resource cleanup Tests: 16 watcher tests + 5 reranker test files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
66
codex-lens/.env.example
Normal file
66
codex-lens/.env.example
Normal file
@@ -0,0 +1,66 @@
|
||||
# CodexLens Environment Configuration
|
||||
# Copy this file to .codexlens/.env and fill in your values
|
||||
#
|
||||
# Priority order:
|
||||
# 1. Environment variables (already set in shell)
|
||||
# 2. .codexlens/.env (workspace-local, this file)
|
||||
# 3. .env (project root)
|
||||
|
||||
# ============================================
|
||||
# RERANKER Configuration
|
||||
# ============================================
|
||||
|
||||
# API key for reranker service (SiliconFlow/Cohere/Jina)
|
||||
# Required for 'api' backend
|
||||
# RERANKER_API_KEY=sk-xxxx
|
||||
|
||||
# Base URL for reranker API (overrides provider default)
|
||||
# SiliconFlow: https://api.siliconflow.cn
|
||||
# Cohere: https://api.cohere.ai
|
||||
# Jina: https://api.jina.ai
|
||||
# RERANKER_API_BASE=https://api.siliconflow.cn
|
||||
|
||||
# Reranker provider: siliconflow, cohere, jina
|
||||
# RERANKER_PROVIDER=siliconflow
|
||||
|
||||
# Reranker model name
|
||||
# SiliconFlow: BAAI/bge-reranker-v2-m3
|
||||
# Cohere: rerank-english-v3.0
|
||||
# Jina: jina-reranker-v2-base-multilingual
|
||||
# RERANKER_MODEL=BAAI/bge-reranker-v2-m3
|
||||
|
||||
# ============================================
|
||||
# EMBEDDING Configuration
|
||||
# ============================================
|
||||
|
||||
# API key for embedding service (for litellm backend)
|
||||
# EMBEDDING_API_KEY=sk-xxxx
|
||||
|
||||
# Base URL for embedding API
|
||||
# EMBEDDING_API_BASE=https://api.openai.com
|
||||
|
||||
# Embedding model name
|
||||
# EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
# ============================================
|
||||
# LITELLM Configuration
|
||||
# ============================================
|
||||
|
||||
# API key for LiteLLM (for litellm reranker backend)
|
||||
# LITELLM_API_KEY=sk-xxxx
|
||||
|
||||
# Base URL for LiteLLM
|
||||
# LITELLM_API_BASE=
|
||||
|
||||
# LiteLLM model name
|
||||
# LITELLM_MODEL=gpt-4o-mini
|
||||
|
||||
# ============================================
|
||||
# General Configuration
|
||||
# ============================================
|
||||
|
||||
# Custom data directory path (default: ~/.codexlens)
|
||||
# CODEXLENS_DATA_DIR=~/.codexlens
|
||||
|
||||
# Enable debug mode (true/false)
|
||||
# CODEXLENS_DEBUG=false
|
||||
@@ -21,6 +21,7 @@ dependencies = [
|
||||
"tree-sitter-javascript>=0.25",
|
||||
"tree-sitter-typescript>=0.23",
|
||||
"pathspec>=0.11",
|
||||
"watchdog>=3.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -50,11 +51,35 @@ semantic-directml = [
|
||||
]
|
||||
|
||||
# Cross-encoder reranking (second-stage, optional)
|
||||
# Install with: pip install codexlens[reranker]
|
||||
reranker = [
|
||||
# Install with: pip install codexlens[reranker] (default: ONNX backend)
|
||||
reranker-onnx = [
|
||||
"optimum>=1.16",
|
||||
"onnxruntime>=1.15",
|
||||
"transformers>=4.36",
|
||||
]
|
||||
|
||||
# Remote reranking via HTTP API
|
||||
reranker-api = [
|
||||
"httpx>=0.25",
|
||||
]
|
||||
|
||||
# LLM-based reranking via ccw-litellm
|
||||
reranker-litellm = [
|
||||
"ccw-litellm>=0.1",
|
||||
]
|
||||
|
||||
# Legacy sentence-transformers CrossEncoder reranker
|
||||
reranker-legacy = [
|
||||
"sentence-transformers>=2.2",
|
||||
]
|
||||
|
||||
# Backward-compatible alias for default reranker backend
|
||||
reranker = [
|
||||
"optimum>=1.16",
|
||||
"onnxruntime>=1.15",
|
||||
"transformers>=4.36",
|
||||
]
|
||||
|
||||
# Encoding detection for non-UTF8 files
|
||||
encoding = [
|
||||
"chardet>=5.0",
|
||||
|
||||
@@ -22,6 +22,7 @@ from codexlens.storage.registry import RegistryStore, ProjectInfo
|
||||
from codexlens.storage.index_tree import IndexTreeBuilder
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
|
||||
from codexlens.watcher import WatcherManager, WatcherConfig
|
||||
|
||||
from .output import (
|
||||
console,
|
||||
@@ -321,6 +322,91 @@ def init(
|
||||
registry.close()
|
||||
|
||||
|
||||
@app.command()
|
||||
def watch(
|
||||
path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to watch."),
|
||||
language: Optional[List[str]] = typer.Option(
|
||||
None,
|
||||
"--language",
|
||||
"-l",
|
||||
help="Limit watching to specific languages (repeat or comma-separated).",
|
||||
),
|
||||
debounce: int = typer.Option(1000, "--debounce", "-d", min=100, max=10000, help="Debounce interval in milliseconds."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose logging."),
|
||||
) -> None:
|
||||
"""Watch directory for changes and update index incrementally.
|
||||
|
||||
Monitors filesystem events and automatically updates the index
|
||||
when files are created, modified, or deleted.
|
||||
|
||||
The directory must already be indexed (run 'codexlens init' first).
|
||||
|
||||
Press Ctrl+C to stop watching.
|
||||
|
||||
Examples:
|
||||
codexlens watch .
|
||||
codexlens watch /path/to/project --debounce 500 --verbose
|
||||
codexlens watch . --language python,typescript
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
|
||||
from codexlens.watcher.events import IndexResult
|
||||
|
||||
base_path = path.expanduser().resolve()
|
||||
|
||||
# Check if path is indexed
|
||||
mapper = PathMapper()
|
||||
index_db = mapper.source_to_index_db(base_path)
|
||||
if not index_db.exists():
|
||||
console.print(f"[red]Error:[/red] Directory not indexed: {base_path}")
|
||||
console.print("Run 'codexlens init' first to create the index.")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Parse languages
|
||||
languages = _parse_languages(language)
|
||||
|
||||
# Create watcher config
|
||||
watcher_config = WatcherConfig(
|
||||
debounce_ms=debounce,
|
||||
languages=languages,
|
||||
)
|
||||
|
||||
# Callback for indexed files
|
||||
def on_indexed(result: IndexResult) -> None:
|
||||
if result.files_indexed > 0:
|
||||
console.print(f" [green]Indexed:[/green] {result.files_indexed} files ({result.symbols_added} symbols)")
|
||||
if result.files_removed > 0:
|
||||
console.print(f" [yellow]Removed:[/yellow] {result.files_removed} files")
|
||||
if result.errors:
|
||||
for error in result.errors[:3]: # Show first 3 errors
|
||||
console.print(f" [red]Error:[/red] {error}")
|
||||
|
||||
console.print(f"[bold]Watching:[/bold] {base_path}")
|
||||
console.print(f" Debounce: {debounce}ms")
|
||||
if languages:
|
||||
console.print(f" Languages: {', '.join(languages)}")
|
||||
console.print(" Press Ctrl+C to stop.\n")
|
||||
|
||||
manager: WatcherManager | None = None
|
||||
try:
|
||||
manager = WatcherManager(
|
||||
root_path=base_path,
|
||||
watcher_config=watcher_config,
|
||||
on_indexed=on_indexed,
|
||||
)
|
||||
manager.start()
|
||||
manager.wait()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Error:[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
if manager is not None:
|
||||
manager.stop()
|
||||
console.print("\n[dim]Watcher stopped.[/dim]")
|
||||
|
||||
|
||||
@app.command()
|
||||
def search(
|
||||
query: str = typer.Argument(..., help="FTS query to run."),
|
||||
@@ -2293,3 +2379,102 @@ def gpu_reset(
|
||||
if gpu_info.preferred_device_id is not None:
|
||||
console.print(f" Auto-selected device: {gpu_info.preferred_device_id}")
|
||||
console.print(f" Device: [cyan]{gpu_info.gpu_name}[/cyan]")
|
||||
|
||||
|
||||
# ==================== Watch Command ====================
|
||||
|
||||
@app.command()
|
||||
def watch(
|
||||
path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to watch."),
|
||||
language: Optional[List[str]] = typer.Option(None, "--language", "-l", help="Languages to watch (comma-separated)."),
|
||||
debounce: int = typer.Option(1000, "--debounce", "-d", min=100, max=10000, help="Debounce interval in milliseconds."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
|
||||
) -> None:
|
||||
"""Watch a directory for file changes and incrementally update the index.
|
||||
|
||||
Monitors the specified directory for file system changes (create, modify, delete)
|
||||
and automatically updates the CodexLens index. The directory must already be indexed
|
||||
using 'codexlens init' before watching.
|
||||
|
||||
Examples:
|
||||
# Watch current directory
|
||||
codexlens watch .
|
||||
|
||||
# Watch with custom debounce interval
|
||||
codexlens watch . --debounce 2000
|
||||
|
||||
# Watch only Python and JavaScript files
|
||||
codexlens watch . --language python,javascript
|
||||
|
||||
Press Ctrl+C to stop watching.
|
||||
"""
|
||||
_configure_logging(verbose)
|
||||
watch_path = path.expanduser().resolve()
|
||||
|
||||
registry: RegistryStore | None = None
|
||||
try:
|
||||
# Validate that path is indexed
|
||||
registry = RegistryStore()
|
||||
registry.initialize()
|
||||
mapper = PathMapper()
|
||||
|
||||
project_record = registry.find_by_source_path(str(watch_path))
|
||||
if not project_record:
|
||||
console.print(f"[red]Error:[/red] Directory is not indexed: {watch_path}")
|
||||
console.print("[dim]Run 'codexlens init' first to create an index.[/dim]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Parse languages
|
||||
languages = _parse_languages(language)
|
||||
|
||||
# Create watcher config
|
||||
watcher_config = WatcherConfig(
|
||||
debounce_ms=debounce,
|
||||
languages=languages,
|
||||
)
|
||||
|
||||
# Display startup message
|
||||
console.print(f"[green]Starting watcher for:[/green] {watch_path}")
|
||||
console.print(f"[dim]Debounce interval: {debounce}ms[/dim]")
|
||||
if languages:
|
||||
console.print(f"[dim]Watching languages: {', '.join(languages)}[/dim]")
|
||||
console.print("[dim]Press Ctrl+C to stop[/dim]\n")
|
||||
|
||||
# Create and start watcher manager
|
||||
manager = WatcherManager(
|
||||
root_path=watch_path,
|
||||
watcher_config=watcher_config,
|
||||
on_indexed=lambda result: _display_index_result(result),
|
||||
)
|
||||
|
||||
manager.start()
|
||||
manager.wait()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Stopping watcher...[/yellow]")
|
||||
except CodexLensError as exc:
|
||||
console.print(f"[red]Watch failed:[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as exc:
|
||||
console.print(f"[red]Unexpected error:[/red] {exc}")
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
if registry is not None:
|
||||
registry.close()
|
||||
|
||||
|
||||
def _display_index_result(result) -> None:
|
||||
"""Display indexing result in real-time."""
|
||||
if result.files_indexed > 0 or result.files_removed > 0:
|
||||
parts = []
|
||||
if result.files_indexed > 0:
|
||||
parts.append(f"[green]✓ Indexed {result.files_indexed} file(s)[/green]")
|
||||
if result.files_removed > 0:
|
||||
parts.append(f"[yellow]✗ Removed {result.files_removed} file(s)[/yellow]")
|
||||
console.print(" | ".join(parts))
|
||||
|
||||
if result.errors:
|
||||
for error in result.errors[:3]: # Show max 3 errors
|
||||
console.print(f" [red]Error:[/red] {error}")
|
||||
if len(result.errors) > 3:
|
||||
console.print(f" [dim]... and {len(result.errors) - 3} more errors[/dim]")
|
||||
|
||||
@@ -116,8 +116,9 @@ class Config:
|
||||
reranking_top_k: int = 50
|
||||
symbol_boost_factor: float = 1.5
|
||||
|
||||
# Optional cross-encoder reranking (second stage, requires codexlens[reranker])
|
||||
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
|
||||
enable_cross_encoder_rerank: bool = False
|
||||
reranker_backend: str = "onnx"
|
||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
reranker_top_k: int = 50
|
||||
|
||||
@@ -311,6 +312,35 @@ class WorkspaceConfig:
|
||||
"""Cache directory for this workspace."""
|
||||
return self.codexlens_dir / "cache"
|
||||
|
||||
@property
|
||||
def env_path(self) -> Path:
|
||||
"""Path to workspace .env file."""
|
||||
return self.codexlens_dir / ".env"
|
||||
|
||||
def load_env(self, *, override: bool = False) -> int:
|
||||
"""Load .env file and apply to os.environ.
|
||||
|
||||
Args:
|
||||
override: If True, override existing environment variables
|
||||
|
||||
Returns:
|
||||
Number of variables applied
|
||||
"""
|
||||
from .env_config import apply_workspace_env
|
||||
return apply_workspace_env(self.workspace_root, override=override)
|
||||
|
||||
def get_api_config(self, prefix: str) -> dict:
|
||||
"""Get API configuration from environment.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||
|
||||
Returns:
|
||||
Dictionary with api_key, api_base, model, etc.
|
||||
"""
|
||||
from .env_config import get_api_config
|
||||
return get_api_config(prefix, workspace_root=self.workspace_root)
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Create the .codexlens directory structure."""
|
||||
try:
|
||||
@@ -324,6 +354,7 @@ class WorkspaceConfig:
|
||||
"# CodexLens workspace data\n"
|
||||
"cache/\n"
|
||||
"*.log\n"
|
||||
".env\n" # Exclude .env from git
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc
|
||||
|
||||
260
codex-lens/src/codexlens/env_config.py
Normal file
260
codex-lens/src/codexlens/env_config.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Environment configuration loader for CodexLens.
|
||||
|
||||
Loads .env files from workspace .codexlens directory with fallback to project root.
|
||||
Provides unified access to API configurations.
|
||||
|
||||
Priority order:
|
||||
1. Environment variables (already set)
|
||||
2. .codexlens/.env (workspace-local)
|
||||
3. .env (project root)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Supported environment variables with descriptions
|
||||
ENV_VARS = {
|
||||
# Reranker API configuration
|
||||
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
|
||||
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
|
||||
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
|
||||
"RERANKER_MODEL": "Reranker model name",
|
||||
# Embedding API configuration
|
||||
"EMBEDDING_API_KEY": "API key for embedding service",
|
||||
"EMBEDDING_API_BASE": "Base URL for embedding API",
|
||||
"EMBEDDING_MODEL": "Embedding model name",
|
||||
# LiteLLM configuration
|
||||
"LITELLM_API_KEY": "API key for LiteLLM",
|
||||
"LITELLM_API_BASE": "Base URL for LiteLLM",
|
||||
"LITELLM_MODEL": "LiteLLM model name",
|
||||
# General configuration
|
||||
"CODEXLENS_DATA_DIR": "Custom data directory path",
|
||||
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
|
||||
}
|
||||
|
||||
|
||||
def _parse_env_line(line: str) -> tuple[str, str] | None:
|
||||
"""Parse a single .env line, returning (key, value) or None."""
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
return None
|
||||
|
||||
# Handle export prefix
|
||||
if line.startswith("export "):
|
||||
line = line[7:].strip()
|
||||
|
||||
# Split on first =
|
||||
if "=" not in line:
|
||||
return None
|
||||
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Remove surrounding quotes
|
||||
if len(value) >= 2:
|
||||
if (value.startswith('"') and value.endswith('"')) or \
|
||||
(value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1]
|
||||
|
||||
return key, value
|
||||
|
||||
|
||||
def load_env_file(env_path: Path) -> Dict[str, str]:
|
||||
"""Load environment variables from a .env file.
|
||||
|
||||
Args:
|
||||
env_path: Path to .env file
|
||||
|
||||
Returns:
|
||||
Dictionary of environment variables
|
||||
"""
|
||||
if not env_path.is_file():
|
||||
return {}
|
||||
|
||||
env_vars: Dict[str, str] = {}
|
||||
|
||||
try:
|
||||
content = env_path.read_text(encoding="utf-8")
|
||||
for line in content.splitlines():
|
||||
result = _parse_env_line(line)
|
||||
if result:
|
||||
key, value = result
|
||||
env_vars[key] = value
|
||||
except Exception as exc:
|
||||
log.warning("Failed to load .env file %s: %s", env_path, exc)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
|
||||
"""Load environment variables from workspace .env files.
|
||||
|
||||
Priority (later overrides earlier):
|
||||
1. Project root .env
|
||||
2. .codexlens/.env
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace root directory. If None, uses current directory.
|
||||
|
||||
Returns:
|
||||
Merged dictionary of environment variables
|
||||
"""
|
||||
if workspace_root is None:
|
||||
workspace_root = Path.cwd()
|
||||
|
||||
workspace_root = Path(workspace_root).resolve()
|
||||
|
||||
env_vars: Dict[str, str] = {}
|
||||
|
||||
# Load from project root .env (lowest priority)
|
||||
root_env = workspace_root / ".env"
|
||||
if root_env.is_file():
|
||||
env_vars.update(load_env_file(root_env))
|
||||
log.debug("Loaded %d vars from %s", len(env_vars), root_env)
|
||||
|
||||
# Load from .codexlens/.env (higher priority)
|
||||
codexlens_env = workspace_root / ".codexlens" / ".env"
|
||||
if codexlens_env.is_file():
|
||||
loaded = load_env_file(codexlens_env)
|
||||
env_vars.update(loaded)
|
||||
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
|
||||
"""Load .env files and apply to os.environ.
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace root directory
|
||||
override: If True, override existing environment variables
|
||||
|
||||
Returns:
|
||||
Number of variables applied
|
||||
"""
|
||||
env_vars = load_workspace_env(workspace_root)
|
||||
applied = 0
|
||||
|
||||
for key, value in env_vars.items():
|
||||
if override or key not in os.environ:
|
||||
os.environ[key] = value
|
||||
applied += 1
|
||||
log.debug("Applied env var: %s", key)
|
||||
|
||||
return applied
|
||||
|
||||
|
||||
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback.
|
||||
|
||||
Priority:
|
||||
1. os.environ (already set)
|
||||
2. .codexlens/.env
|
||||
3. .env
|
||||
4. default value
|
||||
|
||||
Args:
|
||||
key: Environment variable name
|
||||
default: Default value if not found
|
||||
workspace_root: Workspace root for .env file lookup
|
||||
|
||||
Returns:
|
||||
Value or default
|
||||
"""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Load from .env files
|
||||
env_vars = load_workspace_env(workspace_root)
|
||||
if key in env_vars:
|
||||
return env_vars[key]
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_api_config(
|
||||
prefix: str,
|
||||
*,
|
||||
workspace_root: Path | None = None,
|
||||
defaults: Dict[str, Any] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get API configuration from environment.
|
||||
|
||||
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
|
||||
|
||||
Args:
|
||||
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
|
||||
workspace_root: Workspace root for .env file lookup
|
||||
defaults: Default values
|
||||
|
||||
Returns:
|
||||
Dictionary with api_key, api_base, model, etc.
|
||||
"""
|
||||
defaults = defaults or {}
|
||||
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
# Standard API config fields
|
||||
field_mapping = {
|
||||
"api_key": f"{prefix}_API_KEY",
|
||||
"api_base": f"{prefix}_API_BASE",
|
||||
"model": f"{prefix}_MODEL",
|
||||
"provider": f"{prefix}_PROVIDER",
|
||||
"timeout": f"{prefix}_TIMEOUT",
|
||||
}
|
||||
|
||||
for field, env_key in field_mapping.items():
|
||||
value = get_env(env_key, workspace_root=workspace_root)
|
||||
if value is not None:
|
||||
# Type conversion for specific fields
|
||||
if field == "timeout":
|
||||
try:
|
||||
config[field] = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
config[field] = value
|
||||
elif field in defaults:
|
||||
config[field] = defaults[field]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def generate_env_example() -> str:
|
||||
"""Generate .env.example content with all supported variables.
|
||||
|
||||
Returns:
|
||||
String content for .env.example file
|
||||
"""
|
||||
lines = [
|
||||
"# CodexLens Environment Configuration",
|
||||
"# Copy this file to .codexlens/.env and fill in your values",
|
||||
"",
|
||||
]
|
||||
|
||||
# Group by prefix
|
||||
groups: Dict[str, list] = {}
|
||||
for key, desc in ENV_VARS.items():
|
||||
prefix = key.split("_")[0]
|
||||
if prefix not in groups:
|
||||
groups[prefix] = []
|
||||
groups[prefix].append((key, desc))
|
||||
|
||||
for prefix, items in groups.items():
|
||||
lines.append(f"# {prefix} Configuration")
|
||||
for key, desc in items:
|
||||
lines.append(f"# {desc}")
|
||||
lines.append(f"# {key}=")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -258,20 +258,52 @@ class HybridSearchEngine:
|
||||
return None
|
||||
|
||||
try:
|
||||
from codexlens.semantic.reranker import CrossEncoderReranker, check_cross_encoder_available
|
||||
from codexlens.semantic.reranker import (
|
||||
check_reranker_available,
|
||||
get_reranker,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug("Cross-encoder reranker unavailable: %s", exc)
|
||||
self.logger.debug("Reranker factory unavailable: %s", exc)
|
||||
return None
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
backend = (getattr(self._config, "reranker_backend", "") or "").strip().lower() or "onnx"
|
||||
|
||||
ok, err = check_reranker_available(backend)
|
||||
if not ok:
|
||||
self.logger.debug("Cross-encoder reranker unavailable: %s", err)
|
||||
self.logger.debug(
|
||||
"Reranker backend unavailable (backend=%s): %s",
|
||||
backend,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return CrossEncoderReranker(model_name=self._config.reranker_model)
|
||||
model_name = (getattr(self._config, "reranker_model", "") or "").strip() or None
|
||||
|
||||
if backend != "legacy" and model_name == "cross-encoder/ms-marco-MiniLM-L-6-v2":
|
||||
model_name = None
|
||||
|
||||
device: str | None = None
|
||||
kwargs: dict[str, Any] = {}
|
||||
|
||||
if backend == "onnx":
|
||||
kwargs["use_gpu"] = bool(getattr(self._config, "embedding_use_gpu", True))
|
||||
elif backend == "legacy":
|
||||
if not bool(getattr(self._config, "embedding_use_gpu", True)):
|
||||
device = "cpu"
|
||||
|
||||
return get_reranker(
|
||||
backend=backend,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to initialize cross-encoder reranker: %s", exc)
|
||||
self.logger.debug(
|
||||
"Failed to initialize reranker (backend=%s): %s",
|
||||
backend,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
def _search_parallel(
|
||||
|
||||
22
codex-lens/src/codexlens/semantic/reranker/__init__.py
Normal file
22
codex-lens/src/codexlens/semantic/reranker/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Reranker backends for second-stage search ranking.
|
||||
|
||||
This subpackage provides a unified interface and factory for different reranking
|
||||
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import BaseReranker
|
||||
from .factory import check_reranker_available, get_reranker
|
||||
from .legacy import CrossEncoderReranker, check_cross_encoder_available
|
||||
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
|
||||
|
||||
__all__ = [
|
||||
"BaseReranker",
|
||||
"check_reranker_available",
|
||||
"get_reranker",
|
||||
"CrossEncoderReranker",
|
||||
"check_cross_encoder_available",
|
||||
"ONNXReranker",
|
||||
"check_onnx_reranker_available",
|
||||
]
|
||||
310
codex-lens/src/codexlens/semantic/reranker/api_reranker.py
Normal file
310
codex-lens/src/codexlens/semantic/reranker/api_reranker.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""API-based reranker using a remote HTTP provider.
|
||||
|
||||
Supported providers:
|
||||
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
|
||||
- Cohere: https://api.cohere.ai/v1/rerank
|
||||
- Jina: https://api.jina.ai/v1/rerank
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
|
||||
|
||||
|
||||
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
|
||||
"""Get environment variable with .env file fallback."""
|
||||
# Check os.environ first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Try loading from .env files
|
||||
try:
|
||||
from codexlens.env_config import get_env
|
||||
return get_env(key, workspace_root=workspace_root)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def check_httpx_available() -> tuple[bool, str | None]:
|
||||
try:
|
||||
import httpx # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"httpx not available: {exc}. Install with: pip install httpx"
|
||||
return True, None
|
||||
|
||||
|
||||
class APIReranker(BaseReranker):
|
||||
"""Reranker backed by a remote reranking HTTP API."""
|
||||
|
||||
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
|
||||
"siliconflow": {
|
||||
"api_base": "https://api.siliconflow.cn",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "BAAI/bge-reranker-v2-m3",
|
||||
},
|
||||
"cohere": {
|
||||
"api_base": "https://api.cohere.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "rerank-english-v3.0",
|
||||
},
|
||||
"jina": {
|
||||
"api_base": "https://api.jina.ai",
|
||||
"endpoint": "/v1/rerank",
|
||||
"default_model": "jina-reranker-v2-base-multilingual",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str = "siliconflow",
|
||||
model_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
backoff_base_s: float = 0.5,
|
||||
backoff_max_s: float = 8.0,
|
||||
env_api_key: str = _DEFAULT_ENV_API_KEY,
|
||||
workspace_root: Path | str | None = None,
|
||||
) -> None:
|
||||
ok, err = check_httpx_available()
|
||||
if not ok: # pragma: no cover - exercised via factory availability tests
|
||||
raise ImportError(err)
|
||||
|
||||
import httpx
|
||||
|
||||
self._workspace_root = Path(workspace_root) if workspace_root else None
|
||||
|
||||
self.provider = (provider or "").strip().lower()
|
||||
if self.provider not in self._PROVIDER_DEFAULTS:
|
||||
raise ValueError(
|
||||
f"Unknown reranker provider: {provider}. "
|
||||
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
|
||||
)
|
||||
|
||||
defaults = self._PROVIDER_DEFAULTS[self.provider]
|
||||
|
||||
# Load api_base from env with .env fallback
|
||||
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
|
||||
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
|
||||
self.endpoint = defaults["endpoint"]
|
||||
|
||||
# Load model from env with .env fallback
|
||||
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
|
||||
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
# Load API key from env with .env fallback
|
||||
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
|
||||
resolved_key = resolved_key.strip()
|
||||
if not resolved_key:
|
||||
raise ValueError(
|
||||
f"Missing API key for reranker provider '{self.provider}'. "
|
||||
f"Pass api_key=... or set ${env_api_key}."
|
||||
)
|
||||
self._api_key = resolved_key
|
||||
|
||||
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
|
||||
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
|
||||
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
|
||||
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.provider == "cohere":
|
||||
headers.setdefault("Cohere-Version", "2022-12-06")
|
||||
|
||||
self._client = httpx.Client(
|
||||
base_url=self.api_base,
|
||||
headers=headers,
|
||||
timeout=self.timeout_s,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return
|
||||
|
||||
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
|
||||
if retry_after_s is not None and retry_after_s > 0:
|
||||
time.sleep(min(float(retry_after_s), self.backoff_max_s))
|
||||
return
|
||||
|
||||
exp = self.backoff_base_s * (2**attempt)
|
||||
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
|
||||
time.sleep(min(self.backoff_max_s, exp + jitter))
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
|
||||
value = (headers.get("Retry-After") or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_retry_status(status_code: int) -> bool:
|
||||
return status_code == 429 or 500 <= status_code <= 599
|
||||
|
||||
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
response = self._client.post(self.endpoint, json=dict(payload))
|
||||
except Exception as exc: # httpx is optional at import-time
|
||||
last_exc = exc
|
||||
if attempt < self.max_retries:
|
||||
self._sleep_backoff(attempt)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' after "
|
||||
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
status = int(getattr(response, "status_code", 0) or 0)
|
||||
if status >= 400:
|
||||
body_preview = ""
|
||||
try:
|
||||
body_preview = (response.text or "").strip()
|
||||
except Exception:
|
||||
body_preview = ""
|
||||
if len(body_preview) > 300:
|
||||
body_preview = body_preview[:300] + "…"
|
||||
|
||||
if self._should_retry_status(status) and attempt < self.max_retries:
|
||||
retry_after = self._parse_retry_after_seconds(response.headers)
|
||||
logger.warning(
|
||||
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
|
||||
self.api_base,
|
||||
self.endpoint,
|
||||
status,
|
||||
attempt + 1,
|
||||
self.max_retries + 1,
|
||||
)
|
||||
self._sleep_backoff(attempt, retry_after_s=retry_after)
|
||||
continue
|
||||
|
||||
if status in {401, 403}:
|
||||
raise RuntimeError(
|
||||
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
|
||||
"Check your API key."
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
|
||||
f"Response: {body_preview or '<empty>'}"
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' is not valid JSON: "
|
||||
f"{type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise RuntimeError(
|
||||
f"Rerank response from provider '{self.provider}' must be a JSON object; "
|
||||
f"got {type(data).__name__}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
raise RuntimeError(
|
||||
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
|
||||
if not isinstance(results, list):
|
||||
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(expected)]
|
||||
filled = 0
|
||||
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
idx = item.get("index")
|
||||
score = item.get("relevance_score", item.get("score"))
|
||||
if idx is None or score is None:
|
||||
continue
|
||||
try:
|
||||
idx_int = int(idx)
|
||||
score_f = float(score)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if 0 <= idx_int < expected:
|
||||
scores[idx_int] = score_f
|
||||
filled += 1
|
||||
|
||||
if filled != expected:
|
||||
raise RuntimeError(
|
||||
f"Rerank response contained {filled}/{expected} scored documents; "
|
||||
"ensure top_n matches the number of documents."
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": list(documents),
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
return payload
|
||||
|
||||
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
payload = self._build_payload(query=query, documents=documents)
|
||||
data = self._request_json(payload)
|
||||
|
||||
results = data.get("results")
|
||||
return self._extract_scores_from_results(results, expected=len(documents))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
|
||||
) -> list[float]:
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[tuple[int, str]]] = {}
|
||||
for idx, (query, doc) in enumerate(pairs):
|
||||
grouped.setdefault(str(query), []).append((idx, str(doc)))
|
||||
|
||||
scores: list[float] = [0.0 for _ in range(len(pairs))]
|
||||
|
||||
for query, items in grouped.items():
|
||||
documents = [doc for _, doc in items]
|
||||
query_scores = self._rerank_one_query(query=query, documents=documents)
|
||||
for (orig_idx, _), score in zip(items, query_scores):
|
||||
scores[orig_idx] = float(score)
|
||||
|
||||
return scores
|
||||
36
codex-lens/src/codexlens/semantic/reranker/base.py
Normal file
36
codex-lens/src/codexlens/semantic/reranker/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Base class for rerankers.
|
||||
|
||||
Defines the interface that all rerankers must implement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
"""Base class for all rerankers.
|
||||
|
||||
All reranker implementations must inherit from this class and implement
|
||||
the abstract methods to ensure a consistent interface.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs.
|
||||
|
||||
Args:
|
||||
pairs: Sequence of (query, doc) string pairs to score.
|
||||
batch_size: Batch size for scoring.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair).
|
||||
"""
|
||||
...
|
||||
|
||||
138
codex-lens/src/codexlens/semantic/reranker/factory.py
Normal file
138
codex-lens/src/codexlens/semantic/reranker/factory.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Factory for creating rerankers.
|
||||
|
||||
Provides a unified interface for instantiating different reranker backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
|
||||
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
|
||||
"""Check whether a specific reranker backend can be used.
|
||||
|
||||
Notes:
|
||||
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
|
||||
- "onnx" uses Optimum + ONNX Runtime (pip install codexlens[reranker] or codexlens[reranker-onnx]).
|
||||
- "api" uses a remote reranking HTTP API (requires httpx).
|
||||
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "legacy":
|
||||
from .legacy import check_cross_encoder_available
|
||||
|
||||
return check_cross_encoder_available()
|
||||
|
||||
if backend == "onnx":
|
||||
from .onnx_reranker import check_onnx_reranker_available
|
||||
|
||||
return check_onnx_reranker_available()
|
||||
|
||||
if backend == "litellm":
|
||||
try:
|
||||
import ccw_litellm # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
|
||||
)
|
||||
|
||||
try:
|
||||
from .litellm_reranker import LiteLLMReranker # noqa: F401
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return False, f"LiteLLM reranker backend not available: {exc}"
|
||||
|
||||
return True, None
|
||||
|
||||
if backend == "api":
|
||||
from .api_reranker import check_httpx_available
|
||||
|
||||
return check_httpx_available()
|
||||
|
||||
return False, (
|
||||
f"Invalid reranker backend: {backend}. "
|
||||
"Must be 'onnx', 'api', 'litellm', or 'legacy'."
|
||||
)
|
||||
|
||||
|
||||
def get_reranker(
|
||||
backend: str = "onnx",
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
device: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseReranker:
|
||||
"""Factory function to create reranker based on backend.
|
||||
|
||||
Args:
|
||||
backend: Reranker backend to use. Options:
|
||||
- "onnx": Optimum + onnxruntime backend (default)
|
||||
- "api": HTTP API backend (remote providers)
|
||||
- "litellm": LiteLLM backend (LLM-based, experimental)
|
||||
- "legacy": sentence-transformers CrossEncoder backend (optional)
|
||||
model_name: Model identifier for model-based backends. Defaults depend on backend:
|
||||
- onnx: Xenova/ms-marco-MiniLM-L-6-v2
|
||||
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
|
||||
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
- litellm: default
|
||||
device: Optional device string for backends that support it (legacy only).
|
||||
**kwargs: Additional backend-specific arguments.
|
||||
|
||||
Returns:
|
||||
BaseReranker: Configured reranker instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If backend is not recognized.
|
||||
ImportError: If required backend dependencies are not installed or backend is unavailable.
|
||||
"""
|
||||
backend = (backend or "").strip().lower()
|
||||
|
||||
if backend == "onnx":
|
||||
ok, err = check_reranker_available("onnx")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .onnx_reranker import ONNXReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or ONNXReranker.DEFAULT_MODEL
|
||||
_ = device # Device selection is managed via ONNX Runtime providers.
|
||||
return ONNXReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "legacy":
|
||||
ok, err = check_reranker_available("legacy")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .legacy import CrossEncoderReranker
|
||||
|
||||
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
|
||||
|
||||
if backend == "litellm":
|
||||
ok, err = check_reranker_available("litellm")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .litellm_reranker import LiteLLMReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote LLM backends.
|
||||
resolved_model_name = (model_name or "").strip() or "default"
|
||||
return LiteLLMReranker(model=resolved_model_name, **kwargs)
|
||||
|
||||
if backend == "api":
|
||||
ok, err = check_reranker_available("api")
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
from .api_reranker import APIReranker
|
||||
|
||||
_ = device # Device selection is not applicable to remote HTTP backends.
|
||||
resolved_model_name = (model_name or "").strip() or None
|
||||
return APIReranker(model_name=resolved_model_name, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. Supported backends: 'onnx', 'api', 'litellm', 'legacy'"
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Optional cross-encoder reranker for second-stage search ranking.
|
||||
"""Legacy sentence-transformers cross-encoder reranker.
|
||||
|
||||
Install with: pip install codexlens[reranker]
|
||||
Install with: pip install codexlens[reranker-legacy]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -9,6 +9,8 @@ import logging
|
||||
import threading
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
@@ -25,10 +27,14 @@ except ImportError as exc: # pragma: no cover - optional dependency
|
||||
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||
if CROSS_ENCODER_AVAILABLE:
|
||||
return True, None
|
||||
return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]"
|
||||
return (
|
||||
False,
|
||||
_import_error
|
||||
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
|
||||
)
|
||||
|
||||
|
||||
class CrossEncoderReranker:
|
||||
class CrossEncoderReranker(BaseReranker):
|
||||
"""Cross-encoder reranker with lazy model loading."""
|
||||
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
@@ -83,4 +89,3 @@ class CrossEncoderReranker:
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||
return [float(s) for s in scores]
|
||||
|
||||
214
codex-lens/src/codexlens/semantic/reranker/litellm_reranker.py
Normal file
214
codex-lens/src/codexlens/semantic/reranker/litellm_reranker.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Experimental LiteLLM reranker backend.
|
||||
|
||||
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
|
||||
relevance of a single (query, document) pair per request.
|
||||
|
||||
Notes:
|
||||
- This backend is experimental and may be slow/expensive compared to local
|
||||
rerankers.
|
||||
- It relies on `ccw-litellm` for a unified LLM API across providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
|
||||
|
||||
|
||||
def _coerce_score_to_unit_interval(score: float) -> float:
|
||||
"""Coerce a numeric score into [0, 1].
|
||||
|
||||
The prompt asks for a float in [0, 1], but some models may respond with 0-10
|
||||
or 0-100 scales. This function attempts a conservative normalization.
|
||||
"""
|
||||
if 0.0 <= score <= 1.0:
|
||||
return score
|
||||
if 0.0 <= score <= 10.0:
|
||||
return score / 10.0
|
||||
if 0.0 <= score <= 100.0:
|
||||
return score / 100.0
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def _extract_score(text: str) -> float | None:
|
||||
"""Extract a numeric relevance score from an LLM response."""
|
||||
content = (text or "").strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Prefer JSON if present.
|
||||
if "{" in content and "}" in content:
|
||||
try:
|
||||
start = content.index("{")
|
||||
end = content.rindex("}") + 1
|
||||
payload = json.loads(content[start:end])
|
||||
if isinstance(payload, dict) and "score" in payload:
|
||||
return float(payload["score"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
match = _NUMBER_RE.search(content)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
return float(match.group(0))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class LiteLLMReranker(BaseReranker):
|
||||
"""Experimental reranker that uses a LiteLLM-compatible model.
|
||||
|
||||
This reranker scores each (query, doc) pair in isolation (single-pair mode)
|
||||
to improve prompt reliability across providers.
|
||||
"""
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a relevance scoring assistant.\n"
|
||||
"Given a search query and a document snippet, output a single numeric "
|
||||
"relevance score between 0 and 1.\n\n"
|
||||
"Scoring guidance:\n"
|
||||
"- 1.0: The document directly answers the query.\n"
|
||||
"- 0.5: The document is partially relevant.\n"
|
||||
"- 0.0: The document is unrelated.\n\n"
|
||||
"Output requirements:\n"
|
||||
"- Output ONLY the number (e.g., 0.73).\n"
|
||||
"- Do not include any other text."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "default",
|
||||
*,
|
||||
requests_per_minute: float | None = None,
|
||||
min_interval_seconds: float | None = None,
|
||||
default_score: float = 0.0,
|
||||
max_doc_chars: int = 8000,
|
||||
**litellm_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the reranker.
|
||||
|
||||
Args:
|
||||
model: Model name from ccw-litellm configuration (default: "default").
|
||||
requests_per_minute: Optional rate limit in requests per minute.
|
||||
min_interval_seconds: Optional minimum interval between requests. If set,
|
||||
it takes precedence over requests_per_minute.
|
||||
default_score: Score to use when an API call fails or parsing fails.
|
||||
max_doc_chars: Maximum number of document characters to include in the prompt.
|
||||
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
|
||||
|
||||
Raises:
|
||||
ImportError: If ccw-litellm is not installed.
|
||||
ValueError: If model is blank.
|
||||
"""
|
||||
self.model_name = (model or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model cannot be blank")
|
||||
|
||||
self.default_score = float(default_score)
|
||||
|
||||
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
|
||||
|
||||
if min_interval_seconds is not None:
|
||||
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
|
||||
elif requests_per_minute is not None and float(requests_per_minute) > 0:
|
||||
self._min_interval_seconds = 60.0 / float(requests_per_minute)
|
||||
else:
|
||||
self._min_interval_seconds = 0.0
|
||||
|
||||
# Prefer deterministic output by default; allow overrides via kwargs.
|
||||
litellm_kwargs = dict(litellm_kwargs)
|
||||
litellm_kwargs.setdefault("temperature", 0.0)
|
||||
litellm_kwargs.setdefault("max_tokens", 16)
|
||||
|
||||
try:
|
||||
from ccw_litellm import ChatMessage, LiteLLMClient
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
raise ImportError(
|
||||
"ccw-litellm not installed. Install with: pip install ccw-litellm"
|
||||
) from exc
|
||||
|
||||
self._ChatMessage = ChatMessage
|
||||
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
|
||||
|
||||
self._lock = threading.RLock()
|
||||
self._last_request_at = 0.0
|
||||
|
||||
def _sanitize_text(self, text: str) -> str:
|
||||
# Keep consistent with LiteLLMEmbedderWrapper workaround.
|
||||
if text.startswith("import"):
|
||||
return " " + text
|
||||
return text
|
||||
|
||||
def _rate_limit(self) -> None:
|
||||
if self._min_interval_seconds <= 0:
|
||||
return
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_at
|
||||
if elapsed < self._min_interval_seconds:
|
||||
time.sleep(self._min_interval_seconds - elapsed)
|
||||
self._last_request_at = time.monotonic()
|
||||
|
||||
def _build_user_prompt(self, query: str, doc: str) -> str:
|
||||
sanitized_query = self._sanitize_text(query or "")
|
||||
sanitized_doc = self._sanitize_text(doc or "")
|
||||
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
|
||||
sanitized_doc = sanitized_doc[: self.max_doc_chars]
|
||||
|
||||
return (
|
||||
"Query:\n"
|
||||
f"{sanitized_query}\n\n"
|
||||
"Document:\n"
|
||||
f"{sanitized_doc}\n\n"
|
||||
"Return the relevance score (0 to 1) as a single number:"
|
||||
)
|
||||
|
||||
def _score_single_pair(self, query: str, doc: str) -> float:
|
||||
messages = [
|
||||
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
|
||||
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
|
||||
]
|
||||
|
||||
try:
|
||||
self._rate_limit()
|
||||
response = self._client.chat(messages)
|
||||
except Exception as exc:
|
||||
logger.debug("LiteLLM reranker request failed: %s", exc)
|
||||
return self.default_score
|
||||
|
||||
raw = getattr(response, "content", "") or ""
|
||||
score = _extract_score(raw)
|
||||
if score is None:
|
||||
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
|
||||
return self.default_score
|
||||
return _coerce_score_to_unit_interval(float(score))
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with per-pair LLM calls."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
|
||||
scores: list[float] = []
|
||||
for i in range(0, len(pairs), bs):
|
||||
batch = pairs[i : i + bs]
|
||||
for query, doc in batch:
|
||||
scores.append(self._score_single_pair(query, doc))
|
||||
return scores
|
||||
268
codex-lens/src/codexlens/semantic/reranker/onnx_reranker.py
Normal file
268
codex-lens/src/codexlens/semantic/reranker/onnx_reranker.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Optimum + ONNX Runtime reranker backend.
|
||||
|
||||
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
|
||||
classification models. It is designed to run without requiring PyTorch at
|
||||
runtime by using numpy tensors and ONNX Runtime execution providers.
|
||||
|
||||
Install (CPU):
|
||||
pip install onnxruntime optimum[onnxruntime] transformers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
from .base import BaseReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_onnx_reranker_available() -> tuple[bool, str | None]:
|
||||
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
|
||||
try:
|
||||
import numpy # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return False, f"numpy not available: {exc}. Install with: pip install numpy"
|
||||
|
||||
try:
|
||||
import onnxruntime # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
|
||||
)
|
||||
|
||||
try:
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
return (
|
||||
False,
|
||||
f"transformers not available: {exc}. Install with: pip install transformers",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
|
||||
for i in range(0, len(items), batch_size):
|
||||
yield items[i : i + batch_size]
|
||||
|
||||
|
||||
class ONNXReranker(BaseReranker):
|
||||
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
|
||||
|
||||
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str | None = None,
|
||||
*,
|
||||
use_gpu: bool = True,
|
||||
providers: list[Any] | None = None,
|
||||
max_length: int | None = None,
|
||||
) -> None:
|
||||
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.use_gpu = bool(use_gpu)
|
||||
self.providers = providers
|
||||
|
||||
self.max_length = int(max_length) if max_length is not None else None
|
||||
|
||||
self._tokenizer: Any | None = None
|
||||
self._model: Any | None = None
|
||||
self._model_input_names: set[str] | None = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
ok, err = check_onnx_reranker_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None and self._tokenizer is not None:
|
||||
return
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from optimum.onnxruntime import ORTModelForSequenceClassification
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if self.providers is None:
|
||||
from ..gpu_support import get_optimal_providers
|
||||
|
||||
# Include device_id options for DirectML/CUDA selection when available.
|
||||
self.providers = get_optimal_providers(
|
||||
use_gpu=self.use_gpu, with_device_options=True
|
||||
)
|
||||
|
||||
# Some Optimum versions accept `providers`, others accept a single `provider`.
|
||||
# Prefer passing the full providers list, with a conservative fallback.
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
try:
|
||||
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
|
||||
if "providers" in params:
|
||||
model_kwargs["providers"] = self.providers
|
||||
elif "provider" in params:
|
||||
provider_name = "CPUExecutionProvider"
|
||||
if self.providers:
|
||||
first = self.providers[0]
|
||||
provider_name = first[0] if isinstance(first, tuple) else str(first)
|
||||
model_kwargs["provider"] = provider_name
|
||||
except Exception:
|
||||
model_kwargs = {}
|
||||
|
||||
try:
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(
|
||||
self.model_name,
|
||||
**model_kwargs,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for older Optimum versions: retry without provider arguments.
|
||||
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
|
||||
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
|
||||
|
||||
# Cache model input names to filter tokenizer outputs defensively.
|
||||
input_names: set[str] | None = None
|
||||
for attr in ("input_names", "model_input_names"):
|
||||
names = getattr(self._model, attr, None)
|
||||
if isinstance(names, (list, tuple)) and names:
|
||||
input_names = {str(n) for n in names}
|
||||
break
|
||||
if input_names is None:
|
||||
try:
|
||||
session = getattr(self._model, "model", None)
|
||||
if session is not None and hasattr(session, "get_inputs"):
|
||||
input_names = {i.name for i in session.get_inputs()}
|
||||
except Exception:
|
||||
input_names = None
|
||||
self._model_input_names = input_names
|
||||
|
||||
@staticmethod
|
||||
def _sigmoid(x: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
x = np.clip(x, -50.0, 50.0)
|
||||
return 1.0 / (1.0 + np.exp(-x))
|
||||
|
||||
@staticmethod
|
||||
def _select_relevance_logit(logits: "Any") -> "Any":
|
||||
import numpy as np
|
||||
|
||||
arr = np.asarray(logits)
|
||||
if arr.ndim == 0:
|
||||
return arr.reshape(1)
|
||||
if arr.ndim == 1:
|
||||
return arr
|
||||
if arr.ndim >= 2:
|
||||
# Common cases:
|
||||
# - Regression: (batch, 1)
|
||||
# - Binary classification: (batch, 2)
|
||||
if arr.shape[-1] == 1:
|
||||
return arr[..., 0]
|
||||
if arr.shape[-1] == 2:
|
||||
# Convert 2-logit softmax into a single logit via difference.
|
||||
return arr[..., 1] - arr[..., 0]
|
||||
return arr.max(axis=-1)
|
||||
return arr.reshape(-1)
|
||||
|
||||
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
|
||||
if self._tokenizer is None:
|
||||
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
|
||||
|
||||
queries = [q for q, _ in batch]
|
||||
docs = [d for _, d in batch]
|
||||
|
||||
tokenizer_kwargs: dict[str, Any] = {
|
||||
"text": queries,
|
||||
"text_pair": docs,
|
||||
"padding": True,
|
||||
"truncation": True,
|
||||
"return_tensors": "np",
|
||||
}
|
||||
|
||||
max_len = self.max_length
|
||||
if max_len is None:
|
||||
try:
|
||||
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
|
||||
if 0 < model_max < 10_000:
|
||||
max_len = model_max
|
||||
else:
|
||||
max_len = 512
|
||||
except Exception:
|
||||
max_len = 512
|
||||
if max_len is not None and max_len > 0:
|
||||
tokenizer_kwargs["max_length"] = int(max_len)
|
||||
|
||||
encoded = self._tokenizer(**tokenizer_kwargs)
|
||||
inputs = dict(encoded)
|
||||
|
||||
# Some models do not accept token_type_ids; filter to known input names if available.
|
||||
if self._model_input_names:
|
||||
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
|
||||
if self._model is None:
|
||||
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
|
||||
|
||||
outputs = self._model(**inputs)
|
||||
if hasattr(outputs, "logits"):
|
||||
return outputs.logits
|
||||
if isinstance(outputs, dict) and "logits" in outputs:
|
||||
return outputs["logits"]
|
||||
if isinstance(outputs, (list, tuple)) and outputs:
|
||||
return outputs[0]
|
||||
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> list[float]:
|
||||
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
import numpy as np
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores: list[float] = []
|
||||
|
||||
for batch in _iter_batches(list(pairs), bs):
|
||||
inputs = self._tokenize_batch(batch)
|
||||
logits = self._forward_logits(inputs)
|
||||
rel_logits = self._select_relevance_logit(logits)
|
||||
probs = self._sigmoid(rel_logits)
|
||||
probs = np.clip(probs, 0.0, 1.0)
|
||||
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
|
||||
|
||||
if len(scores) != len(pairs):
|
||||
logger.debug(
|
||||
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
|
||||
)
|
||||
return scores[: len(pairs)]
|
||||
|
||||
return scores
|
||||
17
codex-lens/src/codexlens/watcher/__init__.py
Normal file
17
codex-lens/src/codexlens/watcher/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""File watcher module for real-time index updates."""
|
||||
|
||||
from .events import ChangeType, FileEvent, IndexResult, WatcherConfig, WatcherStats
|
||||
from .file_watcher import FileWatcher
|
||||
from .incremental_indexer import IncrementalIndexer
|
||||
from .manager import WatcherManager
|
||||
|
||||
__all__ = [
|
||||
"ChangeType",
|
||||
"FileEvent",
|
||||
"IndexResult",
|
||||
"WatcherConfig",
|
||||
"WatcherStats",
|
||||
"FileWatcher",
|
||||
"IncrementalIndexer",
|
||||
"WatcherManager",
|
||||
]
|
||||
54
codex-lens/src/codexlens/watcher/events.py
Normal file
54
codex-lens/src/codexlens/watcher/events.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Event types for file watcher."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Set
|
||||
|
||||
|
||||
class ChangeType(Enum):
|
||||
"""Type of file system change."""
|
||||
CREATED = "created"
|
||||
MODIFIED = "modified"
|
||||
DELETED = "deleted"
|
||||
MOVED = "moved"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileEvent:
|
||||
"""A file system change event."""
|
||||
path: Path
|
||||
change_type: ChangeType
|
||||
timestamp: float
|
||||
old_path: Optional[Path] = None # For MOVED events
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatcherConfig:
|
||||
"""Configuration for file watcher."""
|
||||
debounce_ms: int = 1000
|
||||
ignored_patterns: Set[str] = field(default_factory=lambda: {
|
||||
".git", ".venv", "venv", "node_modules",
|
||||
"__pycache__", ".codexlens", ".idea", ".vscode",
|
||||
})
|
||||
languages: Optional[List[str]] = None # None = all supported
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
"""Result of processing file changes."""
|
||||
files_indexed: int = 0
|
||||
files_removed: int = 0
|
||||
symbols_added: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WatcherStats:
|
||||
"""Runtime statistics for watcher."""
|
||||
files_watched: int = 0
|
||||
events_processed: int = 0
|
||||
last_event_time: Optional[float] = None
|
||||
is_running: bool = False
|
||||
245
codex-lens/src/codexlens/watcher/file_watcher.py
Normal file
245
codex-lens/src/codexlens/watcher/file_watcher.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""File system watcher using watchdog library."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from .events import ChangeType, FileEvent, WatcherConfig
|
||||
from ..config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CodexLensHandler(FileSystemEventHandler):
|
||||
"""Internal handler for watchdog events."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
watcher: "FileWatcher",
|
||||
on_event: Callable[[FileEvent], None],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._watcher = watcher
|
||||
self._on_event = on_event
|
||||
|
||||
def on_created(self, event) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._emit(event.src_path, ChangeType.CREATED)
|
||||
|
||||
def on_modified(self, event) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._emit(event.src_path, ChangeType.MODIFIED)
|
||||
|
||||
def on_deleted(self, event) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._emit(event.src_path, ChangeType.DELETED)
|
||||
|
||||
def on_moved(self, event) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
self._emit(event.dest_path, ChangeType.MOVED, old_path=event.src_path)
|
||||
|
||||
def _emit(
|
||||
self,
|
||||
path: str,
|
||||
change_type: ChangeType,
|
||||
old_path: Optional[str] = None,
|
||||
) -> None:
|
||||
path_obj = Path(path)
|
||||
|
||||
# Filter out files that should not be indexed
|
||||
if not self._watcher._should_index_file(path_obj):
|
||||
return
|
||||
|
||||
event = FileEvent(
|
||||
path=path_obj,
|
||||
change_type=change_type,
|
||||
timestamp=time.time(),
|
||||
old_path=Path(old_path) if old_path else None,
|
||||
)
|
||||
self._on_event(event)
|
||||
|
||||
|
||||
class FileWatcher:
|
||||
"""File system watcher for monitoring directory changes.
|
||||
|
||||
Uses watchdog library for cross-platform file system monitoring.
|
||||
Events are forwarded to the on_changes callback.
|
||||
|
||||
Example:
|
||||
def handle_changes(events: List[FileEvent]) -> None:
|
||||
for event in events:
|
||||
print(f"{event.change_type}: {event.path}")
|
||||
|
||||
watcher = FileWatcher(Path("."), WatcherConfig(), handle_changes)
|
||||
watcher.start()
|
||||
watcher.wait() # Block until stopped
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_path: Path,
|
||||
config: WatcherConfig,
|
||||
on_changes: Callable[[List[FileEvent]], None],
|
||||
) -> None:
|
||||
"""Initialize file watcher.
|
||||
|
||||
Args:
|
||||
root_path: Directory to watch recursively
|
||||
config: Watcher configuration
|
||||
on_changes: Callback invoked with batched events
|
||||
"""
|
||||
self.root_path = Path(root_path).resolve()
|
||||
self.config = config
|
||||
self.on_changes = on_changes
|
||||
|
||||
self._observer: Optional[Observer] = None
|
||||
self._running = False
|
||||
self._stop_event = threading.Event()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Event queue for batching
|
||||
self._event_queue: List[FileEvent] = []
|
||||
self._queue_lock = threading.Lock()
|
||||
|
||||
# Debounce thread
|
||||
self._debounce_thread: Optional[threading.Thread] = None
|
||||
|
||||
# Config instance for language checking
|
||||
self._codexlens_config = Config()
|
||||
|
||||
def _should_index_file(self, path: Path) -> bool:
|
||||
"""Check if file should be indexed based on extension and ignore patterns.
|
||||
|
||||
Args:
|
||||
path: File path to check
|
||||
|
||||
Returns:
|
||||
True if file should be indexed, False otherwise
|
||||
"""
|
||||
# Check against ignore patterns
|
||||
parts = path.parts
|
||||
for pattern in self.config.ignored_patterns:
|
||||
if pattern in parts:
|
||||
return False
|
||||
|
||||
# Check extension against supported languages
|
||||
language = self._codexlens_config.language_for_path(path)
|
||||
return language is not None
|
||||
|
||||
def _on_raw_event(self, event: FileEvent) -> None:
|
||||
"""Handle raw event from watchdog handler."""
|
||||
with self._queue_lock:
|
||||
self._event_queue.append(event)
|
||||
# Debouncing is handled by background thread
|
||||
|
||||
def _debounce_loop(self) -> None:
|
||||
"""Background thread for debounced event batching."""
|
||||
while self._running:
|
||||
time.sleep(self.config.debounce_ms / 1000.0)
|
||||
self._flush_events()
|
||||
|
||||
def _flush_events(self) -> None:
|
||||
"""Flush queued events with deduplication."""
|
||||
with self._queue_lock:
|
||||
if not self._event_queue:
|
||||
return
|
||||
|
||||
# Deduplicate: keep latest event per path
|
||||
deduped: Dict[Path, FileEvent] = {}
|
||||
for event in self._event_queue:
|
||||
deduped[event.path] = event
|
||||
|
||||
events = list(deduped.values())
|
||||
self._event_queue.clear()
|
||||
|
||||
if events:
|
||||
try:
|
||||
self.on_changes(events)
|
||||
except Exception as exc:
|
||||
logger.error("Error in on_changes callback: %s", exc)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start watching the directory.
|
||||
|
||||
Non-blocking. Use wait() to block until stopped.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("Watcher already running")
|
||||
return
|
||||
|
||||
if not self.root_path.exists():
|
||||
raise ValueError(f"Root path does not exist: {self.root_path}")
|
||||
|
||||
self._observer = Observer()
|
||||
handler = _CodexLensHandler(self, self._on_raw_event)
|
||||
self._observer.schedule(handler, str(self.root_path), recursive=True)
|
||||
|
||||
self._running = True
|
||||
self._stop_event.clear()
|
||||
self._observer.start()
|
||||
|
||||
# Start debounce thread
|
||||
self._debounce_thread = threading.Thread(
|
||||
target=self._debounce_loop,
|
||||
daemon=True,
|
||||
name="FileWatcher-Debounce",
|
||||
)
|
||||
self._debounce_thread.start()
|
||||
|
||||
logger.info("Started watching: %s", self.root_path)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop watching the directory.
|
||||
|
||||
Gracefully stops the observer and flushes remaining events.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stop_event.set()
|
||||
|
||||
if self._observer:
|
||||
self._observer.stop()
|
||||
self._observer.join(timeout=5.0)
|
||||
self._observer = None
|
||||
|
||||
# Wait for debounce thread to finish
|
||||
if self._debounce_thread and self._debounce_thread.is_alive():
|
||||
self._debounce_thread.join(timeout=2.0)
|
||||
self._debounce_thread = None
|
||||
|
||||
# Flush any remaining events
|
||||
self._flush_events()
|
||||
|
||||
logger.info("Stopped watching: %s", self.root_path)
|
||||
|
||||
def wait(self) -> None:
|
||||
"""Block until watcher is stopped.
|
||||
|
||||
Use Ctrl+C or call stop() from another thread to unblock.
|
||||
"""
|
||||
try:
|
||||
while self._running:
|
||||
self._stop_event.wait(timeout=1.0)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt, stopping watcher...")
|
||||
self.stop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if watcher is currently running."""
|
||||
return self._running
|
||||
359
codex-lens/src/codexlens/watcher/incremental_indexer.py
Normal file
359
codex-lens/src/codexlens/watcher/incremental_indexer.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""Incremental indexer for processing file changes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.parsers.factory import ParserFactory
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
from .events import ChangeType, FileEvent, IndexResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileIndexResult:
|
||||
"""Result of indexing a single file."""
|
||||
path: Path
|
||||
symbols_count: int
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class IncrementalIndexer:
|
||||
"""Incremental indexer for processing file change events.
|
||||
|
||||
Processes file events (create, modify, delete, move) and updates
|
||||
the corresponding index databases incrementally.
|
||||
|
||||
Reuses existing infrastructure:
|
||||
- ParserFactory for symbol extraction
|
||||
- DirIndexStore for per-directory storage
|
||||
- GlobalSymbolIndex for cross-file symbols
|
||||
- PathMapper for source-to-index path conversion
|
||||
|
||||
Example:
|
||||
indexer = IncrementalIndexer(registry, mapper, config)
|
||||
result = indexer.process_changes([
|
||||
FileEvent(Path("foo.py"), ChangeType.MODIFIED, time.time()),
|
||||
])
|
||||
print(f"Indexed {result.files_indexed} files")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: RegistryStore,
|
||||
mapper: PathMapper,
|
||||
config: Optional[Config] = None,
|
||||
) -> None:
|
||||
"""Initialize incremental indexer.
|
||||
|
||||
Args:
|
||||
registry: Global project registry
|
||||
mapper: Path mapper for source-to-index conversion
|
||||
config: CodexLens configuration (uses defaults if None)
|
||||
"""
|
||||
self.registry = registry
|
||||
self.mapper = mapper
|
||||
self.config = config or Config()
|
||||
self.parser_factory = ParserFactory(self.config)
|
||||
|
||||
self._global_index: Optional[GlobalSymbolIndex] = None
|
||||
self._dir_stores: dict[Path, DirIndexStore] = {}
|
||||
self._lock = __import__("threading").RLock()
|
||||
|
||||
def _get_global_index(self, index_root: Path) -> Optional[GlobalSymbolIndex]:
|
||||
"""Get or create global symbol index."""
|
||||
if not self.config.global_symbol_index_enabled:
|
||||
return None
|
||||
|
||||
if self._global_index is None:
|
||||
global_db_path = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
if global_db_path.exists():
|
||||
self._global_index = GlobalSymbolIndex(global_db_path)
|
||||
|
||||
return self._global_index
|
||||
|
||||
def _get_dir_store(self, dir_path: Path) -> Optional[DirIndexStore]:
|
||||
"""Get DirIndexStore for a directory, if indexed."""
|
||||
with self._lock:
|
||||
if dir_path in self._dir_stores:
|
||||
return self._dir_stores[dir_path]
|
||||
|
||||
index_db = self.mapper.source_to_index_db(dir_path)
|
||||
if not index_db.exists():
|
||||
logger.debug("No index found for directory: %s", dir_path)
|
||||
return None
|
||||
|
||||
# Get index root for global index
|
||||
index_root = self.mapper.source_to_index_dir(
|
||||
self.mapper.get_project_root(dir_path) or dir_path
|
||||
)
|
||||
global_index = self._get_global_index(index_root)
|
||||
|
||||
store = DirIndexStore(
|
||||
index_db,
|
||||
config=self.config,
|
||||
global_index=global_index,
|
||||
)
|
||||
self._dir_stores[dir_path] = store
|
||||
return store
|
||||
|
||||
def process_changes(self, events: List[FileEvent]) -> IndexResult:
|
||||
"""Process a batch of file change events.
|
||||
|
||||
Args:
|
||||
events: List of file events to process
|
||||
|
||||
Returns:
|
||||
IndexResult with statistics
|
||||
"""
|
||||
result = IndexResult()
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
if event.change_type == ChangeType.CREATED:
|
||||
file_result = self._index_file(event.path)
|
||||
if file_result.success:
|
||||
result.files_indexed += 1
|
||||
result.symbols_added += file_result.symbols_count
|
||||
else:
|
||||
result.errors.append(file_result.error or f"Failed to index: {event.path}")
|
||||
|
||||
elif event.change_type == ChangeType.MODIFIED:
|
||||
file_result = self._index_file(event.path)
|
||||
if file_result.success:
|
||||
result.files_indexed += 1
|
||||
result.symbols_added += file_result.symbols_count
|
||||
else:
|
||||
result.errors.append(file_result.error or f"Failed to index: {event.path}")
|
||||
|
||||
elif event.change_type == ChangeType.DELETED:
|
||||
self._remove_file(event.path)
|
||||
result.files_removed += 1
|
||||
|
||||
elif event.change_type == ChangeType.MOVED:
|
||||
# Remove from old location, add at new location
|
||||
if event.old_path:
|
||||
self._remove_file(event.old_path)
|
||||
result.files_removed += 1
|
||||
file_result = self._index_file(event.path)
|
||||
if file_result.success:
|
||||
result.files_indexed += 1
|
||||
result.symbols_added += file_result.symbols_count
|
||||
else:
|
||||
result.errors.append(file_result.error or f"Failed to index: {event.path}")
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = f"Error processing {event.path}: {type(exc).__name__}: {exc}"
|
||||
logger.error(error_msg)
|
||||
result.errors.append(error_msg)
|
||||
|
||||
return result
|
||||
|
||||
def _index_file(self, path: Path) -> FileIndexResult:
|
||||
"""Index a single file.
|
||||
|
||||
Args:
|
||||
path: Path to the file to index
|
||||
|
||||
Returns:
|
||||
FileIndexResult with status
|
||||
"""
|
||||
path = Path(path).resolve()
|
||||
|
||||
# Check if file exists
|
||||
if not path.exists():
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=f"File not found: {path}",
|
||||
)
|
||||
|
||||
# Check if language is supported
|
||||
language = self.config.language_for_path(path)
|
||||
if not language:
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=f"Unsupported language for: {path}",
|
||||
)
|
||||
|
||||
# Get directory store
|
||||
dir_path = path.parent
|
||||
store = self._get_dir_store(dir_path)
|
||||
if store is None:
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=f"Directory not indexed: {dir_path}",
|
||||
)
|
||||
|
||||
# Read file content with fallback encodings
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.debug("UTF-8 decode failed for %s, using fallback with errors='ignore'", path)
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8", errors="ignore")
|
||||
except Exception as exc:
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=f"Failed to read file: {exc}",
|
||||
)
|
||||
except Exception as exc:
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=f"Failed to read file: {exc}",
|
||||
)
|
||||
|
||||
# Parse symbols
|
||||
try:
|
||||
parser = self.parser_factory.get_parser(language)
|
||||
indexed_file = parser.parse(content, path)
|
||||
except Exception as exc:
|
||||
error_msg = f"Failed to parse {path}: {type(exc).__name__}: {exc}"
|
||||
logger.error(error_msg)
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
# Update store with retry logic for transient database errors
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
store.add_file(
|
||||
name=path.name,
|
||||
full_path=str(path),
|
||||
content=content,
|
||||
language=language,
|
||||
symbols=indexed_file.symbols,
|
||||
relationships=indexed_file.relationships,
|
||||
)
|
||||
|
||||
# Update merkle root
|
||||
store.update_merkle_root()
|
||||
|
||||
logger.debug("Indexed file: %s (%d symbols)", path, len(indexed_file.symbols))
|
||||
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=len(indexed_file.symbols),
|
||||
success=True,
|
||||
)
|
||||
|
||||
except __import__("sqlite3").OperationalError as exc:
|
||||
# Transient database errors (e.g., database locked)
|
||||
if attempt < max_retries - 1:
|
||||
import time
|
||||
wait_time = 0.1 * (2 ** attempt) # Exponential backoff
|
||||
logger.debug("Database operation failed (attempt %d/%d), retrying in %.2fs: %s",
|
||||
attempt + 1, max_retries, wait_time, exc)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
error_msg = f"Failed to store {path} after {max_retries} attempts: {exc}"
|
||||
logger.error(error_msg)
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=error_msg,
|
||||
)
|
||||
except Exception as exc:
|
||||
error_msg = f"Failed to store {path}: {type(exc).__name__}: {exc}"
|
||||
logger.error(error_msg)
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
# Should never reach here
|
||||
return FileIndexResult(
|
||||
path=path,
|
||||
symbols_count=0,
|
||||
success=False,
|
||||
error="Unexpected error in indexing loop",
|
||||
)
|
||||
|
||||
def _remove_file(self, path: Path) -> bool:
|
||||
"""Remove a file from the index.
|
||||
|
||||
Args:
|
||||
path: Path to the file to remove
|
||||
|
||||
Returns:
|
||||
True if removed successfully
|
||||
"""
|
||||
path = Path(path).resolve()
|
||||
dir_path = path.parent
|
||||
|
||||
store = self._get_dir_store(dir_path)
|
||||
if store is None:
|
||||
logger.debug("Cannot remove file, directory not indexed: %s", dir_path)
|
||||
return False
|
||||
|
||||
# Retry logic for transient database errors
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
store.remove_file(str(path))
|
||||
store.update_merkle_root()
|
||||
logger.debug("Removed file from index: %s", path)
|
||||
return True
|
||||
|
||||
except __import__("sqlite3").OperationalError as exc:
|
||||
# Transient database errors (e.g., database locked)
|
||||
if attempt < max_retries - 1:
|
||||
import time
|
||||
wait_time = 0.1 * (2 ** attempt) # Exponential backoff
|
||||
logger.debug("Database operation failed (attempt %d/%d), retrying in %.2fs: %s",
|
||||
attempt + 1, max_retries, wait_time, exc)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
logger.error("Failed to remove %s after %d attempts: %s", path, max_retries, exc)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.error("Failed to remove %s: %s", path, exc)
|
||||
return False
|
||||
|
||||
# Should never reach here
|
||||
return False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all open stores."""
|
||||
with self._lock:
|
||||
for store in self._dir_stores.values():
|
||||
try:
|
||||
store.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._dir_stores.clear()
|
||||
|
||||
if self._global_index:
|
||||
try:
|
||||
self._global_index.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._global_index = None
|
||||
194
codex-lens/src/codexlens/watcher/manager.py
Normal file
194
codex-lens/src/codexlens/watcher/manager.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Watcher manager for coordinating file watching and incremental indexing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
from .events import FileEvent, IndexResult, WatcherConfig, WatcherStats
|
||||
from .file_watcher import FileWatcher
|
||||
from .incremental_indexer import IncrementalIndexer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatcherManager:
|
||||
"""High-level manager for file watching and incremental indexing.
|
||||
|
||||
Coordinates FileWatcher and IncrementalIndexer with:
|
||||
- Lifecycle management (start/stop)
|
||||
- Signal handling (SIGINT/SIGTERM)
|
||||
- Statistics tracking
|
||||
- Graceful shutdown
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root_path: Path,
|
||||
config: Optional[Config] = None,
|
||||
watcher_config: Optional[WatcherConfig] = None,
|
||||
on_indexed: Optional[Callable[[IndexResult], None]] = None,
|
||||
) -> None:
|
||||
self.root_path = Path(root_path).resolve()
|
||||
self.config = config or Config()
|
||||
self.watcher_config = watcher_config or WatcherConfig()
|
||||
self.on_indexed = on_indexed
|
||||
|
||||
self._registry: Optional[RegistryStore] = None
|
||||
self._mapper: Optional[PathMapper] = None
|
||||
self._watcher: Optional[FileWatcher] = None
|
||||
self._indexer: Optional[IncrementalIndexer] = None
|
||||
|
||||
self._running = False
|
||||
self._stop_event = threading.Event()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Statistics
|
||||
self._stats = WatcherStats()
|
||||
self._original_sigint = None
|
||||
self._original_sigterm = None
|
||||
|
||||
def _handle_changes(self, events: List[FileEvent]) -> None:
|
||||
"""Handle file change events from watcher."""
|
||||
if not self._indexer or not events:
|
||||
return
|
||||
|
||||
logger.info("Processing %d file changes", len(events))
|
||||
result = self._indexer.process_changes(events)
|
||||
|
||||
# Update stats
|
||||
self._stats.events_processed += len(events)
|
||||
self._stats.last_event_time = time.time()
|
||||
|
||||
if result.files_indexed > 0 or result.files_removed > 0:
|
||||
logger.info(
|
||||
"Indexed %d files, removed %d files, %d errors",
|
||||
result.files_indexed, result.files_removed, len(result.errors)
|
||||
)
|
||||
|
||||
if self.on_indexed:
|
||||
try:
|
||||
self.on_indexed(result)
|
||||
except Exception as exc:
|
||||
logger.error("Error in on_indexed callback: %s", exc)
|
||||
|
||||
def _signal_handler(self, signum, frame) -> None:
|
||||
"""Handle shutdown signals."""
|
||||
logger.info("Received signal %d, stopping...", signum)
|
||||
self.stop()
|
||||
|
||||
def _install_signal_handlers(self) -> None:
|
||||
"""Install signal handlers for graceful shutdown."""
|
||||
try:
|
||||
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
|
||||
if hasattr(signal, 'SIGTERM'):
|
||||
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
|
||||
except (ValueError, OSError):
|
||||
# Signal handling not available (e.g., not main thread)
|
||||
pass
|
||||
|
||||
def _restore_signal_handlers(self) -> None:
|
||||
"""Restore original signal handlers."""
|
||||
try:
|
||||
if self._original_sigint is not None:
|
||||
signal.signal(signal.SIGINT, self._original_sigint)
|
||||
if self._original_sigterm is not None and hasattr(signal, 'SIGTERM'):
|
||||
signal.signal(signal.SIGTERM, self._original_sigterm)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start watching and indexing."""
|
||||
with self._lock:
|
||||
if self._running:
|
||||
logger.warning("WatcherManager already running")
|
||||
return
|
||||
|
||||
# Validate path
|
||||
if not self.root_path.exists():
|
||||
raise ValueError(f"Root path does not exist: {self.root_path}")
|
||||
|
||||
# Initialize components
|
||||
self._registry = RegistryStore()
|
||||
self._registry.initialize()
|
||||
self._mapper = PathMapper()
|
||||
|
||||
self._indexer = IncrementalIndexer(
|
||||
self._registry, self._mapper, self.config
|
||||
)
|
||||
|
||||
self._watcher = FileWatcher(
|
||||
self.root_path, self.watcher_config, self._handle_changes
|
||||
)
|
||||
|
||||
# Install signal handlers
|
||||
self._install_signal_handlers()
|
||||
|
||||
# Start watcher
|
||||
self._running = True
|
||||
self._stats.is_running = True
|
||||
self._stop_event.clear()
|
||||
self._watcher.start()
|
||||
|
||||
logger.info("WatcherManager started for: %s", self.root_path)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop watching and clean up."""
|
||||
with self._lock:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
self._stats.is_running = False
|
||||
self._stop_event.set()
|
||||
|
||||
# Stop watcher
|
||||
if self._watcher:
|
||||
self._watcher.stop()
|
||||
self._watcher = None
|
||||
|
||||
# Close indexer
|
||||
if self._indexer:
|
||||
self._indexer.close()
|
||||
self._indexer = None
|
||||
|
||||
# Close registry
|
||||
if self._registry:
|
||||
self._registry.close()
|
||||
self._registry = None
|
||||
|
||||
# Restore signal handlers
|
||||
self._restore_signal_handlers()
|
||||
|
||||
logger.info("WatcherManager stopped")
|
||||
|
||||
def wait(self) -> None:
|
||||
"""Block until stopped."""
|
||||
try:
|
||||
while self._running:
|
||||
self._stop_event.wait(timeout=1.0)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted, stopping...")
|
||||
self.stop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if manager is running."""
|
||||
return self._running
|
||||
|
||||
def get_stats(self) -> WatcherStats:
|
||||
"""Get runtime statistics."""
|
||||
return WatcherStats(
|
||||
files_watched=self._stats.files_watched,
|
||||
events_processed=self._stats.events_processed,
|
||||
last_event_time=self._stats.last_event_time,
|
||||
is_running=self._running,
|
||||
)
|
||||
171
codex-lens/tests/test_api_reranker.py
Normal file
171
codex-lens/tests/test_api_reranker.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Tests for APIReranker backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.semantic.reranker import get_reranker
|
||||
from codexlens.semantic.reranker.api_reranker import APIReranker
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = 200,
|
||||
json_data: Any = None,
|
||||
text: str = "",
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.status_code = int(status_code)
|
||||
self._json_data = json_data
|
||||
self.text = text
|
||||
self.headers = headers or {}
|
||||
|
||||
def json(self) -> Any:
|
||||
return self._json_data
|
||||
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self, *, base_url: str | None = None, headers: dict[str, str] | None = None, timeout: float | None = None) -> None:
|
||||
self.base_url = base_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.closed = False
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._responses: list[DummyResponse] = []
|
||||
|
||||
def queue(self, response: DummyResponse) -> None:
|
||||
self._responses.append(response)
|
||||
|
||||
def post(self, endpoint: str, *, json: dict[str, Any] | None = None) -> DummyResponse:
|
||||
self.calls.append({"endpoint": endpoint, "json": json})
|
||||
if not self._responses:
|
||||
raise AssertionError("DummyClient has no queued responses")
|
||||
return self._responses.pop(0)
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def httpx_clients(monkeypatch: pytest.MonkeyPatch) -> list[DummyClient]:
|
||||
clients: list[DummyClient] = []
|
||||
|
||||
dummy_httpx = types.ModuleType("httpx")
|
||||
|
||||
def Client(*, base_url: str | None = None, headers: dict[str, str] | None = None, timeout: float | None = None) -> DummyClient:
|
||||
client = DummyClient(base_url=base_url, headers=headers, timeout=timeout)
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
dummy_httpx.Client = Client
|
||||
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
def test_api_reranker_requires_api_key(
|
||||
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
|
||||
) -> None:
|
||||
monkeypatch.delenv("RERANKER_API_KEY", raising=False)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing API key"):
|
||||
APIReranker()
|
||||
|
||||
assert httpx_clients == []
|
||||
|
||||
|
||||
def test_api_reranker_reads_api_key_from_env(
|
||||
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
|
||||
) -> None:
|
||||
monkeypatch.setenv("RERANKER_API_KEY", "test-key")
|
||||
|
||||
reranker = APIReranker()
|
||||
assert len(httpx_clients) == 1
|
||||
assert httpx_clients[0].headers["Authorization"] == "Bearer test-key"
|
||||
reranker.close()
|
||||
assert httpx_clients[0].closed is True
|
||||
|
||||
|
||||
def test_api_reranker_scores_pairs_siliconflow(
|
||||
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
|
||||
) -> None:
|
||||
monkeypatch.delenv("RERANKER_API_KEY", raising=False)
|
||||
|
||||
reranker = APIReranker(api_key="k", provider="siliconflow")
|
||||
client = httpx_clients[0]
|
||||
|
||||
client.queue(
|
||||
DummyResponse(
|
||||
json_data={
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.9},
|
||||
{"index": 1, "relevance_score": 0.1},
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")])
|
||||
assert scores == pytest.approx([0.9, 0.1])
|
||||
|
||||
assert client.calls[0]["endpoint"] == "/v1/rerank"
|
||||
payload = client.calls[0]["json"]
|
||||
assert payload["model"] == "BAAI/bge-reranker-v2-m3"
|
||||
assert payload["query"] == "q"
|
||||
assert payload["documents"] == ["d1", "d2"]
|
||||
assert payload["top_n"] == 2
|
||||
assert payload["return_documents"] is False
|
||||
|
||||
|
||||
def test_api_reranker_retries_on_5xx(
|
||||
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
|
||||
) -> None:
|
||||
monkeypatch.setenv("RERANKER_API_KEY", "k")
|
||||
|
||||
from codexlens.semantic.reranker import api_reranker as api_reranker_module
|
||||
|
||||
monkeypatch.setattr(api_reranker_module.time, "sleep", lambda *_args, **_kwargs: None)
|
||||
|
||||
reranker = APIReranker(max_retries=1)
|
||||
client = httpx_clients[0]
|
||||
|
||||
client.queue(DummyResponse(status_code=500, text="oops", json_data={"error": "oops"}))
|
||||
client.queue(
|
||||
DummyResponse(
|
||||
json_data={"results": [{"index": 0, "relevance_score": 0.7}]},
|
||||
)
|
||||
)
|
||||
|
||||
scores = reranker.score_pairs([("q", "d")])
|
||||
assert scores == pytest.approx([0.7])
|
||||
assert len(client.calls) == 2
|
||||
|
||||
|
||||
def test_api_reranker_unauthorized_raises(
|
||||
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
|
||||
) -> None:
|
||||
monkeypatch.setenv("RERANKER_API_KEY", "k")
|
||||
|
||||
reranker = APIReranker()
|
||||
client = httpx_clients[0]
|
||||
client.queue(DummyResponse(status_code=401, text="unauthorized"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="unauthorized"):
|
||||
reranker.score_pairs([("q", "d")])
|
||||
|
||||
|
||||
def test_factory_api_backend_constructs_reranker(
|
||||
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
|
||||
) -> None:
|
||||
monkeypatch.setenv("RERANKER_API_KEY", "k")
|
||||
|
||||
reranker = get_reranker(backend="api")
|
||||
assert isinstance(reranker, APIReranker)
|
||||
assert len(httpx_clients) == 1
|
||||
|
||||
139
codex-lens/tests/test_hybrid_search_reranker_backend.py
Normal file
139
codex-lens/tests/test_hybrid_search_reranker_backend.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Tests for HybridSearchEngine reranker backend selection."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.search.hybrid_search import HybridSearchEngine
|
||||
|
||||
|
||||
def test_get_cross_encoder_reranker_uses_factory_backend_legacy(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
calls: dict[str, object] = {}
|
||||
|
||||
def fake_check_reranker_available(backend: str):
|
||||
calls["check_backend"] = backend
|
||||
return True, None
|
||||
|
||||
sentinel = object()
|
||||
|
||||
def fake_get_reranker(*, backend: str, model_name=None, device=None, **kwargs):
|
||||
calls["get_args"] = {
|
||||
"backend": backend,
|
||||
"model_name": model_name,
|
||||
"device": device,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
"codexlens.semantic.reranker.check_reranker_available",
|
||||
fake_check_reranker_available,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"codexlens.semantic.reranker.get_reranker",
|
||||
fake_get_reranker,
|
||||
)
|
||||
|
||||
config = Config(
|
||||
data_dir=tmp_path / "legacy",
|
||||
enable_reranking=True,
|
||||
enable_cross_encoder_rerank=True,
|
||||
reranker_backend="legacy",
|
||||
reranker_model="dummy-model",
|
||||
)
|
||||
engine = HybridSearchEngine(config=config)
|
||||
|
||||
reranker = engine._get_cross_encoder_reranker()
|
||||
assert reranker is sentinel
|
||||
assert calls["check_backend"] == "legacy"
|
||||
|
||||
get_args = calls["get_args"]
|
||||
assert isinstance(get_args, dict)
|
||||
assert get_args["backend"] == "legacy"
|
||||
assert get_args["model_name"] == "dummy-model"
|
||||
assert get_args["device"] is None
|
||||
|
||||
|
||||
def test_get_cross_encoder_reranker_uses_factory_backend_onnx_gpu_flag(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
calls: dict[str, object] = {}
|
||||
|
||||
def fake_check_reranker_available(backend: str):
|
||||
calls["check_backend"] = backend
|
||||
return True, None
|
||||
|
||||
sentinel = object()
|
||||
|
||||
def fake_get_reranker(*, backend: str, model_name=None, device=None, **kwargs):
|
||||
calls["get_args"] = {
|
||||
"backend": backend,
|
||||
"model_name": model_name,
|
||||
"device": device,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
"codexlens.semantic.reranker.check_reranker_available",
|
||||
fake_check_reranker_available,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"codexlens.semantic.reranker.get_reranker",
|
||||
fake_get_reranker,
|
||||
)
|
||||
|
||||
config = Config(
|
||||
data_dir=tmp_path / "onnx",
|
||||
enable_reranking=True,
|
||||
enable_cross_encoder_rerank=True,
|
||||
reranker_backend="onnx",
|
||||
embedding_use_gpu=False,
|
||||
)
|
||||
engine = HybridSearchEngine(config=config)
|
||||
|
||||
reranker = engine._get_cross_encoder_reranker()
|
||||
assert reranker is sentinel
|
||||
assert calls["check_backend"] == "onnx"
|
||||
|
||||
get_args = calls["get_args"]
|
||||
assert isinstance(get_args, dict)
|
||||
assert get_args["backend"] == "onnx"
|
||||
assert get_args["model_name"] is None
|
||||
assert get_args["device"] is None
|
||||
assert get_args["kwargs"]["use_gpu"] is False
|
||||
|
||||
|
||||
def test_get_cross_encoder_reranker_returns_none_when_backend_unavailable(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
def fake_check_reranker_available(backend: str):
|
||||
return False, "missing deps"
|
||||
|
||||
def fake_get_reranker(*args, **kwargs):
|
||||
raise AssertionError("get_reranker should not be called when backend is unavailable")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"codexlens.semantic.reranker.check_reranker_available",
|
||||
fake_check_reranker_available,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"codexlens.semantic.reranker.get_reranker",
|
||||
fake_get_reranker,
|
||||
)
|
||||
|
||||
config = Config(
|
||||
data_dir=tmp_path / "unavailable",
|
||||
enable_reranking=True,
|
||||
enable_cross_encoder_rerank=True,
|
||||
reranker_backend="onnx",
|
||||
)
|
||||
engine = HybridSearchEngine(config=config)
|
||||
|
||||
assert engine._get_cross_encoder_reranker() is None
|
||||
85
codex-lens/tests/test_litellm_reranker.py
Normal file
85
codex-lens/tests/test_litellm_reranker.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for LiteLLMReranker (LLM-based reranking)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
|
||||
|
||||
|
||||
def _install_dummy_ccw_litellm(
|
||||
monkeypatch: pytest.MonkeyPatch, *, responses: list[str]
|
||||
) -> None:
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class LiteLLMClient:
|
||||
def __init__(self, model: str = "default", **kwargs) -> None:
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
self._responses = list(responses)
|
||||
self.calls: list[list[ChatMessage]] = []
|
||||
|
||||
def chat(self, messages, **kwargs):
|
||||
self.calls.append(list(messages))
|
||||
content = self._responses.pop(0) if self._responses else ""
|
||||
return types.SimpleNamespace(content=content)
|
||||
|
||||
dummy = types.ModuleType("ccw_litellm")
|
||||
dummy.ChatMessage = ChatMessage
|
||||
dummy.LiteLLMClient = LiteLLMClient
|
||||
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy)
|
||||
|
||||
|
||||
def test_score_pairs_parses_numbers_and_normalizes_scales(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_install_dummy_ccw_litellm(monkeypatch, responses=["0.73", "7", "80"])
|
||||
|
||||
reranker = LiteLLMReranker(model="dummy")
|
||||
scores = reranker.score_pairs([("q", "d1"), ("q", "d2"), ("q", "d3")])
|
||||
assert scores == pytest.approx([0.73, 0.7, 0.8])
|
||||
|
||||
|
||||
def test_score_pairs_parses_json_score_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_install_dummy_ccw_litellm(monkeypatch, responses=['{"score": 0.42}'])
|
||||
|
||||
reranker = LiteLLMReranker(model="dummy")
|
||||
scores = reranker.score_pairs([("q", "d")])
|
||||
assert scores == pytest.approx([0.42])
|
||||
|
||||
|
||||
def test_score_pairs_uses_default_score_on_parse_failure(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_install_dummy_ccw_litellm(monkeypatch, responses=["N/A"])
|
||||
|
||||
reranker = LiteLLMReranker(model="dummy", default_score=0.123)
|
||||
scores = reranker.score_pairs([("q", "d")])
|
||||
assert scores == pytest.approx([0.123])
|
||||
|
||||
|
||||
def test_rate_limiting_sleeps_between_requests(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_install_dummy_ccw_litellm(monkeypatch, responses=["0.1", "0.2"])
|
||||
|
||||
reranker = LiteLLMReranker(model="dummy", min_interval_seconds=1.0)
|
||||
|
||||
import codexlens.semantic.reranker.litellm_reranker as litellm_reranker_module
|
||||
|
||||
sleeps: list[float] = []
|
||||
times = iter([100.0, 100.0, 100.1, 100.1])
|
||||
|
||||
monkeypatch.setattr(litellm_reranker_module.time, "monotonic", lambda: next(times))
|
||||
monkeypatch.setattr(
|
||||
litellm_reranker_module.time, "sleep", lambda seconds: sleeps.append(seconds)
|
||||
)
|
||||
|
||||
_ = reranker.score_pairs([("q", "d1"), ("q", "d2")])
|
||||
assert sleeps == pytest.approx([0.9])
|
||||
|
||||
115
codex-lens/tests/test_reranker_backends.py
Normal file
115
codex-lens/tests/test_reranker_backends.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Mocked smoke tests for all reranker backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_reranker_backend_legacy_scores_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from codexlens.semantic.reranker import legacy as legacy_module
|
||||
|
||||
class DummyCrossEncoder:
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.calls: list[dict[str, object]] = []
|
||||
|
||||
def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]:
|
||||
self.calls.append({"pairs": list(pairs), "batch_size": int(batch_size)})
|
||||
return [0.5 for _ in pairs]
|
||||
|
||||
monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder)
|
||||
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True)
|
||||
monkeypatch.setattr(legacy_module, "_import_error", None)
|
||||
|
||||
reranker = legacy_module.CrossEncoderReranker(model_name="dummy-model", device="cpu")
|
||||
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0)
|
||||
assert scores == pytest.approx([0.5, 0.5])
|
||||
|
||||
|
||||
def test_reranker_backend_onnx_availability_check(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from codexlens.semantic.reranker.onnx_reranker import check_onnx_reranker_available
|
||||
|
||||
dummy_numpy = types.ModuleType("numpy")
|
||||
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
||||
|
||||
dummy_optimum = types.ModuleType("optimum")
|
||||
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
||||
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
||||
dummy_optimum_ort.ORTModelForSequenceClassification = object()
|
||||
|
||||
dummy_transformers = types.ModuleType("transformers")
|
||||
dummy_transformers.AutoTokenizer = object()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "numpy", dummy_numpy)
|
||||
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
||||
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
||||
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
||||
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
||||
|
||||
ok, err = check_onnx_reranker_available()
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_reranker_backend_api_constructs_with_dummy_httpx(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from codexlens.semantic.reranker.api_reranker import APIReranker
|
||||
|
||||
created: list[object] = []
|
||||
|
||||
class DummyClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.closed = False
|
||||
created.append(self)
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
dummy_httpx = types.ModuleType("httpx")
|
||||
dummy_httpx.Client = DummyClient
|
||||
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
|
||||
|
||||
reranker = APIReranker(api_key="k", provider="siliconflow")
|
||||
assert reranker.provider == "siliconflow"
|
||||
assert len(created) == 1
|
||||
assert created[0].headers["Authorization"] == "Bearer k"
|
||||
reranker.close()
|
||||
assert created[0].closed is True
|
||||
|
||||
|
||||
def test_reranker_backend_litellm_scores_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class DummyLiteLLMClient:
|
||||
def __init__(self, model: str = "default", **_kwargs: object) -> None:
|
||||
self.model = model
|
||||
|
||||
def chat(self, _messages: list[ChatMessage]) -> object:
|
||||
return types.SimpleNamespace(content="0.5")
|
||||
|
||||
dummy_litellm = types.ModuleType("ccw_litellm")
|
||||
dummy_litellm.ChatMessage = ChatMessage
|
||||
dummy_litellm.LiteLLMClient = DummyLiteLLMClient
|
||||
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
|
||||
|
||||
reranker = LiteLLMReranker(model="dummy")
|
||||
assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5])
|
||||
|
||||
315
codex-lens/tests/test_reranker_factory.py
Normal file
315
codex-lens/tests/test_reranker_factory.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Tests for reranker factory and availability checks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import math
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.semantic.reranker import (
|
||||
BaseReranker,
|
||||
ONNXReranker,
|
||||
check_reranker_available,
|
||||
get_reranker,
|
||||
)
|
||||
from codexlens.semantic.reranker import legacy as legacy_module
|
||||
|
||||
|
||||
def test_public_imports_work() -> None:
|
||||
from codexlens.semantic.reranker import BaseReranker as ImportedBaseReranker
|
||||
from codexlens.semantic.reranker import get_reranker as imported_get_reranker
|
||||
|
||||
assert ImportedBaseReranker is BaseReranker
|
||||
assert imported_get_reranker is get_reranker
|
||||
|
||||
|
||||
def test_base_reranker_is_abstract() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
BaseReranker() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_check_reranker_available_invalid_backend() -> None:
|
||||
ok, err = check_reranker_available("nope")
|
||||
assert ok is False
|
||||
assert "Invalid reranker backend" in (err or "")
|
||||
|
||||
|
||||
def test_get_reranker_invalid_backend_raises_value_error() -> None:
|
||||
with pytest.raises(ValueError, match="Unknown backend"):
|
||||
get_reranker("nope")
|
||||
|
||||
|
||||
def test_get_reranker_legacy_missing_dependency_raises_import_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", False)
|
||||
monkeypatch.setattr(legacy_module, "_import_error", "missing sentence-transformers")
|
||||
|
||||
with pytest.raises(ImportError, match="missing sentence-transformers"):
|
||||
get_reranker(backend="legacy", model_name="dummy-model")
|
||||
|
||||
|
||||
def test_get_reranker_legacy_returns_cross_encoder_reranker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class DummyCrossEncoder:
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.last_batch_size: int | None = None
|
||||
|
||||
def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]:
|
||||
self.last_batch_size = int(batch_size)
|
||||
return [0.5 for _ in pairs]
|
||||
|
||||
monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder)
|
||||
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True)
|
||||
monkeypatch.setattr(legacy_module, "_import_error", None)
|
||||
|
||||
reranker = get_reranker(backend=" LEGACY ", model_name="dummy-model", device="cpu")
|
||||
assert isinstance(reranker, legacy_module.CrossEncoderReranker)
|
||||
|
||||
assert reranker.score_pairs([]) == []
|
||||
|
||||
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0)
|
||||
assert scores == pytest.approx([0.5, 0.5])
|
||||
assert reranker._model is not None
|
||||
assert reranker._model.last_batch_size == 32
|
||||
|
||||
|
||||
def test_check_reranker_available_onnx_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
||||
if name == "onnxruntime":
|
||||
raise ImportError("no onnxruntime")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
ok, err = check_reranker_available("onnx")
|
||||
assert ok is False
|
||||
assert "onnxruntime not available" in (err or "")
|
||||
|
||||
|
||||
def test_check_reranker_available_onnx_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
||||
dummy_optimum = types.ModuleType("optimum")
|
||||
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
||||
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
||||
dummy_optimum_ort.ORTModelForSequenceClassification = object()
|
||||
|
||||
dummy_transformers = types.ModuleType("transformers")
|
||||
dummy_transformers.AutoTokenizer = object()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
||||
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
||||
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
||||
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
||||
|
||||
ok, err = check_reranker_available("onnx")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_check_reranker_available_litellm_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
||||
if name == "ccw_litellm":
|
||||
raise ImportError("no ccw-litellm")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
ok, err = check_reranker_available("litellm")
|
||||
assert ok is False
|
||||
assert "ccw-litellm not available" in (err or "")
|
||||
|
||||
|
||||
def test_check_reranker_available_litellm_deps_present(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
dummy_litellm = types.ModuleType("ccw_litellm")
|
||||
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
|
||||
|
||||
ok, err = check_reranker_available("litellm")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_check_reranker_available_api_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
||||
if name == "httpx":
|
||||
raise ImportError("no httpx")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
ok, err = check_reranker_available("api")
|
||||
assert ok is False
|
||||
assert "httpx not available" in (err or "")
|
||||
|
||||
|
||||
def test_check_reranker_available_api_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dummy_httpx = types.ModuleType("httpx")
|
||||
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
|
||||
|
||||
ok, err = check_reranker_available("api")
|
||||
assert ok is True
|
||||
assert err is None
|
||||
|
||||
|
||||
def test_get_reranker_litellm_returns_litellm_reranker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class DummyLiteLLMClient:
|
||||
def __init__(self, model: str = "default", **kwargs) -> None:
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def chat(self, messages, **kwargs):
|
||||
return types.SimpleNamespace(content="0.5")
|
||||
|
||||
dummy_litellm = types.ModuleType("ccw_litellm")
|
||||
dummy_litellm.ChatMessage = ChatMessage
|
||||
dummy_litellm.LiteLLMClient = DummyLiteLLMClient
|
||||
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
|
||||
|
||||
reranker = get_reranker(backend="litellm", model_name="dummy-model")
|
||||
|
||||
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
|
||||
|
||||
assert isinstance(reranker, LiteLLMReranker)
|
||||
assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5])
|
||||
|
||||
|
||||
def test_get_reranker_onnx_raises_import_error_with_dependency_hint(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
|
||||
if name == "onnxruntime":
|
||||
raise ImportError("no onnxruntime")
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
with pytest.raises(ImportError) as exc:
|
||||
get_reranker(backend="onnx", model_name="any")
|
||||
|
||||
assert "onnxruntime" in str(exc.value)
|
||||
|
||||
|
||||
def test_get_reranker_default_backend_is_onnx(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
||||
dummy_optimum = types.ModuleType("optimum")
|
||||
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
||||
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
||||
dummy_optimum_ort.ORTModelForSequenceClassification = object()
|
||||
|
||||
dummy_transformers = types.ModuleType("transformers")
|
||||
dummy_transformers.AutoTokenizer = object()
|
||||
|
||||
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
||||
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
||||
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
||||
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
||||
|
||||
reranker = get_reranker()
|
||||
assert isinstance(reranker, ONNXReranker)
|
||||
|
||||
|
||||
def test_onnx_reranker_scores_pairs_with_sigmoid_normalization(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
import numpy as np
|
||||
|
||||
dummy_onnxruntime = types.ModuleType("onnxruntime")
|
||||
|
||||
dummy_optimum = types.ModuleType("optimum")
|
||||
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
|
||||
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
|
||||
|
||||
class DummyModelOutput:
|
||||
def __init__(self, logits: np.ndarray) -> None:
|
||||
self.logits = logits
|
||||
|
||||
class DummyModel:
|
||||
input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[int] = []
|
||||
self._next_logit = 0
|
||||
|
||||
def __call__(self, **inputs):
|
||||
batch = int(inputs["input_ids"].shape[0])
|
||||
start = self._next_logit
|
||||
self._next_logit += batch
|
||||
self.calls.append(batch)
|
||||
logits = np.arange(start, start + batch, dtype=np.float32).reshape(batch, 1)
|
||||
return DummyModelOutput(logits=logits)
|
||||
|
||||
class DummyORTModelForSequenceClassification:
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name: str, providers=None, **kwargs):
|
||||
_ = model_name, providers, kwargs
|
||||
return DummyModel()
|
||||
|
||||
dummy_optimum_ort.ORTModelForSequenceClassification = DummyORTModelForSequenceClassification
|
||||
|
||||
dummy_transformers = types.ModuleType("transformers")
|
||||
|
||||
class DummyAutoTokenizer:
|
||||
model_max_length = 512
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name: str, **kwargs):
|
||||
_ = model_name, kwargs
|
||||
return cls()
|
||||
|
||||
def __call__(self, *, text, text_pair, return_tensors, **kwargs):
|
||||
_ = text_pair, kwargs
|
||||
assert return_tensors == "np"
|
||||
batch = len(text)
|
||||
# Include token_type_ids to ensure input filtering is exercised.
|
||||
return {
|
||||
"input_ids": np.zeros((batch, 4), dtype=np.int64),
|
||||
"attention_mask": np.ones((batch, 4), dtype=np.int64),
|
||||
"token_type_ids": np.zeros((batch, 4), dtype=np.int64),
|
||||
}
|
||||
|
||||
dummy_transformers.AutoTokenizer = DummyAutoTokenizer
|
||||
|
||||
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
|
||||
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
|
||||
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
|
||||
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
|
||||
|
||||
reranker = get_reranker(backend="onnx", model_name="dummy-model", use_gpu=False)
|
||||
assert isinstance(reranker, ONNXReranker)
|
||||
assert reranker._model is None
|
||||
|
||||
pairs = [("q", f"d{idx}") for idx in range(5)]
|
||||
scores = reranker.score_pairs(pairs, batch_size=2)
|
||||
|
||||
assert reranker._model is not None
|
||||
assert reranker._model.calls == [2, 2, 1]
|
||||
assert len(scores) == len(pairs)
|
||||
assert all(0.0 <= s <= 1.0 for s in scores)
|
||||
|
||||
expected = [1.0 / (1.0 + math.exp(-float(i))) for i in range(len(pairs))]
|
||||
assert scores == pytest.approx(expected, rel=1e-6, abs=1e-6)
|
||||
1
codex-lens/tests/test_watcher/__init__.py
Normal file
1
codex-lens/tests/test_watcher/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for watcher module."""
|
||||
43
codex-lens/tests/test_watcher/conftest.py
Normal file
43
codex-lens/tests/test_watcher/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Fixtures for watcher tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project() -> Generator[Path, None, None]:
|
||||
"""Create a temporary project directory with sample files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project = Path(tmpdir)
|
||||
|
||||
# Create sample Python file
|
||||
py_file = project / "main.py"
|
||||
py_file.write_text("def hello():\n print('Hello')\n")
|
||||
|
||||
# Create sample JavaScript file
|
||||
js_file = project / "app.js"
|
||||
js_file.write_text("function greet() {\n console.log('Hi');\n}\n")
|
||||
|
||||
# Create subdirectory with file
|
||||
sub_dir = project / "src"
|
||||
sub_dir.mkdir()
|
||||
(sub_dir / "utils.py").write_text("def add(a, b):\n return a + b\n")
|
||||
|
||||
# Create ignored directory
|
||||
git_dir = project / ".git"
|
||||
git_dir.mkdir()
|
||||
(git_dir / "config").write_text("[core]\n")
|
||||
|
||||
yield project
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def watcher_config():
|
||||
"""Create default watcher configuration."""
|
||||
from codexlens.watcher import WatcherConfig
|
||||
return WatcherConfig(debounce_ms=100) # Short debounce for tests
|
||||
103
codex-lens/tests/test_watcher/test_events.py
Normal file
103
codex-lens/tests/test_watcher/test_events.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Tests for watcher event types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.watcher import ChangeType, FileEvent, WatcherConfig, IndexResult, WatcherStats
|
||||
|
||||
|
||||
class TestChangeType:
|
||||
"""Tests for ChangeType enum."""
|
||||
|
||||
def test_change_types_exist(self):
|
||||
"""Verify all change types are defined."""
|
||||
assert ChangeType.CREATED.value == "created"
|
||||
assert ChangeType.MODIFIED.value == "modified"
|
||||
assert ChangeType.DELETED.value == "deleted"
|
||||
assert ChangeType.MOVED.value == "moved"
|
||||
|
||||
def test_change_type_count(self):
|
||||
"""Verify we have exactly 4 change types."""
|
||||
assert len(ChangeType) == 4
|
||||
|
||||
|
||||
class TestFileEvent:
|
||||
"""Tests for FileEvent dataclass."""
|
||||
|
||||
def test_create_event(self):
|
||||
"""Test creating a file event."""
|
||||
event = FileEvent(
|
||||
path=Path("/test/file.py"),
|
||||
change_type=ChangeType.CREATED,
|
||||
timestamp=time.time(),
|
||||
)
|
||||
assert event.path == Path("/test/file.py")
|
||||
assert event.change_type == ChangeType.CREATED
|
||||
assert event.old_path is None
|
||||
|
||||
def test_moved_event(self):
|
||||
"""Test creating a moved event with old_path."""
|
||||
event = FileEvent(
|
||||
path=Path("/test/new.py"),
|
||||
change_type=ChangeType.MOVED,
|
||||
timestamp=time.time(),
|
||||
old_path=Path("/test/old.py"),
|
||||
)
|
||||
assert event.old_path == Path("/test/old.py")
|
||||
|
||||
|
||||
class TestWatcherConfig:
|
||||
"""Tests for WatcherConfig dataclass."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = WatcherConfig()
|
||||
assert config.debounce_ms == 1000
|
||||
assert ".git" in config.ignored_patterns
|
||||
assert "node_modules" in config.ignored_patterns
|
||||
assert "__pycache__" in config.ignored_patterns
|
||||
assert config.languages is None
|
||||
|
||||
def test_custom_debounce(self):
|
||||
"""Test custom debounce setting."""
|
||||
config = WatcherConfig(debounce_ms=500)
|
||||
assert config.debounce_ms == 500
|
||||
|
||||
|
||||
class TestIndexResult:
|
||||
"""Tests for IndexResult dataclass."""
|
||||
|
||||
def test_default_result(self):
|
||||
"""Test default result values."""
|
||||
result = IndexResult()
|
||||
assert result.files_indexed == 0
|
||||
assert result.files_removed == 0
|
||||
assert result.symbols_added == 0
|
||||
assert result.errors == []
|
||||
|
||||
def test_custom_result(self):
|
||||
"""Test creating result with values."""
|
||||
result = IndexResult(
|
||||
files_indexed=5,
|
||||
files_removed=2,
|
||||
symbols_added=50,
|
||||
errors=["error1"],
|
||||
)
|
||||
assert result.files_indexed == 5
|
||||
assert result.files_removed == 2
|
||||
|
||||
|
||||
class TestWatcherStats:
|
||||
"""Tests for WatcherStats dataclass."""
|
||||
|
||||
def test_default_stats(self):
|
||||
"""Test default stats values."""
|
||||
stats = WatcherStats()
|
||||
assert stats.files_watched == 0
|
||||
assert stats.events_processed == 0
|
||||
assert stats.last_event_time is None
|
||||
assert stats.is_running is False
|
||||
124
codex-lens/tests/test_watcher/test_file_watcher.py
Normal file
124
codex-lens/tests/test_watcher/test_file_watcher.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Tests for FileWatcher class."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.watcher import FileWatcher, WatcherConfig, FileEvent, ChangeType
|
||||
|
||||
|
||||
class TestFileWatcherInit:
|
||||
"""Tests for FileWatcher initialization."""
|
||||
|
||||
def test_init_with_valid_path(self, temp_project: Path, watcher_config: WatcherConfig):
|
||||
"""Test initializing with valid path."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
|
||||
|
||||
assert watcher.root_path == temp_project.resolve()
|
||||
assert watcher.config == watcher_config
|
||||
assert not watcher.is_running
|
||||
|
||||
def test_start_with_invalid_path(self, watcher_config: WatcherConfig):
|
||||
"""Test starting watcher with non-existent path."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(Path("/nonexistent/path"), watcher_config, lambda e: events.extend(e))
|
||||
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
watcher.start()
|
||||
|
||||
|
||||
class TestFileWatcherLifecycle:
|
||||
"""Tests for FileWatcher start/stop lifecycle."""
|
||||
|
||||
def test_start_stop(self, temp_project: Path, watcher_config: WatcherConfig):
|
||||
"""Test basic start and stop."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
|
||||
|
||||
watcher.start()
|
||||
assert watcher.is_running
|
||||
|
||||
watcher.stop()
|
||||
assert not watcher.is_running
|
||||
|
||||
def test_double_start(self, temp_project: Path, watcher_config: WatcherConfig):
|
||||
"""Test calling start twice."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
|
||||
|
||||
watcher.start()
|
||||
watcher.start() # Should not raise
|
||||
assert watcher.is_running
|
||||
|
||||
watcher.stop()
|
||||
|
||||
def test_double_stop(self, temp_project: Path, watcher_config: WatcherConfig):
|
||||
"""Test calling stop twice."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
|
||||
|
||||
watcher.start()
|
||||
watcher.stop()
|
||||
watcher.stop() # Should not raise
|
||||
assert not watcher.is_running
|
||||
|
||||
|
||||
class TestFileWatcherEvents:
|
||||
"""Tests for FileWatcher event detection."""
|
||||
|
||||
def test_detect_file_creation(self, temp_project: Path, watcher_config: WatcherConfig):
|
||||
"""Test detecting new file creation."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
|
||||
|
||||
try:
|
||||
watcher.start()
|
||||
time.sleep(0.3) # Let watcher start (longer for Windows)
|
||||
|
||||
# Create new file
|
||||
new_file = temp_project / "new_file.py"
|
||||
new_file.write_text("# New file\n")
|
||||
|
||||
# Wait for event with retries (watchdog timing varies by platform)
|
||||
max_wait = 2.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
time.sleep(0.2)
|
||||
waited += 0.2
|
||||
# Windows may report MODIFIED instead of CREATED
|
||||
file_events = [e for e in events if e.change_type in (ChangeType.CREATED, ChangeType.MODIFIED)]
|
||||
if any(e.path.name == "new_file.py" for e in file_events):
|
||||
break
|
||||
|
||||
# Check event was detected (Windows may report MODIFIED instead of CREATED)
|
||||
relevant_events = [e for e in events if e.change_type in (ChangeType.CREATED, ChangeType.MODIFIED)]
|
||||
assert len(relevant_events) >= 1, f"Expected file event, got: {events}"
|
||||
assert any(e.path.name == "new_file.py" for e in relevant_events)
|
||||
finally:
|
||||
watcher.stop()
|
||||
|
||||
def test_filter_ignored_directories(self, temp_project: Path, watcher_config: WatcherConfig):
|
||||
"""Test that files in ignored directories are filtered."""
|
||||
events: List[FileEvent] = []
|
||||
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
|
||||
|
||||
try:
|
||||
watcher.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
# Create file in .git (should be ignored)
|
||||
git_file = temp_project / ".git" / "test.py"
|
||||
git_file.write_text("# In git\n")
|
||||
|
||||
time.sleep(watcher_config.debounce_ms / 1000.0 + 0.2)
|
||||
|
||||
# No events should be detected for .git files
|
||||
git_events = [e for e in events if ".git" in str(e.path)]
|
||||
assert len(git_events) == 0
|
||||
finally:
|
||||
watcher.stop()
|
||||
Reference in New Issue
Block a user