feat(codex-lens): add unified reranker architecture and file watcher

Unified Reranker Architecture:
- Add BaseReranker ABC with factory pattern
- Implement 4 backends: ONNX (default), API, LiteLLM, Legacy
- Add .env configuration parsing for API credentials
- Migrate from sentence-transformers to optimum+onnxruntime

File Watcher Module:
- Add real-time file system monitoring with watchdog
- Implement IncrementalIndexer for single-file updates
- Add WatcherManager with signal handling and graceful shutdown
- Add 'codexlens watch' CLI command
- Event filtering, debouncing, and deduplication
- Thread-safe design with proper resource cleanup

Tests: 16 watcher tests + 5 reranker test files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
catlog22
2026-01-01 13:23:52 +08:00
parent 8ac27548ad
commit 520f2d26f2
27 changed files with 3571 additions and 14 deletions

66
codex-lens/.env.example Normal file
View File

@@ -0,0 +1,66 @@
# CodexLens Environment Configuration
# Copy this file to .codexlens/.env and fill in your values
#
# Priority order:
# 1. Environment variables (already set in shell)
# 2. .codexlens/.env (workspace-local, this file)
# 3. .env (project root)
# ============================================
# RERANKER Configuration
# ============================================
# API key for reranker service (SiliconFlow/Cohere/Jina)
# Required for 'api' backend
# RERANKER_API_KEY=sk-xxxx
# Base URL for reranker API (overrides provider default)
# SiliconFlow: https://api.siliconflow.cn
# Cohere: https://api.cohere.ai
# Jina: https://api.jina.ai
# RERANKER_API_BASE=https://api.siliconflow.cn
# Reranker provider: siliconflow, cohere, jina
# RERANKER_PROVIDER=siliconflow
# Reranker model name
# SiliconFlow: BAAI/bge-reranker-v2-m3
# Cohere: rerank-english-v3.0
# Jina: jina-reranker-v2-base-multilingual
# RERANKER_MODEL=BAAI/bge-reranker-v2-m3
# ============================================
# EMBEDDING Configuration
# ============================================
# API key for embedding service (for litellm backend)
# EMBEDDING_API_KEY=sk-xxxx
# Base URL for embedding API
# EMBEDDING_API_BASE=https://api.openai.com
# Embedding model name
# EMBEDDING_MODEL=text-embedding-3-small
# ============================================
# LITELLM Configuration
# ============================================
# API key for LiteLLM (for litellm reranker backend)
# LITELLM_API_KEY=sk-xxxx
# Base URL for LiteLLM
# LITELLM_API_BASE=
# LiteLLM model name
# LITELLM_MODEL=gpt-4o-mini
# ============================================
# General Configuration
# ============================================
# Custom data directory path (default: ~/.codexlens)
# CODEXLENS_DATA_DIR=~/.codexlens
# Enable debug mode (true/false)
# CODEXLENS_DEBUG=false

View File

@@ -21,6 +21,7 @@ dependencies = [
"tree-sitter-javascript>=0.25",
"tree-sitter-typescript>=0.23",
"pathspec>=0.11",
"watchdog>=3.0",
]
[project.optional-dependencies]
@@ -50,11 +51,35 @@ semantic-directml = [
]
# Cross-encoder reranking (second-stage, optional)
# Install with: pip install codexlens[reranker]
reranker = [
# Install with: pip install codexlens[reranker] (default: ONNX backend)
reranker-onnx = [
"optimum>=1.16",
"onnxruntime>=1.15",
"transformers>=4.36",
]
# Remote reranking via HTTP API
reranker-api = [
"httpx>=0.25",
]
# LLM-based reranking via ccw-litellm
reranker-litellm = [
"ccw-litellm>=0.1",
]
# Legacy sentence-transformers CrossEncoder reranker
reranker-legacy = [
"sentence-transformers>=2.2",
]
# Backward-compatible alias for default reranker backend
reranker = [
"optimum>=1.16",
"onnxruntime>=1.15",
"transformers>=4.36",
]
# Encoding detection for non-UTF8 files
encoding = [
"chardet>=5.0",

View File

@@ -22,6 +22,7 @@ from codexlens.storage.registry import RegistryStore, ProjectInfo
from codexlens.storage.index_tree import IndexTreeBuilder
from codexlens.storage.dir_index import DirIndexStore
from codexlens.search.chain_search import ChainSearchEngine, SearchOptions
from codexlens.watcher import WatcherManager, WatcherConfig
from .output import (
console,
@@ -321,6 +322,91 @@ def init(
registry.close()
@app.command()
def watch(
path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to watch."),
language: Optional[List[str]] = typer.Option(
None,
"--language",
"-l",
help="Limit watching to specific languages (repeat or comma-separated).",
),
debounce: int = typer.Option(1000, "--debounce", "-d", min=100, max=10000, help="Debounce interval in milliseconds."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose logging."),
) -> None:
"""Watch directory for changes and update index incrementally.
Monitors filesystem events and automatically updates the index
when files are created, modified, or deleted.
The directory must already be indexed (run 'codexlens init' first).
Press Ctrl+C to stop watching.
Examples:
codexlens watch .
codexlens watch /path/to/project --debounce 500 --verbose
codexlens watch . --language python,typescript
"""
_configure_logging(verbose)
from codexlens.watcher.events import IndexResult
base_path = path.expanduser().resolve()
# Check if path is indexed
mapper = PathMapper()
index_db = mapper.source_to_index_db(base_path)
if not index_db.exists():
console.print(f"[red]Error:[/red] Directory not indexed: {base_path}")
console.print("Run 'codexlens init' first to create the index.")
raise typer.Exit(code=1)
# Parse languages
languages = _parse_languages(language)
# Create watcher config
watcher_config = WatcherConfig(
debounce_ms=debounce,
languages=languages,
)
# Callback for indexed files
def on_indexed(result: IndexResult) -> None:
if result.files_indexed > 0:
console.print(f" [green]Indexed:[/green] {result.files_indexed} files ({result.symbols_added} symbols)")
if result.files_removed > 0:
console.print(f" [yellow]Removed:[/yellow] {result.files_removed} files")
if result.errors:
for error in result.errors[:3]: # Show first 3 errors
console.print(f" [red]Error:[/red] {error}")
console.print(f"[bold]Watching:[/bold] {base_path}")
console.print(f" Debounce: {debounce}ms")
if languages:
console.print(f" Languages: {', '.join(languages)}")
console.print(" Press Ctrl+C to stop.\n")
manager: WatcherManager | None = None
try:
manager = WatcherManager(
root_path=base_path,
watcher_config=watcher_config,
on_indexed=on_indexed,
)
manager.start()
manager.wait()
except KeyboardInterrupt:
pass
except Exception as exc:
console.print(f"[red]Error:[/red] {exc}")
raise typer.Exit(code=1)
finally:
if manager is not None:
manager.stop()
console.print("\n[dim]Watcher stopped.[/dim]")
@app.command()
def search(
query: str = typer.Argument(..., help="FTS query to run."),
@@ -2293,3 +2379,102 @@ def gpu_reset(
if gpu_info.preferred_device_id is not None:
console.print(f" Auto-selected device: {gpu_info.preferred_device_id}")
console.print(f" Device: [cyan]{gpu_info.gpu_name}[/cyan]")
# ==================== Watch Command ====================
@app.command()
def watch(
path: Path = typer.Argument(Path("."), exists=True, file_okay=False, dir_okay=True, help="Project root to watch."),
language: Optional[List[str]] = typer.Option(None, "--language", "-l", help="Languages to watch (comma-separated)."),
debounce: int = typer.Option(1000, "--debounce", "-d", min=100, max=10000, help="Debounce interval in milliseconds."),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable debug logging."),
) -> None:
"""Watch a directory for file changes and incrementally update the index.
Monitors the specified directory for file system changes (create, modify, delete)
and automatically updates the CodexLens index. The directory must already be indexed
using 'codexlens init' before watching.
Examples:
# Watch current directory
codexlens watch .
# Watch with custom debounce interval
codexlens watch . --debounce 2000
# Watch only Python and JavaScript files
codexlens watch . --language python,javascript
Press Ctrl+C to stop watching.
"""
_configure_logging(verbose)
watch_path = path.expanduser().resolve()
registry: RegistryStore | None = None
try:
# Validate that path is indexed
registry = RegistryStore()
registry.initialize()
mapper = PathMapper()
project_record = registry.find_by_source_path(str(watch_path))
if not project_record:
console.print(f"[red]Error:[/red] Directory is not indexed: {watch_path}")
console.print("[dim]Run 'codexlens init' first to create an index.[/dim]")
raise typer.Exit(code=1)
# Parse languages
languages = _parse_languages(language)
# Create watcher config
watcher_config = WatcherConfig(
debounce_ms=debounce,
languages=languages,
)
# Display startup message
console.print(f"[green]Starting watcher for:[/green] {watch_path}")
console.print(f"[dim]Debounce interval: {debounce}ms[/dim]")
if languages:
console.print(f"[dim]Watching languages: {', '.join(languages)}[/dim]")
console.print("[dim]Press Ctrl+C to stop[/dim]\n")
# Create and start watcher manager
manager = WatcherManager(
root_path=watch_path,
watcher_config=watcher_config,
on_indexed=lambda result: _display_index_result(result),
)
manager.start()
manager.wait()
except KeyboardInterrupt:
console.print("\n[yellow]Stopping watcher...[/yellow]")
except CodexLensError as exc:
console.print(f"[red]Watch failed:[/red] {exc}")
raise typer.Exit(code=1)
except Exception as exc:
console.print(f"[red]Unexpected error:[/red] {exc}")
raise typer.Exit(code=1)
finally:
if registry is not None:
registry.close()
def _display_index_result(result) -> None:
"""Display indexing result in real-time."""
if result.files_indexed > 0 or result.files_removed > 0:
parts = []
if result.files_indexed > 0:
parts.append(f"[green]✓ Indexed {result.files_indexed} file(s)[/green]")
if result.files_removed > 0:
parts.append(f"[yellow]✗ Removed {result.files_removed} file(s)[/yellow]")
console.print(" | ".join(parts))
if result.errors:
for error in result.errors[:3]: # Show max 3 errors
console.print(f" [red]Error:[/red] {error}")
if len(result.errors) > 3:
console.print(f" [dim]... and {len(result.errors) - 3} more errors[/dim]")

View File

@@ -116,8 +116,9 @@ class Config:
reranking_top_k: int = 50
symbol_boost_factor: float = 1.5
# Optional cross-encoder reranking (second stage, requires codexlens[reranker])
# Optional cross-encoder reranking (second stage; requires optional reranker deps)
enable_cross_encoder_rerank: bool = False
reranker_backend: str = "onnx"
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker_top_k: int = 50
@@ -311,6 +312,35 @@ class WorkspaceConfig:
"""Cache directory for this workspace."""
return self.codexlens_dir / "cache"
@property
def env_path(self) -> Path:
"""Path to workspace .env file."""
return self.codexlens_dir / ".env"
def load_env(self, *, override: bool = False) -> int:
"""Load .env file and apply to os.environ.
Args:
override: If True, override existing environment variables
Returns:
Number of variables applied
"""
from .env_config import apply_workspace_env
return apply_workspace_env(self.workspace_root, override=override)
def get_api_config(self, prefix: str) -> dict:
"""Get API configuration from environment.
Args:
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
Returns:
Dictionary with api_key, api_base, model, etc.
"""
from .env_config import get_api_config
return get_api_config(prefix, workspace_root=self.workspace_root)
def initialize(self) -> None:
"""Create the .codexlens directory structure."""
try:
@@ -324,6 +354,7 @@ class WorkspaceConfig:
"# CodexLens workspace data\n"
"cache/\n"
"*.log\n"
".env\n" # Exclude .env from git
)
except Exception as exc:
raise ConfigError(f"Failed to initialize workspace at {self.codexlens_dir}: {exc}") from exc

View File

@@ -0,0 +1,260 @@
"""Environment configuration loader for CodexLens.
Loads .env files from workspace .codexlens directory with fallback to project root.
Provides unified access to API configurations.
Priority order:
1. Environment variables (already set)
2. .codexlens/.env (workspace-local)
3. .env (project root)
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
log = logging.getLogger(__name__)
# Supported environment variables with descriptions
ENV_VARS = {
# Reranker API configuration
"RERANKER_API_KEY": "API key for reranker service (SiliconFlow/Cohere/Jina)",
"RERANKER_API_BASE": "Base URL for reranker API (overrides provider default)",
"RERANKER_PROVIDER": "Reranker provider: siliconflow, cohere, jina",
"RERANKER_MODEL": "Reranker model name",
# Embedding API configuration
"EMBEDDING_API_KEY": "API key for embedding service",
"EMBEDDING_API_BASE": "Base URL for embedding API",
"EMBEDDING_MODEL": "Embedding model name",
# LiteLLM configuration
"LITELLM_API_KEY": "API key for LiteLLM",
"LITELLM_API_BASE": "Base URL for LiteLLM",
"LITELLM_MODEL": "LiteLLM model name",
# General configuration
"CODEXLENS_DATA_DIR": "Custom data directory path",
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
}
def _parse_env_line(line: str) -> tuple[str, str] | None:
"""Parse a single .env line, returning (key, value) or None."""
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith("#"):
return None
# Handle export prefix
if line.startswith("export "):
line = line[7:].strip()
# Split on first =
if "=" not in line:
return None
key, _, value = line.partition("=")
key = key.strip()
value = value.strip()
# Remove surrounding quotes
if len(value) >= 2:
if (value.startswith('"') and value.endswith('"')) or \
(value.startswith("'") and value.endswith("'")):
value = value[1:-1]
return key, value
def load_env_file(env_path: Path) -> Dict[str, str]:
"""Load environment variables from a .env file.
Args:
env_path: Path to .env file
Returns:
Dictionary of environment variables
"""
if not env_path.is_file():
return {}
env_vars: Dict[str, str] = {}
try:
content = env_path.read_text(encoding="utf-8")
for line in content.splitlines():
result = _parse_env_line(line)
if result:
key, value = result
env_vars[key] = value
except Exception as exc:
log.warning("Failed to load .env file %s: %s", env_path, exc)
return env_vars
def load_workspace_env(workspace_root: Path | None = None) -> Dict[str, str]:
"""Load environment variables from workspace .env files.
Priority (later overrides earlier):
1. Project root .env
2. .codexlens/.env
Args:
workspace_root: Workspace root directory. If None, uses current directory.
Returns:
Merged dictionary of environment variables
"""
if workspace_root is None:
workspace_root = Path.cwd()
workspace_root = Path(workspace_root).resolve()
env_vars: Dict[str, str] = {}
# Load from project root .env (lowest priority)
root_env = workspace_root / ".env"
if root_env.is_file():
env_vars.update(load_env_file(root_env))
log.debug("Loaded %d vars from %s", len(env_vars), root_env)
# Load from .codexlens/.env (higher priority)
codexlens_env = workspace_root / ".codexlens" / ".env"
if codexlens_env.is_file():
loaded = load_env_file(codexlens_env)
env_vars.update(loaded)
log.debug("Loaded %d vars from %s", len(loaded), codexlens_env)
return env_vars
def apply_workspace_env(workspace_root: Path | None = None, *, override: bool = False) -> int:
"""Load .env files and apply to os.environ.
Args:
workspace_root: Workspace root directory
override: If True, override existing environment variables
Returns:
Number of variables applied
"""
env_vars = load_workspace_env(workspace_root)
applied = 0
for key, value in env_vars.items():
if override or key not in os.environ:
os.environ[key] = value
applied += 1
log.debug("Applied env var: %s", key)
return applied
def get_env(key: str, default: str | None = None, *, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback.
Priority:
1. os.environ (already set)
2. .codexlens/.env
3. .env
4. default value
Args:
key: Environment variable name
default: Default value if not found
workspace_root: Workspace root for .env file lookup
Returns:
Value or default
"""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Load from .env files
env_vars = load_workspace_env(workspace_root)
if key in env_vars:
return env_vars[key]
return default
def get_api_config(
prefix: str,
*,
workspace_root: Path | None = None,
defaults: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
"""Get API configuration from environment.
Loads {PREFIX}_API_KEY, {PREFIX}_API_BASE, {PREFIX}_MODEL, etc.
Args:
prefix: Environment variable prefix (e.g., "RERANKER", "EMBEDDING")
workspace_root: Workspace root for .env file lookup
defaults: Default values
Returns:
Dictionary with api_key, api_base, model, etc.
"""
defaults = defaults or {}
config: Dict[str, Any] = {}
# Standard API config fields
field_mapping = {
"api_key": f"{prefix}_API_KEY",
"api_base": f"{prefix}_API_BASE",
"model": f"{prefix}_MODEL",
"provider": f"{prefix}_PROVIDER",
"timeout": f"{prefix}_TIMEOUT",
}
for field, env_key in field_mapping.items():
value = get_env(env_key, workspace_root=workspace_root)
if value is not None:
# Type conversion for specific fields
if field == "timeout":
try:
config[field] = float(value)
except ValueError:
pass
else:
config[field] = value
elif field in defaults:
config[field] = defaults[field]
return config
def generate_env_example() -> str:
"""Generate .env.example content with all supported variables.
Returns:
String content for .env.example file
"""
lines = [
"# CodexLens Environment Configuration",
"# Copy this file to .codexlens/.env and fill in your values",
"",
]
# Group by prefix
groups: Dict[str, list] = {}
for key, desc in ENV_VARS.items():
prefix = key.split("_")[0]
if prefix not in groups:
groups[prefix] = []
groups[prefix].append((key, desc))
for prefix, items in groups.items():
lines.append(f"# {prefix} Configuration")
for key, desc in items:
lines.append(f"# {desc}")
lines.append(f"# {key}=")
lines.append("")
return "\n".join(lines)

View File

@@ -258,20 +258,52 @@ class HybridSearchEngine:
return None
try:
from codexlens.semantic.reranker import CrossEncoderReranker, check_cross_encoder_available
from codexlens.semantic.reranker import (
check_reranker_available,
get_reranker,
)
except Exception as exc:
self.logger.debug("Cross-encoder reranker unavailable: %s", exc)
self.logger.debug("Reranker factory unavailable: %s", exc)
return None
ok, err = check_cross_encoder_available()
backend = (getattr(self._config, "reranker_backend", "") or "").strip().lower() or "onnx"
ok, err = check_reranker_available(backend)
if not ok:
self.logger.debug("Cross-encoder reranker unavailable: %s", err)
self.logger.debug(
"Reranker backend unavailable (backend=%s): %s",
backend,
err,
)
return None
try:
return CrossEncoderReranker(model_name=self._config.reranker_model)
model_name = (getattr(self._config, "reranker_model", "") or "").strip() or None
if backend != "legacy" and model_name == "cross-encoder/ms-marco-MiniLM-L-6-v2":
model_name = None
device: str | None = None
kwargs: dict[str, Any] = {}
if backend == "onnx":
kwargs["use_gpu"] = bool(getattr(self._config, "embedding_use_gpu", True))
elif backend == "legacy":
if not bool(getattr(self._config, "embedding_use_gpu", True)):
device = "cpu"
return get_reranker(
backend=backend,
model_name=model_name,
device=device,
**kwargs,
)
except Exception as exc:
self.logger.debug("Failed to initialize cross-encoder reranker: %s", exc)
self.logger.debug(
"Failed to initialize reranker (backend=%s): %s",
backend,
exc,
)
return None
def _search_parallel(

View File

@@ -0,0 +1,22 @@
"""Reranker backends for second-stage search ranking.
This subpackage provides a unified interface and factory for different reranking
implementations (e.g., ONNX, API-based, LiteLLM, and legacy sentence-transformers).
"""
from __future__ import annotations
from .base import BaseReranker
from .factory import check_reranker_available, get_reranker
from .legacy import CrossEncoderReranker, check_cross_encoder_available
from .onnx_reranker import ONNXReranker, check_onnx_reranker_available
__all__ = [
"BaseReranker",
"check_reranker_available",
"get_reranker",
"CrossEncoderReranker",
"check_cross_encoder_available",
"ONNXReranker",
"check_onnx_reranker_available",
]

View File

@@ -0,0 +1,310 @@
"""API-based reranker using a remote HTTP provider.
Supported providers:
- SiliconFlow: https://api.siliconflow.cn/v1/rerank
- Cohere: https://api.cohere.ai/v1/rerank
- Jina: https://api.jina.ai/v1/rerank
"""
from __future__ import annotations
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Mapping, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_DEFAULT_ENV_API_KEY = "RERANKER_API_KEY"
def _get_env_with_fallback(key: str, workspace_root: Path | None = None) -> str | None:
"""Get environment variable with .env file fallback."""
# Check os.environ first
if key in os.environ:
return os.environ[key]
# Try loading from .env files
try:
from codexlens.env_config import get_env
return get_env(key, workspace_root=workspace_root)
except ImportError:
return None
def check_httpx_available() -> tuple[bool, str | None]:
try:
import httpx # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"httpx not available: {exc}. Install with: pip install httpx"
return True, None
class APIReranker(BaseReranker):
"""Reranker backed by a remote reranking HTTP API."""
_PROVIDER_DEFAULTS: Mapping[str, Mapping[str, str]] = {
"siliconflow": {
"api_base": "https://api.siliconflow.cn",
"endpoint": "/v1/rerank",
"default_model": "BAAI/bge-reranker-v2-m3",
},
"cohere": {
"api_base": "https://api.cohere.ai",
"endpoint": "/v1/rerank",
"default_model": "rerank-english-v3.0",
},
"jina": {
"api_base": "https://api.jina.ai",
"endpoint": "/v1/rerank",
"default_model": "jina-reranker-v2-base-multilingual",
},
}
def __init__(
self,
*,
provider: str = "siliconflow",
model_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
timeout: float = 30.0,
max_retries: int = 3,
backoff_base_s: float = 0.5,
backoff_max_s: float = 8.0,
env_api_key: str = _DEFAULT_ENV_API_KEY,
workspace_root: Path | str | None = None,
) -> None:
ok, err = check_httpx_available()
if not ok: # pragma: no cover - exercised via factory availability tests
raise ImportError(err)
import httpx
self._workspace_root = Path(workspace_root) if workspace_root else None
self.provider = (provider or "").strip().lower()
if self.provider not in self._PROVIDER_DEFAULTS:
raise ValueError(
f"Unknown reranker provider: {provider}. "
f"Supported providers: {', '.join(sorted(self._PROVIDER_DEFAULTS))}"
)
defaults = self._PROVIDER_DEFAULTS[self.provider]
# Load api_base from env with .env fallback
env_api_base = _get_env_with_fallback("RERANKER_API_BASE", self._workspace_root)
self.api_base = (api_base or env_api_base or defaults["api_base"]).strip().rstrip("/")
self.endpoint = defaults["endpoint"]
# Load model from env with .env fallback
env_model = _get_env_with_fallback("RERANKER_MODEL", self._workspace_root)
self.model_name = (model_name or env_model or defaults["default_model"]).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
# Load API key from env with .env fallback
resolved_key = api_key or _get_env_with_fallback(env_api_key, self._workspace_root) or ""
resolved_key = resolved_key.strip()
if not resolved_key:
raise ValueError(
f"Missing API key for reranker provider '{self.provider}'. "
f"Pass api_key=... or set ${env_api_key}."
)
self._api_key = resolved_key
self.timeout_s = float(timeout) if timeout and float(timeout) > 0 else 30.0
self.max_retries = int(max_retries) if max_retries and int(max_retries) >= 0 else 3
self.backoff_base_s = float(backoff_base_s) if backoff_base_s and float(backoff_base_s) > 0 else 0.5
self.backoff_max_s = float(backoff_max_s) if backoff_max_s and float(backoff_max_s) > 0 else 8.0
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
if self.provider == "cohere":
headers.setdefault("Cohere-Version", "2022-12-06")
self._client = httpx.Client(
base_url=self.api_base,
headers=headers,
timeout=self.timeout_s,
)
def close(self) -> None:
try:
self._client.close()
except Exception: # pragma: no cover - defensive
return
def _sleep_backoff(self, attempt: int, *, retry_after_s: float | None = None) -> None:
if retry_after_s is not None and retry_after_s > 0:
time.sleep(min(float(retry_after_s), self.backoff_max_s))
return
exp = self.backoff_base_s * (2**attempt)
jitter = random.uniform(0, min(0.5, self.backoff_base_s))
time.sleep(min(self.backoff_max_s, exp + jitter))
@staticmethod
def _parse_retry_after_seconds(headers: Mapping[str, str]) -> float | None:
value = (headers.get("Retry-After") or "").strip()
if not value:
return None
try:
return float(value)
except ValueError:
return None
@staticmethod
def _should_retry_status(status_code: int) -> bool:
return status_code == 429 or 500 <= status_code <= 599
def _request_json(self, payload: Mapping[str, Any]) -> Mapping[str, Any]:
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
response = self._client.post(self.endpoint, json=dict(payload))
except Exception as exc: # httpx is optional at import-time
last_exc = exc
if attempt < self.max_retries:
self._sleep_backoff(attempt)
continue
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' after "
f"{self.max_retries + 1} attempts: {type(exc).__name__}: {exc}"
) from exc
status = int(getattr(response, "status_code", 0) or 0)
if status >= 400:
body_preview = ""
try:
body_preview = (response.text or "").strip()
except Exception:
body_preview = ""
if len(body_preview) > 300:
body_preview = body_preview[:300] + ""
if self._should_retry_status(status) and attempt < self.max_retries:
retry_after = self._parse_retry_after_seconds(response.headers)
logger.warning(
"Rerank request to %s%s failed with HTTP %s (attempt %s/%s). Retrying…",
self.api_base,
self.endpoint,
status,
attempt + 1,
self.max_retries + 1,
)
self._sleep_backoff(attempt, retry_after_s=retry_after)
continue
if status in {401, 403}:
raise RuntimeError(
f"Rerank request unauthorized for provider '{self.provider}' (HTTP {status}). "
"Check your API key."
)
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}' (HTTP {status}). "
f"Response: {body_preview or '<empty>'}"
)
try:
data = response.json()
except Exception as exc:
raise RuntimeError(
f"Rerank response from provider '{self.provider}' is not valid JSON: "
f"{type(exc).__name__}: {exc}"
) from exc
if not isinstance(data, dict):
raise RuntimeError(
f"Rerank response from provider '{self.provider}' must be a JSON object; "
f"got {type(data).__name__}"
)
return data
raise RuntimeError(
f"Rerank request failed for provider '{self.provider}'. Last error: {last_exc}"
)
@staticmethod
def _extract_scores_from_results(results: Any, expected: int) -> list[float]:
if not isinstance(results, list):
raise RuntimeError(f"Invalid rerank response: 'results' must be a list, got {type(results).__name__}")
scores: list[float] = [0.0 for _ in range(expected)]
filled = 0
for item in results:
if not isinstance(item, dict):
continue
idx = item.get("index")
score = item.get("relevance_score", item.get("score"))
if idx is None or score is None:
continue
try:
idx_int = int(idx)
score_f = float(score)
except (TypeError, ValueError):
continue
if 0 <= idx_int < expected:
scores[idx_int] = score_f
filled += 1
if filled != expected:
raise RuntimeError(
f"Rerank response contained {filled}/{expected} scored documents; "
"ensure top_n matches the number of documents."
)
return scores
def _build_payload(self, *, query: str, documents: Sequence[str]) -> Mapping[str, Any]:
payload: dict[str, Any] = {
"model": self.model_name,
"query": query,
"documents": list(documents),
"top_n": len(documents),
"return_documents": False,
}
return payload
def _rerank_one_query(self, *, query: str, documents: Sequence[str]) -> list[float]:
if not documents:
return []
payload = self._build_payload(query=query, documents=documents)
data = self._request_json(payload)
results = data.get("results")
return self._extract_scores_from_results(results, expected=len(documents))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32, # noqa: ARG002 - kept for BaseReranker compatibility
) -> list[float]:
if not pairs:
return []
grouped: dict[str, list[tuple[int, str]]] = {}
for idx, (query, doc) in enumerate(pairs):
grouped.setdefault(str(query), []).append((idx, str(doc)))
scores: list[float] = [0.0 for _ in range(len(pairs))]
for query, items in grouped.items():
documents = [doc for _, doc in items]
query_scores = self._rerank_one_query(query=query, documents=documents)
for (orig_idx, _), score in zip(items, query_scores):
scores[orig_idx] = float(score)
return scores

View File

@@ -0,0 +1,36 @@
"""Base class for rerankers.
Defines the interface that all rerankers must implement.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence
class BaseReranker(ABC):
"""Base class for all rerankers.
All reranker implementations must inherit from this class and implement
the abstract methods to ensure a consistent interface.
"""
@abstractmethod
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs.
Args:
pairs: Sequence of (query, doc) string pairs to score.
batch_size: Batch size for scoring.
Returns:
List of scores (one per pair).
"""
...

View File

@@ -0,0 +1,138 @@
"""Factory for creating rerankers.
Provides a unified interface for instantiating different reranker backends.
"""
from __future__ import annotations
from typing import Any
from .base import BaseReranker
def check_reranker_available(backend: str) -> tuple[bool, str | None]:
"""Check whether a specific reranker backend can be used.
Notes:
- "legacy" uses sentence-transformers CrossEncoder (pip install codexlens[reranker-legacy]).
- "onnx" uses Optimum + ONNX Runtime (pip install codexlens[reranker] or codexlens[reranker-onnx]).
- "api" uses a remote reranking HTTP API (requires httpx).
- "litellm" uses `ccw-litellm` for unified access to LLM providers.
"""
backend = (backend or "").strip().lower()
if backend == "legacy":
from .legacy import check_cross_encoder_available
return check_cross_encoder_available()
if backend == "onnx":
from .onnx_reranker import check_onnx_reranker_available
return check_onnx_reranker_available()
if backend == "litellm":
try:
import ccw_litellm # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"ccw-litellm not available: {exc}. Install with: pip install ccw-litellm",
)
try:
from .litellm_reranker import LiteLLMReranker # noqa: F401
except Exception as exc: # pragma: no cover - defensive
return False, f"LiteLLM reranker backend not available: {exc}"
return True, None
if backend == "api":
from .api_reranker import check_httpx_available
return check_httpx_available()
return False, (
f"Invalid reranker backend: {backend}. "
"Must be 'onnx', 'api', 'litellm', or 'legacy'."
)
def get_reranker(
backend: str = "onnx",
model_name: str | None = None,
*,
device: str | None = None,
**kwargs: Any,
) -> BaseReranker:
"""Factory function to create reranker based on backend.
Args:
backend: Reranker backend to use. Options:
- "onnx": Optimum + onnxruntime backend (default)
- "api": HTTP API backend (remote providers)
- "litellm": LiteLLM backend (LLM-based, experimental)
- "legacy": sentence-transformers CrossEncoder backend (optional)
model_name: Model identifier for model-based backends. Defaults depend on backend:
- onnx: Xenova/ms-marco-MiniLM-L-6-v2
- api: BAAI/bge-reranker-v2-m3 (SiliconFlow)
- legacy: cross-encoder/ms-marco-MiniLM-L-6-v2
- litellm: default
device: Optional device string for backends that support it (legacy only).
**kwargs: Additional backend-specific arguments.
Returns:
BaseReranker: Configured reranker instance.
Raises:
ValueError: If backend is not recognized.
ImportError: If required backend dependencies are not installed or backend is unavailable.
"""
backend = (backend or "").strip().lower()
if backend == "onnx":
ok, err = check_reranker_available("onnx")
if not ok:
raise ImportError(err)
from .onnx_reranker import ONNXReranker
resolved_model_name = (model_name or "").strip() or ONNXReranker.DEFAULT_MODEL
_ = device # Device selection is managed via ONNX Runtime providers.
return ONNXReranker(model_name=resolved_model_name, **kwargs)
if backend == "legacy":
ok, err = check_reranker_available("legacy")
if not ok:
raise ImportError(err)
from .legacy import CrossEncoderReranker
resolved_model_name = (model_name or "").strip() or "cross-encoder/ms-marco-MiniLM-L-6-v2"
return CrossEncoderReranker(model_name=resolved_model_name, device=device)
if backend == "litellm":
ok, err = check_reranker_available("litellm")
if not ok:
raise ImportError(err)
from .litellm_reranker import LiteLLMReranker
_ = device # Device selection is not applicable to remote LLM backends.
resolved_model_name = (model_name or "").strip() or "default"
return LiteLLMReranker(model=resolved_model_name, **kwargs)
if backend == "api":
ok, err = check_reranker_available("api")
if not ok:
raise ImportError(err)
from .api_reranker import APIReranker
_ = device # Device selection is not applicable to remote HTTP backends.
resolved_model_name = (model_name or "").strip() or None
return APIReranker(model_name=resolved_model_name, **kwargs)
raise ValueError(
f"Unknown backend: {backend}. Supported backends: 'onnx', 'api', 'litellm', 'legacy'"
)

View File

@@ -1,6 +1,6 @@
"""Optional cross-encoder reranker for second-stage search ranking.
"""Legacy sentence-transformers cross-encoder reranker.
Install with: pip install codexlens[reranker]
Install with: pip install codexlens[reranker-legacy]
"""
from __future__ import annotations
@@ -9,6 +9,8 @@ import logging
import threading
from typing import List, Sequence, Tuple
from .base import BaseReranker
logger = logging.getLogger(__name__)
try:
@@ -25,10 +27,14 @@ except ImportError as exc: # pragma: no cover - optional dependency
def check_cross_encoder_available() -> tuple[bool, str | None]:
if CROSS_ENCODER_AVAILABLE:
return True, None
return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]"
return (
False,
_import_error
or "sentence-transformers not available. Install with: pip install codexlens[reranker-legacy]",
)
class CrossEncoderReranker:
class CrossEncoderReranker(BaseReranker):
"""Cross-encoder reranker with lazy model loading."""
def __init__(self, model_name: str, *, device: str | None = None) -> None:
@@ -83,4 +89,3 @@ class CrossEncoderReranker:
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
return [float(s) for s in scores]

View File

@@ -0,0 +1,214 @@
"""Experimental LiteLLM reranker backend.
This module provides :class:`LiteLLMReranker`, which uses an LLM to score the
relevance of a single (query, document) pair per request.
Notes:
- This backend is experimental and may be slow/expensive compared to local
rerankers.
- It relies on `ccw-litellm` for a unified LLM API across providers.
"""
from __future__ import annotations
import json
import logging
import re
import threading
import time
from typing import Any, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
_NUMBER_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")
def _coerce_score_to_unit_interval(score: float) -> float:
"""Coerce a numeric score into [0, 1].
The prompt asks for a float in [0, 1], but some models may respond with 0-10
or 0-100 scales. This function attempts a conservative normalization.
"""
if 0.0 <= score <= 1.0:
return score
if 0.0 <= score <= 10.0:
return score / 10.0
if 0.0 <= score <= 100.0:
return score / 100.0
return max(0.0, min(1.0, score))
def _extract_score(text: str) -> float | None:
"""Extract a numeric relevance score from an LLM response."""
content = (text or "").strip()
if not content:
return None
# Prefer JSON if present.
if "{" in content and "}" in content:
try:
start = content.index("{")
end = content.rindex("}") + 1
payload = json.loads(content[start:end])
if isinstance(payload, dict) and "score" in payload:
return float(payload["score"])
except Exception:
pass
match = _NUMBER_RE.search(content)
if not match:
return None
try:
return float(match.group(0))
except ValueError:
return None
class LiteLLMReranker(BaseReranker):
"""Experimental reranker that uses a LiteLLM-compatible model.
This reranker scores each (query, doc) pair in isolation (single-pair mode)
to improve prompt reliability across providers.
"""
_SYSTEM_PROMPT = (
"You are a relevance scoring assistant.\n"
"Given a search query and a document snippet, output a single numeric "
"relevance score between 0 and 1.\n\n"
"Scoring guidance:\n"
"- 1.0: The document directly answers the query.\n"
"- 0.5: The document is partially relevant.\n"
"- 0.0: The document is unrelated.\n\n"
"Output requirements:\n"
"- Output ONLY the number (e.g., 0.73).\n"
"- Do not include any other text."
)
def __init__(
self,
model: str = "default",
*,
requests_per_minute: float | None = None,
min_interval_seconds: float | None = None,
default_score: float = 0.0,
max_doc_chars: int = 8000,
**litellm_kwargs: Any,
) -> None:
"""Initialize the reranker.
Args:
model: Model name from ccw-litellm configuration (default: "default").
requests_per_minute: Optional rate limit in requests per minute.
min_interval_seconds: Optional minimum interval between requests. If set,
it takes precedence over requests_per_minute.
default_score: Score to use when an API call fails or parsing fails.
max_doc_chars: Maximum number of document characters to include in the prompt.
**litellm_kwargs: Passed through to `ccw_litellm.LiteLLMClient`.
Raises:
ImportError: If ccw-litellm is not installed.
ValueError: If model is blank.
"""
self.model_name = (model or "").strip()
if not self.model_name:
raise ValueError("model cannot be blank")
self.default_score = float(default_score)
self.max_doc_chars = int(max_doc_chars) if int(max_doc_chars) > 0 else 0
if min_interval_seconds is not None:
self._min_interval_seconds = max(0.0, float(min_interval_seconds))
elif requests_per_minute is not None and float(requests_per_minute) > 0:
self._min_interval_seconds = 60.0 / float(requests_per_minute)
else:
self._min_interval_seconds = 0.0
# Prefer deterministic output by default; allow overrides via kwargs.
litellm_kwargs = dict(litellm_kwargs)
litellm_kwargs.setdefault("temperature", 0.0)
litellm_kwargs.setdefault("max_tokens", 16)
try:
from ccw_litellm import ChatMessage, LiteLLMClient
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError(
"ccw-litellm not installed. Install with: pip install ccw-litellm"
) from exc
self._ChatMessage = ChatMessage
self._client = LiteLLMClient(model=self.model_name, **litellm_kwargs)
self._lock = threading.RLock()
self._last_request_at = 0.0
def _sanitize_text(self, text: str) -> str:
# Keep consistent with LiteLLMEmbedderWrapper workaround.
if text.startswith("import"):
return " " + text
return text
def _rate_limit(self) -> None:
if self._min_interval_seconds <= 0:
return
with self._lock:
now = time.monotonic()
elapsed = now - self._last_request_at
if elapsed < self._min_interval_seconds:
time.sleep(self._min_interval_seconds - elapsed)
self._last_request_at = time.monotonic()
def _build_user_prompt(self, query: str, doc: str) -> str:
sanitized_query = self._sanitize_text(query or "")
sanitized_doc = self._sanitize_text(doc or "")
if self.max_doc_chars and len(sanitized_doc) > self.max_doc_chars:
sanitized_doc = sanitized_doc[: self.max_doc_chars]
return (
"Query:\n"
f"{sanitized_query}\n\n"
"Document:\n"
f"{sanitized_doc}\n\n"
"Return the relevance score (0 to 1) as a single number:"
)
def _score_single_pair(self, query: str, doc: str) -> float:
messages = [
self._ChatMessage(role="system", content=self._SYSTEM_PROMPT),
self._ChatMessage(role="user", content=self._build_user_prompt(query, doc)),
]
try:
self._rate_limit()
response = self._client.chat(messages)
except Exception as exc:
logger.debug("LiteLLM reranker request failed: %s", exc)
return self.default_score
raw = getattr(response, "content", "") or ""
score = _extract_score(raw)
if score is None:
logger.debug("Failed to parse LiteLLM reranker score from response: %r", raw)
return self.default_score
return _coerce_score_to_unit_interval(float(score))
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with per-pair LLM calls."""
if not pairs:
return []
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for i in range(0, len(pairs), bs):
batch = pairs[i : i + bs]
for query, doc in batch:
scores.append(self._score_single_pair(query, doc))
return scores

View File

@@ -0,0 +1,268 @@
"""Optimum + ONNX Runtime reranker backend.
This reranker uses Hugging Face Optimum's ONNXRuntime backend for sequence
classification models. It is designed to run without requiring PyTorch at
runtime by using numpy tensors and ONNX Runtime execution providers.
Install (CPU):
pip install onnxruntime optimum[onnxruntime] transformers
"""
from __future__ import annotations
import logging
import threading
from typing import Any, Iterable, Sequence
from .base import BaseReranker
logger = logging.getLogger(__name__)
def check_onnx_reranker_available() -> tuple[bool, str | None]:
"""Check whether Optimum + ONNXRuntime reranker dependencies are available."""
try:
import numpy # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return False, f"numpy not available: {exc}. Install with: pip install numpy"
try:
import onnxruntime # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"onnxruntime not available: {exc}. Install with: pip install onnxruntime",
)
try:
from optimum.onnxruntime import ORTModelForSequenceClassification # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"optimum[onnxruntime] not available: {exc}. Install with: pip install optimum[onnxruntime]",
)
try:
from transformers import AutoTokenizer # noqa: F401
except ImportError as exc: # pragma: no cover - optional dependency
return (
False,
f"transformers not available: {exc}. Install with: pip install transformers",
)
return True, None
def _iter_batches(items: Sequence[Any], batch_size: int) -> Iterable[Sequence[Any]]:
for i in range(0, len(items), batch_size):
yield items[i : i + batch_size]
class ONNXReranker(BaseReranker):
"""Cross-encoder reranker using Optimum + ONNX Runtime with lazy loading."""
DEFAULT_MODEL = "Xenova/ms-marco-MiniLM-L-6-v2"
def __init__(
self,
model_name: str | None = None,
*,
use_gpu: bool = True,
providers: list[Any] | None = None,
max_length: int | None = None,
) -> None:
self.model_name = (model_name or self.DEFAULT_MODEL).strip()
if not self.model_name:
raise ValueError("model_name cannot be blank")
self.use_gpu = bool(use_gpu)
self.providers = providers
self.max_length = int(max_length) if max_length is not None else None
self._tokenizer: Any | None = None
self._model: Any | None = None
self._model_input_names: set[str] | None = None
self._lock = threading.RLock()
def _load_model(self) -> None:
if self._model is not None and self._tokenizer is not None:
return
ok, err = check_onnx_reranker_available()
if not ok:
raise ImportError(err)
with self._lock:
if self._model is not None and self._tokenizer is not None:
return
from inspect import signature
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
if self.providers is None:
from ..gpu_support import get_optimal_providers
# Include device_id options for DirectML/CUDA selection when available.
self.providers = get_optimal_providers(
use_gpu=self.use_gpu, with_device_options=True
)
# Some Optimum versions accept `providers`, others accept a single `provider`.
# Prefer passing the full providers list, with a conservative fallback.
model_kwargs: dict[str, Any] = {}
try:
params = signature(ORTModelForSequenceClassification.from_pretrained).parameters
if "providers" in params:
model_kwargs["providers"] = self.providers
elif "provider" in params:
provider_name = "CPUExecutionProvider"
if self.providers:
first = self.providers[0]
provider_name = first[0] if isinstance(first, tuple) else str(first)
model_kwargs["provider"] = provider_name
except Exception:
model_kwargs = {}
try:
self._model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
**model_kwargs,
)
except TypeError:
# Fallback for older Optimum versions: retry without provider arguments.
self._model = ORTModelForSequenceClassification.from_pretrained(self.model_name)
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
# Cache model input names to filter tokenizer outputs defensively.
input_names: set[str] | None = None
for attr in ("input_names", "model_input_names"):
names = getattr(self._model, attr, None)
if isinstance(names, (list, tuple)) and names:
input_names = {str(n) for n in names}
break
if input_names is None:
try:
session = getattr(self._model, "model", None)
if session is not None and hasattr(session, "get_inputs"):
input_names = {i.name for i in session.get_inputs()}
except Exception:
input_names = None
self._model_input_names = input_names
@staticmethod
def _sigmoid(x: "Any") -> "Any":
import numpy as np
x = np.clip(x, -50.0, 50.0)
return 1.0 / (1.0 + np.exp(-x))
@staticmethod
def _select_relevance_logit(logits: "Any") -> "Any":
import numpy as np
arr = np.asarray(logits)
if arr.ndim == 0:
return arr.reshape(1)
if arr.ndim == 1:
return arr
if arr.ndim >= 2:
# Common cases:
# - Regression: (batch, 1)
# - Binary classification: (batch, 2)
if arr.shape[-1] == 1:
return arr[..., 0]
if arr.shape[-1] == 2:
# Convert 2-logit softmax into a single logit via difference.
return arr[..., 1] - arr[..., 0]
return arr.max(axis=-1)
return arr.reshape(-1)
def _tokenize_batch(self, batch: Sequence[tuple[str, str]]) -> dict[str, Any]:
if self._tokenizer is None:
raise RuntimeError("Tokenizer not loaded") # pragma: no cover - defensive
queries = [q for q, _ in batch]
docs = [d for _, d in batch]
tokenizer_kwargs: dict[str, Any] = {
"text": queries,
"text_pair": docs,
"padding": True,
"truncation": True,
"return_tensors": "np",
}
max_len = self.max_length
if max_len is None:
try:
model_max = int(getattr(self._tokenizer, "model_max_length", 0) or 0)
if 0 < model_max < 10_000:
max_len = model_max
else:
max_len = 512
except Exception:
max_len = 512
if max_len is not None and max_len > 0:
tokenizer_kwargs["max_length"] = int(max_len)
encoded = self._tokenizer(**tokenizer_kwargs)
inputs = dict(encoded)
# Some models do not accept token_type_ids; filter to known input names if available.
if self._model_input_names:
inputs = {k: v for k, v in inputs.items() if k in self._model_input_names}
return inputs
def _forward_logits(self, inputs: dict[str, Any]) -> Any:
if self._model is None:
raise RuntimeError("Model not loaded") # pragma: no cover - defensive
outputs = self._model(**inputs)
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, dict) and "logits" in outputs:
return outputs["logits"]
if isinstance(outputs, (list, tuple)) and outputs:
return outputs[0]
raise RuntimeError("Unexpected model output format") # pragma: no cover - defensive
def score_pairs(
self,
pairs: Sequence[tuple[str, str]],
*,
batch_size: int = 32,
) -> list[float]:
"""Score (query, doc) pairs with sigmoid-normalized outputs in [0, 1]."""
if not pairs:
return []
self._load_model()
if self._model is None or self._tokenizer is None: # pragma: no cover - defensive
return []
import numpy as np
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
scores: list[float] = []
for batch in _iter_batches(list(pairs), bs):
inputs = self._tokenize_batch(batch)
logits = self._forward_logits(inputs)
rel_logits = self._select_relevance_logit(logits)
probs = self._sigmoid(rel_logits)
probs = np.clip(probs, 0.0, 1.0)
scores.extend([float(p) for p in probs.reshape(-1).tolist()])
if len(scores) != len(pairs):
logger.debug(
"ONNX reranker produced %d scores for %d pairs", len(scores), len(pairs)
)
return scores[: len(pairs)]
return scores

View File

@@ -0,0 +1,17 @@
"""File watcher module for real-time index updates."""
from .events import ChangeType, FileEvent, IndexResult, WatcherConfig, WatcherStats
from .file_watcher import FileWatcher
from .incremental_indexer import IncrementalIndexer
from .manager import WatcherManager
__all__ = [
"ChangeType",
"FileEvent",
"IndexResult",
"WatcherConfig",
"WatcherStats",
"FileWatcher",
"IncrementalIndexer",
"WatcherManager",
]

View File

@@ -0,0 +1,54 @@
"""Event types for file watcher."""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import List, Optional, Set
class ChangeType(Enum):
"""Type of file system change."""
CREATED = "created"
MODIFIED = "modified"
DELETED = "deleted"
MOVED = "moved"
@dataclass
class FileEvent:
"""A file system change event."""
path: Path
change_type: ChangeType
timestamp: float
old_path: Optional[Path] = None # For MOVED events
@dataclass
class WatcherConfig:
"""Configuration for file watcher."""
debounce_ms: int = 1000
ignored_patterns: Set[str] = field(default_factory=lambda: {
".git", ".venv", "venv", "node_modules",
"__pycache__", ".codexlens", ".idea", ".vscode",
})
languages: Optional[List[str]] = None # None = all supported
@dataclass
class IndexResult:
"""Result of processing file changes."""
files_indexed: int = 0
files_removed: int = 0
symbols_added: int = 0
errors: List[str] = field(default_factory=list)
@dataclass
class WatcherStats:
"""Runtime statistics for watcher."""
files_watched: int = 0
events_processed: int = 0
last_event_time: Optional[float] = None
is_running: bool = False

View File

@@ -0,0 +1,245 @@
"""File system watcher using watchdog library."""
from __future__ import annotations
import logging
import threading
import time
from pathlib import Path
from typing import Callable, Dict, List, Optional
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from .events import ChangeType, FileEvent, WatcherConfig
from ..config import Config
logger = logging.getLogger(__name__)
class _CodexLensHandler(FileSystemEventHandler):
"""Internal handler for watchdog events."""
def __init__(
self,
watcher: "FileWatcher",
on_event: Callable[[FileEvent], None],
) -> None:
super().__init__()
self._watcher = watcher
self._on_event = on_event
def on_created(self, event) -> None:
if event.is_directory:
return
self._emit(event.src_path, ChangeType.CREATED)
def on_modified(self, event) -> None:
if event.is_directory:
return
self._emit(event.src_path, ChangeType.MODIFIED)
def on_deleted(self, event) -> None:
if event.is_directory:
return
self._emit(event.src_path, ChangeType.DELETED)
def on_moved(self, event) -> None:
if event.is_directory:
return
self._emit(event.dest_path, ChangeType.MOVED, old_path=event.src_path)
def _emit(
self,
path: str,
change_type: ChangeType,
old_path: Optional[str] = None,
) -> None:
path_obj = Path(path)
# Filter out files that should not be indexed
if not self._watcher._should_index_file(path_obj):
return
event = FileEvent(
path=path_obj,
change_type=change_type,
timestamp=time.time(),
old_path=Path(old_path) if old_path else None,
)
self._on_event(event)
class FileWatcher:
"""File system watcher for monitoring directory changes.
Uses watchdog library for cross-platform file system monitoring.
Events are forwarded to the on_changes callback.
Example:
def handle_changes(events: List[FileEvent]) -> None:
for event in events:
print(f"{event.change_type}: {event.path}")
watcher = FileWatcher(Path("."), WatcherConfig(), handle_changes)
watcher.start()
watcher.wait() # Block until stopped
"""
def __init__(
self,
root_path: Path,
config: WatcherConfig,
on_changes: Callable[[List[FileEvent]], None],
) -> None:
"""Initialize file watcher.
Args:
root_path: Directory to watch recursively
config: Watcher configuration
on_changes: Callback invoked with batched events
"""
self.root_path = Path(root_path).resolve()
self.config = config
self.on_changes = on_changes
self._observer: Optional[Observer] = None
self._running = False
self._stop_event = threading.Event()
self._lock = threading.RLock()
# Event queue for batching
self._event_queue: List[FileEvent] = []
self._queue_lock = threading.Lock()
# Debounce thread
self._debounce_thread: Optional[threading.Thread] = None
# Config instance for language checking
self._codexlens_config = Config()
def _should_index_file(self, path: Path) -> bool:
"""Check if file should be indexed based on extension and ignore patterns.
Args:
path: File path to check
Returns:
True if file should be indexed, False otherwise
"""
# Check against ignore patterns
parts = path.parts
for pattern in self.config.ignored_patterns:
if pattern in parts:
return False
# Check extension against supported languages
language = self._codexlens_config.language_for_path(path)
return language is not None
def _on_raw_event(self, event: FileEvent) -> None:
"""Handle raw event from watchdog handler."""
with self._queue_lock:
self._event_queue.append(event)
# Debouncing is handled by background thread
def _debounce_loop(self) -> None:
"""Background thread for debounced event batching."""
while self._running:
time.sleep(self.config.debounce_ms / 1000.0)
self._flush_events()
def _flush_events(self) -> None:
"""Flush queued events with deduplication."""
with self._queue_lock:
if not self._event_queue:
return
# Deduplicate: keep latest event per path
deduped: Dict[Path, FileEvent] = {}
for event in self._event_queue:
deduped[event.path] = event
events = list(deduped.values())
self._event_queue.clear()
if events:
try:
self.on_changes(events)
except Exception as exc:
logger.error("Error in on_changes callback: %s", exc)
def start(self) -> None:
"""Start watching the directory.
Non-blocking. Use wait() to block until stopped.
"""
with self._lock:
if self._running:
logger.warning("Watcher already running")
return
if not self.root_path.exists():
raise ValueError(f"Root path does not exist: {self.root_path}")
self._observer = Observer()
handler = _CodexLensHandler(self, self._on_raw_event)
self._observer.schedule(handler, str(self.root_path), recursive=True)
self._running = True
self._stop_event.clear()
self._observer.start()
# Start debounce thread
self._debounce_thread = threading.Thread(
target=self._debounce_loop,
daemon=True,
name="FileWatcher-Debounce",
)
self._debounce_thread.start()
logger.info("Started watching: %s", self.root_path)
def stop(self) -> None:
"""Stop watching the directory.
Gracefully stops the observer and flushes remaining events.
"""
with self._lock:
if not self._running:
return
self._running = False
self._stop_event.set()
if self._observer:
self._observer.stop()
self._observer.join(timeout=5.0)
self._observer = None
# Wait for debounce thread to finish
if self._debounce_thread and self._debounce_thread.is_alive():
self._debounce_thread.join(timeout=2.0)
self._debounce_thread = None
# Flush any remaining events
self._flush_events()
logger.info("Stopped watching: %s", self.root_path)
def wait(self) -> None:
"""Block until watcher is stopped.
Use Ctrl+C or call stop() from another thread to unblock.
"""
try:
while self._running:
self._stop_event.wait(timeout=1.0)
except KeyboardInterrupt:
logger.info("Received interrupt, stopping watcher...")
self.stop()
@property
def is_running(self) -> bool:
"""Check if watcher is currently running."""
return self._running

View File

@@ -0,0 +1,359 @@
"""Incremental indexer for processing file changes."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from codexlens.config import Config
from codexlens.parsers.factory import ParserFactory
from codexlens.storage.dir_index import DirIndexStore
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
from .events import ChangeType, FileEvent, IndexResult
logger = logging.getLogger(__name__)
@dataclass
class FileIndexResult:
"""Result of indexing a single file."""
path: Path
symbols_count: int
success: bool
error: Optional[str] = None
class IncrementalIndexer:
"""Incremental indexer for processing file change events.
Processes file events (create, modify, delete, move) and updates
the corresponding index databases incrementally.
Reuses existing infrastructure:
- ParserFactory for symbol extraction
- DirIndexStore for per-directory storage
- GlobalSymbolIndex for cross-file symbols
- PathMapper for source-to-index path conversion
Example:
indexer = IncrementalIndexer(registry, mapper, config)
result = indexer.process_changes([
FileEvent(Path("foo.py"), ChangeType.MODIFIED, time.time()),
])
print(f"Indexed {result.files_indexed} files")
"""
def __init__(
self,
registry: RegistryStore,
mapper: PathMapper,
config: Optional[Config] = None,
) -> None:
"""Initialize incremental indexer.
Args:
registry: Global project registry
mapper: Path mapper for source-to-index conversion
config: CodexLens configuration (uses defaults if None)
"""
self.registry = registry
self.mapper = mapper
self.config = config or Config()
self.parser_factory = ParserFactory(self.config)
self._global_index: Optional[GlobalSymbolIndex] = None
self._dir_stores: dict[Path, DirIndexStore] = {}
self._lock = __import__("threading").RLock()
def _get_global_index(self, index_root: Path) -> Optional[GlobalSymbolIndex]:
"""Get or create global symbol index."""
if not self.config.global_symbol_index_enabled:
return None
if self._global_index is None:
global_db_path = index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
if global_db_path.exists():
self._global_index = GlobalSymbolIndex(global_db_path)
return self._global_index
def _get_dir_store(self, dir_path: Path) -> Optional[DirIndexStore]:
"""Get DirIndexStore for a directory, if indexed."""
with self._lock:
if dir_path in self._dir_stores:
return self._dir_stores[dir_path]
index_db = self.mapper.source_to_index_db(dir_path)
if not index_db.exists():
logger.debug("No index found for directory: %s", dir_path)
return None
# Get index root for global index
index_root = self.mapper.source_to_index_dir(
self.mapper.get_project_root(dir_path) or dir_path
)
global_index = self._get_global_index(index_root)
store = DirIndexStore(
index_db,
config=self.config,
global_index=global_index,
)
self._dir_stores[dir_path] = store
return store
def process_changes(self, events: List[FileEvent]) -> IndexResult:
"""Process a batch of file change events.
Args:
events: List of file events to process
Returns:
IndexResult with statistics
"""
result = IndexResult()
for event in events:
try:
if event.change_type == ChangeType.CREATED:
file_result = self._index_file(event.path)
if file_result.success:
result.files_indexed += 1
result.symbols_added += file_result.symbols_count
else:
result.errors.append(file_result.error or f"Failed to index: {event.path}")
elif event.change_type == ChangeType.MODIFIED:
file_result = self._index_file(event.path)
if file_result.success:
result.files_indexed += 1
result.symbols_added += file_result.symbols_count
else:
result.errors.append(file_result.error or f"Failed to index: {event.path}")
elif event.change_type == ChangeType.DELETED:
self._remove_file(event.path)
result.files_removed += 1
elif event.change_type == ChangeType.MOVED:
# Remove from old location, add at new location
if event.old_path:
self._remove_file(event.old_path)
result.files_removed += 1
file_result = self._index_file(event.path)
if file_result.success:
result.files_indexed += 1
result.symbols_added += file_result.symbols_count
else:
result.errors.append(file_result.error or f"Failed to index: {event.path}")
except Exception as exc:
error_msg = f"Error processing {event.path}: {type(exc).__name__}: {exc}"
logger.error(error_msg)
result.errors.append(error_msg)
return result
def _index_file(self, path: Path) -> FileIndexResult:
"""Index a single file.
Args:
path: Path to the file to index
Returns:
FileIndexResult with status
"""
path = Path(path).resolve()
# Check if file exists
if not path.exists():
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=f"File not found: {path}",
)
# Check if language is supported
language = self.config.language_for_path(path)
if not language:
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=f"Unsupported language for: {path}",
)
# Get directory store
dir_path = path.parent
store = self._get_dir_store(dir_path)
if store is None:
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=f"Directory not indexed: {dir_path}",
)
# Read file content with fallback encodings
try:
content = path.read_text(encoding="utf-8")
except UnicodeDecodeError:
logger.debug("UTF-8 decode failed for %s, using fallback with errors='ignore'", path)
try:
content = path.read_text(encoding="utf-8", errors="ignore")
except Exception as exc:
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=f"Failed to read file: {exc}",
)
except Exception as exc:
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=f"Failed to read file: {exc}",
)
# Parse symbols
try:
parser = self.parser_factory.get_parser(language)
indexed_file = parser.parse(content, path)
except Exception as exc:
error_msg = f"Failed to parse {path}: {type(exc).__name__}: {exc}"
logger.error(error_msg)
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=error_msg,
)
# Update store with retry logic for transient database errors
max_retries = 3
for attempt in range(max_retries):
try:
store.add_file(
name=path.name,
full_path=str(path),
content=content,
language=language,
symbols=indexed_file.symbols,
relationships=indexed_file.relationships,
)
# Update merkle root
store.update_merkle_root()
logger.debug("Indexed file: %s (%d symbols)", path, len(indexed_file.symbols))
return FileIndexResult(
path=path,
symbols_count=len(indexed_file.symbols),
success=True,
)
except __import__("sqlite3").OperationalError as exc:
# Transient database errors (e.g., database locked)
if attempt < max_retries - 1:
import time
wait_time = 0.1 * (2 ** attempt) # Exponential backoff
logger.debug("Database operation failed (attempt %d/%d), retrying in %.2fs: %s",
attempt + 1, max_retries, wait_time, exc)
time.sleep(wait_time)
continue
else:
error_msg = f"Failed to store {path} after {max_retries} attempts: {exc}"
logger.error(error_msg)
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=error_msg,
)
except Exception as exc:
error_msg = f"Failed to store {path}: {type(exc).__name__}: {exc}"
logger.error(error_msg)
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error=error_msg,
)
# Should never reach here
return FileIndexResult(
path=path,
symbols_count=0,
success=False,
error="Unexpected error in indexing loop",
)
def _remove_file(self, path: Path) -> bool:
"""Remove a file from the index.
Args:
path: Path to the file to remove
Returns:
True if removed successfully
"""
path = Path(path).resolve()
dir_path = path.parent
store = self._get_dir_store(dir_path)
if store is None:
logger.debug("Cannot remove file, directory not indexed: %s", dir_path)
return False
# Retry logic for transient database errors
max_retries = 3
for attempt in range(max_retries):
try:
store.remove_file(str(path))
store.update_merkle_root()
logger.debug("Removed file from index: %s", path)
return True
except __import__("sqlite3").OperationalError as exc:
# Transient database errors (e.g., database locked)
if attempt < max_retries - 1:
import time
wait_time = 0.1 * (2 ** attempt) # Exponential backoff
logger.debug("Database operation failed (attempt %d/%d), retrying in %.2fs: %s",
attempt + 1, max_retries, wait_time, exc)
time.sleep(wait_time)
continue
else:
logger.error("Failed to remove %s after %d attempts: %s", path, max_retries, exc)
return False
except Exception as exc:
logger.error("Failed to remove %s: %s", path, exc)
return False
# Should never reach here
return False
def close(self) -> None:
"""Close all open stores."""
with self._lock:
for store in self._dir_stores.values():
try:
store.close()
except Exception:
pass
self._dir_stores.clear()
if self._global_index:
try:
self._global_index.close()
except Exception:
pass
self._global_index = None

View File

@@ -0,0 +1,194 @@
"""Watcher manager for coordinating file watching and incremental indexing."""
from __future__ import annotations
import logging
import signal
import threading
import time
from pathlib import Path
from typing import Callable, List, Optional
from codexlens.config import Config
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
from .events import FileEvent, IndexResult, WatcherConfig, WatcherStats
from .file_watcher import FileWatcher
from .incremental_indexer import IncrementalIndexer
logger = logging.getLogger(__name__)
class WatcherManager:
"""High-level manager for file watching and incremental indexing.
Coordinates FileWatcher and IncrementalIndexer with:
- Lifecycle management (start/stop)
- Signal handling (SIGINT/SIGTERM)
- Statistics tracking
- Graceful shutdown
"""
def __init__(
self,
root_path: Path,
config: Optional[Config] = None,
watcher_config: Optional[WatcherConfig] = None,
on_indexed: Optional[Callable[[IndexResult], None]] = None,
) -> None:
self.root_path = Path(root_path).resolve()
self.config = config or Config()
self.watcher_config = watcher_config or WatcherConfig()
self.on_indexed = on_indexed
self._registry: Optional[RegistryStore] = None
self._mapper: Optional[PathMapper] = None
self._watcher: Optional[FileWatcher] = None
self._indexer: Optional[IncrementalIndexer] = None
self._running = False
self._stop_event = threading.Event()
self._lock = threading.RLock()
# Statistics
self._stats = WatcherStats()
self._original_sigint = None
self._original_sigterm = None
def _handle_changes(self, events: List[FileEvent]) -> None:
"""Handle file change events from watcher."""
if not self._indexer or not events:
return
logger.info("Processing %d file changes", len(events))
result = self._indexer.process_changes(events)
# Update stats
self._stats.events_processed += len(events)
self._stats.last_event_time = time.time()
if result.files_indexed > 0 or result.files_removed > 0:
logger.info(
"Indexed %d files, removed %d files, %d errors",
result.files_indexed, result.files_removed, len(result.errors)
)
if self.on_indexed:
try:
self.on_indexed(result)
except Exception as exc:
logger.error("Error in on_indexed callback: %s", exc)
def _signal_handler(self, signum, frame) -> None:
"""Handle shutdown signals."""
logger.info("Received signal %d, stopping...", signum)
self.stop()
def _install_signal_handlers(self) -> None:
"""Install signal handlers for graceful shutdown."""
try:
self._original_sigint = signal.signal(signal.SIGINT, self._signal_handler)
if hasattr(signal, 'SIGTERM'):
self._original_sigterm = signal.signal(signal.SIGTERM, self._signal_handler)
except (ValueError, OSError):
# Signal handling not available (e.g., not main thread)
pass
def _restore_signal_handlers(self) -> None:
"""Restore original signal handlers."""
try:
if self._original_sigint is not None:
signal.signal(signal.SIGINT, self._original_sigint)
if self._original_sigterm is not None and hasattr(signal, 'SIGTERM'):
signal.signal(signal.SIGTERM, self._original_sigterm)
except (ValueError, OSError):
pass
def start(self) -> None:
"""Start watching and indexing."""
with self._lock:
if self._running:
logger.warning("WatcherManager already running")
return
# Validate path
if not self.root_path.exists():
raise ValueError(f"Root path does not exist: {self.root_path}")
# Initialize components
self._registry = RegistryStore()
self._registry.initialize()
self._mapper = PathMapper()
self._indexer = IncrementalIndexer(
self._registry, self._mapper, self.config
)
self._watcher = FileWatcher(
self.root_path, self.watcher_config, self._handle_changes
)
# Install signal handlers
self._install_signal_handlers()
# Start watcher
self._running = True
self._stats.is_running = True
self._stop_event.clear()
self._watcher.start()
logger.info("WatcherManager started for: %s", self.root_path)
def stop(self) -> None:
"""Stop watching and clean up."""
with self._lock:
if not self._running:
return
self._running = False
self._stats.is_running = False
self._stop_event.set()
# Stop watcher
if self._watcher:
self._watcher.stop()
self._watcher = None
# Close indexer
if self._indexer:
self._indexer.close()
self._indexer = None
# Close registry
if self._registry:
self._registry.close()
self._registry = None
# Restore signal handlers
self._restore_signal_handlers()
logger.info("WatcherManager stopped")
def wait(self) -> None:
"""Block until stopped."""
try:
while self._running:
self._stop_event.wait(timeout=1.0)
except KeyboardInterrupt:
logger.info("Interrupted, stopping...")
self.stop()
@property
def is_running(self) -> bool:
"""Check if manager is running."""
return self._running
def get_stats(self) -> WatcherStats:
"""Get runtime statistics."""
return WatcherStats(
files_watched=self._stats.files_watched,
events_processed=self._stats.events_processed,
last_event_time=self._stats.last_event_time,
is_running=self._running,
)

View File

@@ -0,0 +1,171 @@
"""Tests for APIReranker backend."""
from __future__ import annotations
import sys
import types
from typing import Any
import pytest
from codexlens.semantic.reranker import get_reranker
from codexlens.semantic.reranker.api_reranker import APIReranker
class DummyResponse:
def __init__(
self,
*,
status_code: int = 200,
json_data: Any = None,
text: str = "",
headers: dict[str, str] | None = None,
) -> None:
self.status_code = int(status_code)
self._json_data = json_data
self.text = text
self.headers = headers or {}
def json(self) -> Any:
return self._json_data
class DummyClient:
def __init__(self, *, base_url: str | None = None, headers: dict[str, str] | None = None, timeout: float | None = None) -> None:
self.base_url = base_url
self.headers = headers or {}
self.timeout = timeout
self.closed = False
self.calls: list[dict[str, Any]] = []
self._responses: list[DummyResponse] = []
def queue(self, response: DummyResponse) -> None:
self._responses.append(response)
def post(self, endpoint: str, *, json: dict[str, Any] | None = None) -> DummyResponse:
self.calls.append({"endpoint": endpoint, "json": json})
if not self._responses:
raise AssertionError("DummyClient has no queued responses")
return self._responses.pop(0)
def close(self) -> None:
self.closed = True
@pytest.fixture
def httpx_clients(monkeypatch: pytest.MonkeyPatch) -> list[DummyClient]:
clients: list[DummyClient] = []
dummy_httpx = types.ModuleType("httpx")
def Client(*, base_url: str | None = None, headers: dict[str, str] | None = None, timeout: float | None = None) -> DummyClient:
client = DummyClient(base_url=base_url, headers=headers, timeout=timeout)
clients.append(client)
return client
dummy_httpx.Client = Client
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
return clients
def test_api_reranker_requires_api_key(
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
) -> None:
monkeypatch.delenv("RERANKER_API_KEY", raising=False)
with pytest.raises(ValueError, match="Missing API key"):
APIReranker()
assert httpx_clients == []
def test_api_reranker_reads_api_key_from_env(
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
) -> None:
monkeypatch.setenv("RERANKER_API_KEY", "test-key")
reranker = APIReranker()
assert len(httpx_clients) == 1
assert httpx_clients[0].headers["Authorization"] == "Bearer test-key"
reranker.close()
assert httpx_clients[0].closed is True
def test_api_reranker_scores_pairs_siliconflow(
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
) -> None:
monkeypatch.delenv("RERANKER_API_KEY", raising=False)
reranker = APIReranker(api_key="k", provider="siliconflow")
client = httpx_clients[0]
client.queue(
DummyResponse(
json_data={
"results": [
{"index": 0, "relevance_score": 0.9},
{"index": 1, "relevance_score": 0.1},
]
}
)
)
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")])
assert scores == pytest.approx([0.9, 0.1])
assert client.calls[0]["endpoint"] == "/v1/rerank"
payload = client.calls[0]["json"]
assert payload["model"] == "BAAI/bge-reranker-v2-m3"
assert payload["query"] == "q"
assert payload["documents"] == ["d1", "d2"]
assert payload["top_n"] == 2
assert payload["return_documents"] is False
def test_api_reranker_retries_on_5xx(
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
) -> None:
monkeypatch.setenv("RERANKER_API_KEY", "k")
from codexlens.semantic.reranker import api_reranker as api_reranker_module
monkeypatch.setattr(api_reranker_module.time, "sleep", lambda *_args, **_kwargs: None)
reranker = APIReranker(max_retries=1)
client = httpx_clients[0]
client.queue(DummyResponse(status_code=500, text="oops", json_data={"error": "oops"}))
client.queue(
DummyResponse(
json_data={"results": [{"index": 0, "relevance_score": 0.7}]},
)
)
scores = reranker.score_pairs([("q", "d")])
assert scores == pytest.approx([0.7])
assert len(client.calls) == 2
def test_api_reranker_unauthorized_raises(
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
) -> None:
monkeypatch.setenv("RERANKER_API_KEY", "k")
reranker = APIReranker()
client = httpx_clients[0]
client.queue(DummyResponse(status_code=401, text="unauthorized"))
with pytest.raises(RuntimeError, match="unauthorized"):
reranker.score_pairs([("q", "d")])
def test_factory_api_backend_constructs_reranker(
monkeypatch: pytest.MonkeyPatch, httpx_clients: list[DummyClient]
) -> None:
monkeypatch.setenv("RERANKER_API_KEY", "k")
reranker = get_reranker(backend="api")
assert isinstance(reranker, APIReranker)
assert len(httpx_clients) == 1

View File

@@ -0,0 +1,139 @@
"""Tests for HybridSearchEngine reranker backend selection."""
from __future__ import annotations
import pytest
from codexlens.config import Config
from codexlens.search.hybrid_search import HybridSearchEngine
def test_get_cross_encoder_reranker_uses_factory_backend_legacy(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
calls: dict[str, object] = {}
def fake_check_reranker_available(backend: str):
calls["check_backend"] = backend
return True, None
sentinel = object()
def fake_get_reranker(*, backend: str, model_name=None, device=None, **kwargs):
calls["get_args"] = {
"backend": backend,
"model_name": model_name,
"device": device,
"kwargs": kwargs,
}
return sentinel
monkeypatch.setattr(
"codexlens.semantic.reranker.check_reranker_available",
fake_check_reranker_available,
)
monkeypatch.setattr(
"codexlens.semantic.reranker.get_reranker",
fake_get_reranker,
)
config = Config(
data_dir=tmp_path / "legacy",
enable_reranking=True,
enable_cross_encoder_rerank=True,
reranker_backend="legacy",
reranker_model="dummy-model",
)
engine = HybridSearchEngine(config=config)
reranker = engine._get_cross_encoder_reranker()
assert reranker is sentinel
assert calls["check_backend"] == "legacy"
get_args = calls["get_args"]
assert isinstance(get_args, dict)
assert get_args["backend"] == "legacy"
assert get_args["model_name"] == "dummy-model"
assert get_args["device"] is None
def test_get_cross_encoder_reranker_uses_factory_backend_onnx_gpu_flag(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
calls: dict[str, object] = {}
def fake_check_reranker_available(backend: str):
calls["check_backend"] = backend
return True, None
sentinel = object()
def fake_get_reranker(*, backend: str, model_name=None, device=None, **kwargs):
calls["get_args"] = {
"backend": backend,
"model_name": model_name,
"device": device,
"kwargs": kwargs,
}
return sentinel
monkeypatch.setattr(
"codexlens.semantic.reranker.check_reranker_available",
fake_check_reranker_available,
)
monkeypatch.setattr(
"codexlens.semantic.reranker.get_reranker",
fake_get_reranker,
)
config = Config(
data_dir=tmp_path / "onnx",
enable_reranking=True,
enable_cross_encoder_rerank=True,
reranker_backend="onnx",
embedding_use_gpu=False,
)
engine = HybridSearchEngine(config=config)
reranker = engine._get_cross_encoder_reranker()
assert reranker is sentinel
assert calls["check_backend"] == "onnx"
get_args = calls["get_args"]
assert isinstance(get_args, dict)
assert get_args["backend"] == "onnx"
assert get_args["model_name"] is None
assert get_args["device"] is None
assert get_args["kwargs"]["use_gpu"] is False
def test_get_cross_encoder_reranker_returns_none_when_backend_unavailable(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
) -> None:
def fake_check_reranker_available(backend: str):
return False, "missing deps"
def fake_get_reranker(*args, **kwargs):
raise AssertionError("get_reranker should not be called when backend is unavailable")
monkeypatch.setattr(
"codexlens.semantic.reranker.check_reranker_available",
fake_check_reranker_available,
)
monkeypatch.setattr(
"codexlens.semantic.reranker.get_reranker",
fake_get_reranker,
)
config = Config(
data_dir=tmp_path / "unavailable",
enable_reranking=True,
enable_cross_encoder_rerank=True,
reranker_backend="onnx",
)
engine = HybridSearchEngine(config=config)
assert engine._get_cross_encoder_reranker() is None

View File

@@ -0,0 +1,85 @@
"""Tests for LiteLLMReranker (LLM-based reranking)."""
from __future__ import annotations
import sys
import types
from dataclasses import dataclass
import pytest
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
def _install_dummy_ccw_litellm(
monkeypatch: pytest.MonkeyPatch, *, responses: list[str]
) -> None:
@dataclass(frozen=True, slots=True)
class ChatMessage:
role: str
content: str
class LiteLLMClient:
def __init__(self, model: str = "default", **kwargs) -> None:
self.model = model
self.kwargs = kwargs
self._responses = list(responses)
self.calls: list[list[ChatMessage]] = []
def chat(self, messages, **kwargs):
self.calls.append(list(messages))
content = self._responses.pop(0) if self._responses else ""
return types.SimpleNamespace(content=content)
dummy = types.ModuleType("ccw_litellm")
dummy.ChatMessage = ChatMessage
dummy.LiteLLMClient = LiteLLMClient
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy)
def test_score_pairs_parses_numbers_and_normalizes_scales(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_install_dummy_ccw_litellm(monkeypatch, responses=["0.73", "7", "80"])
reranker = LiteLLMReranker(model="dummy")
scores = reranker.score_pairs([("q", "d1"), ("q", "d2"), ("q", "d3")])
assert scores == pytest.approx([0.73, 0.7, 0.8])
def test_score_pairs_parses_json_score_field(monkeypatch: pytest.MonkeyPatch) -> None:
_install_dummy_ccw_litellm(monkeypatch, responses=['{"score": 0.42}'])
reranker = LiteLLMReranker(model="dummy")
scores = reranker.score_pairs([("q", "d")])
assert scores == pytest.approx([0.42])
def test_score_pairs_uses_default_score_on_parse_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_install_dummy_ccw_litellm(monkeypatch, responses=["N/A"])
reranker = LiteLLMReranker(model="dummy", default_score=0.123)
scores = reranker.score_pairs([("q", "d")])
assert scores == pytest.approx([0.123])
def test_rate_limiting_sleeps_between_requests(monkeypatch: pytest.MonkeyPatch) -> None:
_install_dummy_ccw_litellm(monkeypatch, responses=["0.1", "0.2"])
reranker = LiteLLMReranker(model="dummy", min_interval_seconds=1.0)
import codexlens.semantic.reranker.litellm_reranker as litellm_reranker_module
sleeps: list[float] = []
times = iter([100.0, 100.0, 100.1, 100.1])
monkeypatch.setattr(litellm_reranker_module.time, "monotonic", lambda: next(times))
monkeypatch.setattr(
litellm_reranker_module.time, "sleep", lambda seconds: sleeps.append(seconds)
)
_ = reranker.score_pairs([("q", "d1"), ("q", "d2")])
assert sleeps == pytest.approx([0.9])

View File

@@ -0,0 +1,115 @@
"""Mocked smoke tests for all reranker backends."""
from __future__ import annotations
import sys
import types
from dataclasses import dataclass
import pytest
def test_reranker_backend_legacy_scores_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
from codexlens.semantic.reranker import legacy as legacy_module
class DummyCrossEncoder:
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = model_name
self.device = device
self.calls: list[dict[str, object]] = []
def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]:
self.calls.append({"pairs": list(pairs), "batch_size": int(batch_size)})
return [0.5 for _ in pairs]
monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder)
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True)
monkeypatch.setattr(legacy_module, "_import_error", None)
reranker = legacy_module.CrossEncoderReranker(model_name="dummy-model", device="cpu")
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0)
assert scores == pytest.approx([0.5, 0.5])
def test_reranker_backend_onnx_availability_check(monkeypatch: pytest.MonkeyPatch) -> None:
from codexlens.semantic.reranker.onnx_reranker import check_onnx_reranker_available
dummy_numpy = types.ModuleType("numpy")
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
dummy_optimum_ort.ORTModelForSequenceClassification = object()
dummy_transformers = types.ModuleType("transformers")
dummy_transformers.AutoTokenizer = object()
monkeypatch.setitem(sys.modules, "numpy", dummy_numpy)
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
ok, err = check_onnx_reranker_available()
assert ok is True
assert err is None
def test_reranker_backend_api_constructs_with_dummy_httpx(monkeypatch: pytest.MonkeyPatch) -> None:
from codexlens.semantic.reranker.api_reranker import APIReranker
created: list[object] = []
class DummyClient:
def __init__(
self,
*,
base_url: str | None = None,
headers: dict[str, str] | None = None,
timeout: float | None = None,
) -> None:
self.base_url = base_url
self.headers = headers or {}
self.timeout = timeout
self.closed = False
created.append(self)
def close(self) -> None:
self.closed = True
dummy_httpx = types.ModuleType("httpx")
dummy_httpx.Client = DummyClient
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
reranker = APIReranker(api_key="k", provider="siliconflow")
assert reranker.provider == "siliconflow"
assert len(created) == 1
assert created[0].headers["Authorization"] == "Bearer k"
reranker.close()
assert created[0].closed is True
def test_reranker_backend_litellm_scores_pairs(monkeypatch: pytest.MonkeyPatch) -> None:
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
@dataclass(frozen=True, slots=True)
class ChatMessage:
role: str
content: str
class DummyLiteLLMClient:
def __init__(self, model: str = "default", **_kwargs: object) -> None:
self.model = model
def chat(self, _messages: list[ChatMessage]) -> object:
return types.SimpleNamespace(content="0.5")
dummy_litellm = types.ModuleType("ccw_litellm")
dummy_litellm.ChatMessage = ChatMessage
dummy_litellm.LiteLLMClient = DummyLiteLLMClient
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
reranker = LiteLLMReranker(model="dummy")
assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5])

View File

@@ -0,0 +1,315 @@
"""Tests for reranker factory and availability checks."""
from __future__ import annotations
import builtins
import math
import sys
import types
import pytest
from codexlens.semantic.reranker import (
BaseReranker,
ONNXReranker,
check_reranker_available,
get_reranker,
)
from codexlens.semantic.reranker import legacy as legacy_module
def test_public_imports_work() -> None:
from codexlens.semantic.reranker import BaseReranker as ImportedBaseReranker
from codexlens.semantic.reranker import get_reranker as imported_get_reranker
assert ImportedBaseReranker is BaseReranker
assert imported_get_reranker is get_reranker
def test_base_reranker_is_abstract() -> None:
with pytest.raises(TypeError):
BaseReranker() # type: ignore[abstract]
def test_check_reranker_available_invalid_backend() -> None:
ok, err = check_reranker_available("nope")
assert ok is False
assert "Invalid reranker backend" in (err or "")
def test_get_reranker_invalid_backend_raises_value_error() -> None:
with pytest.raises(ValueError, match="Unknown backend"):
get_reranker("nope")
def test_get_reranker_legacy_missing_dependency_raises_import_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", False)
monkeypatch.setattr(legacy_module, "_import_error", "missing sentence-transformers")
with pytest.raises(ImportError, match="missing sentence-transformers"):
get_reranker(backend="legacy", model_name="dummy-model")
def test_get_reranker_legacy_returns_cross_encoder_reranker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class DummyCrossEncoder:
def __init__(self, model_name: str, *, device: str | None = None) -> None:
self.model_name = model_name
self.device = device
self.last_batch_size: int | None = None
def predict(self, pairs: list[tuple[str, str]], *, batch_size: int = 32) -> list[float]:
self.last_batch_size = int(batch_size)
return [0.5 for _ in pairs]
monkeypatch.setattr(legacy_module, "_CrossEncoder", DummyCrossEncoder)
monkeypatch.setattr(legacy_module, "CROSS_ENCODER_AVAILABLE", True)
monkeypatch.setattr(legacy_module, "_import_error", None)
reranker = get_reranker(backend=" LEGACY ", model_name="dummy-model", device="cpu")
assert isinstance(reranker, legacy_module.CrossEncoderReranker)
assert reranker.score_pairs([]) == []
scores = reranker.score_pairs([("q", "d1"), ("q", "d2")], batch_size=0)
assert scores == pytest.approx([0.5, 0.5])
assert reranker._model is not None
assert reranker._model.last_batch_size == 32
def test_check_reranker_available_onnx_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "onnxruntime":
raise ImportError("no onnxruntime")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
ok, err = check_reranker_available("onnx")
assert ok is False
assert "onnxruntime not available" in (err or "")
def test_check_reranker_available_onnx_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
dummy_optimum_ort.ORTModelForSequenceClassification = object()
dummy_transformers = types.ModuleType("transformers")
dummy_transformers.AutoTokenizer = object()
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
ok, err = check_reranker_available("onnx")
assert ok is True
assert err is None
def test_check_reranker_available_litellm_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "ccw_litellm":
raise ImportError("no ccw-litellm")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
ok, err = check_reranker_available("litellm")
assert ok is False
assert "ccw-litellm not available" in (err or "")
def test_check_reranker_available_litellm_deps_present(
monkeypatch: pytest.MonkeyPatch,
) -> None:
dummy_litellm = types.ModuleType("ccw_litellm")
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
ok, err = check_reranker_available("litellm")
assert ok is True
assert err is None
def test_check_reranker_available_api_missing_deps(monkeypatch: pytest.MonkeyPatch) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "httpx":
raise ImportError("no httpx")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
ok, err = check_reranker_available("api")
assert ok is False
assert "httpx not available" in (err or "")
def test_check_reranker_available_api_deps_present(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_httpx = types.ModuleType("httpx")
monkeypatch.setitem(sys.modules, "httpx", dummy_httpx)
ok, err = check_reranker_available("api")
assert ok is True
assert err is None
def test_get_reranker_litellm_returns_litellm_reranker(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from dataclasses import dataclass
@dataclass(frozen=True, slots=True)
class ChatMessage:
role: str
content: str
class DummyLiteLLMClient:
def __init__(self, model: str = "default", **kwargs) -> None:
self.model = model
self.kwargs = kwargs
def chat(self, messages, **kwargs):
return types.SimpleNamespace(content="0.5")
dummy_litellm = types.ModuleType("ccw_litellm")
dummy_litellm.ChatMessage = ChatMessage
dummy_litellm.LiteLLMClient = DummyLiteLLMClient
monkeypatch.setitem(sys.modules, "ccw_litellm", dummy_litellm)
reranker = get_reranker(backend="litellm", model_name="dummy-model")
from codexlens.semantic.reranker.litellm_reranker import LiteLLMReranker
assert isinstance(reranker, LiteLLMReranker)
assert reranker.score_pairs([("q", "d")]) == pytest.approx([0.5])
def test_get_reranker_onnx_raises_import_error_with_dependency_hint(
monkeypatch: pytest.MonkeyPatch,
) -> None:
real_import = builtins.__import__
def fake_import(name: str, globals=None, locals=None, fromlist=(), level: int = 0):
if name == "onnxruntime":
raise ImportError("no onnxruntime")
return real_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
with pytest.raises(ImportError) as exc:
get_reranker(backend="onnx", model_name="any")
assert "onnxruntime" in str(exc.value)
def test_get_reranker_default_backend_is_onnx(monkeypatch: pytest.MonkeyPatch) -> None:
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
dummy_optimum_ort.ORTModelForSequenceClassification = object()
dummy_transformers = types.ModuleType("transformers")
dummy_transformers.AutoTokenizer = object()
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
reranker = get_reranker()
assert isinstance(reranker, ONNXReranker)
def test_onnx_reranker_scores_pairs_with_sigmoid_normalization(
monkeypatch: pytest.MonkeyPatch,
) -> None:
import numpy as np
dummy_onnxruntime = types.ModuleType("onnxruntime")
dummy_optimum = types.ModuleType("optimum")
dummy_optimum.__path__ = [] # Mark as package for submodule imports.
dummy_optimum_ort = types.ModuleType("optimum.onnxruntime")
class DummyModelOutput:
def __init__(self, logits: np.ndarray) -> None:
self.logits = logits
class DummyModel:
input_names = ["input_ids", "attention_mask"]
def __init__(self) -> None:
self.calls: list[int] = []
self._next_logit = 0
def __call__(self, **inputs):
batch = int(inputs["input_ids"].shape[0])
start = self._next_logit
self._next_logit += batch
self.calls.append(batch)
logits = np.arange(start, start + batch, dtype=np.float32).reshape(batch, 1)
return DummyModelOutput(logits=logits)
class DummyORTModelForSequenceClassification:
@classmethod
def from_pretrained(cls, model_name: str, providers=None, **kwargs):
_ = model_name, providers, kwargs
return DummyModel()
dummy_optimum_ort.ORTModelForSequenceClassification = DummyORTModelForSequenceClassification
dummy_transformers = types.ModuleType("transformers")
class DummyAutoTokenizer:
model_max_length = 512
@classmethod
def from_pretrained(cls, model_name: str, **kwargs):
_ = model_name, kwargs
return cls()
def __call__(self, *, text, text_pair, return_tensors, **kwargs):
_ = text_pair, kwargs
assert return_tensors == "np"
batch = len(text)
# Include token_type_ids to ensure input filtering is exercised.
return {
"input_ids": np.zeros((batch, 4), dtype=np.int64),
"attention_mask": np.ones((batch, 4), dtype=np.int64),
"token_type_ids": np.zeros((batch, 4), dtype=np.int64),
}
dummy_transformers.AutoTokenizer = DummyAutoTokenizer
monkeypatch.setitem(sys.modules, "onnxruntime", dummy_onnxruntime)
monkeypatch.setitem(sys.modules, "optimum", dummy_optimum)
monkeypatch.setitem(sys.modules, "optimum.onnxruntime", dummy_optimum_ort)
monkeypatch.setitem(sys.modules, "transformers", dummy_transformers)
reranker = get_reranker(backend="onnx", model_name="dummy-model", use_gpu=False)
assert isinstance(reranker, ONNXReranker)
assert reranker._model is None
pairs = [("q", f"d{idx}") for idx in range(5)]
scores = reranker.score_pairs(pairs, batch_size=2)
assert reranker._model is not None
assert reranker._model.calls == [2, 2, 1]
assert len(scores) == len(pairs)
assert all(0.0 <= s <= 1.0 for s in scores)
expected = [1.0 / (1.0 + math.exp(-float(i))) for i in range(len(pairs))]
assert scores == pytest.approx(expected, rel=1e-6, abs=1e-6)

View File

@@ -0,0 +1 @@
"""Tests for watcher module."""

View File

@@ -0,0 +1,43 @@
"""Fixtures for watcher tests."""
from __future__ import annotations
import tempfile
from pathlib import Path
from typing import Generator
import pytest
@pytest.fixture
def temp_project() -> Generator[Path, None, None]:
"""Create a temporary project directory with sample files."""
with tempfile.TemporaryDirectory() as tmpdir:
project = Path(tmpdir)
# Create sample Python file
py_file = project / "main.py"
py_file.write_text("def hello():\n print('Hello')\n")
# Create sample JavaScript file
js_file = project / "app.js"
js_file.write_text("function greet() {\n console.log('Hi');\n}\n")
# Create subdirectory with file
sub_dir = project / "src"
sub_dir.mkdir()
(sub_dir / "utils.py").write_text("def add(a, b):\n return a + b\n")
# Create ignored directory
git_dir = project / ".git"
git_dir.mkdir()
(git_dir / "config").write_text("[core]\n")
yield project
@pytest.fixture
def watcher_config():
"""Create default watcher configuration."""
from codexlens.watcher import WatcherConfig
return WatcherConfig(debounce_ms=100) # Short debounce for tests

View File

@@ -0,0 +1,103 @@
"""Tests for watcher event types."""
from __future__ import annotations
import time
from pathlib import Path
import pytest
from codexlens.watcher import ChangeType, FileEvent, WatcherConfig, IndexResult, WatcherStats
class TestChangeType:
"""Tests for ChangeType enum."""
def test_change_types_exist(self):
"""Verify all change types are defined."""
assert ChangeType.CREATED.value == "created"
assert ChangeType.MODIFIED.value == "modified"
assert ChangeType.DELETED.value == "deleted"
assert ChangeType.MOVED.value == "moved"
def test_change_type_count(self):
"""Verify we have exactly 4 change types."""
assert len(ChangeType) == 4
class TestFileEvent:
"""Tests for FileEvent dataclass."""
def test_create_event(self):
"""Test creating a file event."""
event = FileEvent(
path=Path("/test/file.py"),
change_type=ChangeType.CREATED,
timestamp=time.time(),
)
assert event.path == Path("/test/file.py")
assert event.change_type == ChangeType.CREATED
assert event.old_path is None
def test_moved_event(self):
"""Test creating a moved event with old_path."""
event = FileEvent(
path=Path("/test/new.py"),
change_type=ChangeType.MOVED,
timestamp=time.time(),
old_path=Path("/test/old.py"),
)
assert event.old_path == Path("/test/old.py")
class TestWatcherConfig:
"""Tests for WatcherConfig dataclass."""
def test_default_config(self):
"""Test default configuration values."""
config = WatcherConfig()
assert config.debounce_ms == 1000
assert ".git" in config.ignored_patterns
assert "node_modules" in config.ignored_patterns
assert "__pycache__" in config.ignored_patterns
assert config.languages is None
def test_custom_debounce(self):
"""Test custom debounce setting."""
config = WatcherConfig(debounce_ms=500)
assert config.debounce_ms == 500
class TestIndexResult:
"""Tests for IndexResult dataclass."""
def test_default_result(self):
"""Test default result values."""
result = IndexResult()
assert result.files_indexed == 0
assert result.files_removed == 0
assert result.symbols_added == 0
assert result.errors == []
def test_custom_result(self):
"""Test creating result with values."""
result = IndexResult(
files_indexed=5,
files_removed=2,
symbols_added=50,
errors=["error1"],
)
assert result.files_indexed == 5
assert result.files_removed == 2
class TestWatcherStats:
"""Tests for WatcherStats dataclass."""
def test_default_stats(self):
"""Test default stats values."""
stats = WatcherStats()
assert stats.files_watched == 0
assert stats.events_processed == 0
assert stats.last_event_time is None
assert stats.is_running is False

View File

@@ -0,0 +1,124 @@
"""Tests for FileWatcher class."""
from __future__ import annotations
import time
from pathlib import Path
from typing import List
import pytest
from codexlens.watcher import FileWatcher, WatcherConfig, FileEvent, ChangeType
class TestFileWatcherInit:
"""Tests for FileWatcher initialization."""
def test_init_with_valid_path(self, temp_project: Path, watcher_config: WatcherConfig):
"""Test initializing with valid path."""
events: List[FileEvent] = []
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
assert watcher.root_path == temp_project.resolve()
assert watcher.config == watcher_config
assert not watcher.is_running
def test_start_with_invalid_path(self, watcher_config: WatcherConfig):
"""Test starting watcher with non-existent path."""
events: List[FileEvent] = []
watcher = FileWatcher(Path("/nonexistent/path"), watcher_config, lambda e: events.extend(e))
with pytest.raises(ValueError, match="does not exist"):
watcher.start()
class TestFileWatcherLifecycle:
"""Tests for FileWatcher start/stop lifecycle."""
def test_start_stop(self, temp_project: Path, watcher_config: WatcherConfig):
"""Test basic start and stop."""
events: List[FileEvent] = []
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
watcher.start()
assert watcher.is_running
watcher.stop()
assert not watcher.is_running
def test_double_start(self, temp_project: Path, watcher_config: WatcherConfig):
"""Test calling start twice."""
events: List[FileEvent] = []
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
watcher.start()
watcher.start() # Should not raise
assert watcher.is_running
watcher.stop()
def test_double_stop(self, temp_project: Path, watcher_config: WatcherConfig):
"""Test calling stop twice."""
events: List[FileEvent] = []
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
watcher.start()
watcher.stop()
watcher.stop() # Should not raise
assert not watcher.is_running
class TestFileWatcherEvents:
"""Tests for FileWatcher event detection."""
def test_detect_file_creation(self, temp_project: Path, watcher_config: WatcherConfig):
"""Test detecting new file creation."""
events: List[FileEvent] = []
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
try:
watcher.start()
time.sleep(0.3) # Let watcher start (longer for Windows)
# Create new file
new_file = temp_project / "new_file.py"
new_file.write_text("# New file\n")
# Wait for event with retries (watchdog timing varies by platform)
max_wait = 2.0
waited = 0.0
while waited < max_wait:
time.sleep(0.2)
waited += 0.2
# Windows may report MODIFIED instead of CREATED
file_events = [e for e in events if e.change_type in (ChangeType.CREATED, ChangeType.MODIFIED)]
if any(e.path.name == "new_file.py" for e in file_events):
break
# Check event was detected (Windows may report MODIFIED instead of CREATED)
relevant_events = [e for e in events if e.change_type in (ChangeType.CREATED, ChangeType.MODIFIED)]
assert len(relevant_events) >= 1, f"Expected file event, got: {events}"
assert any(e.path.name == "new_file.py" for e in relevant_events)
finally:
watcher.stop()
def test_filter_ignored_directories(self, temp_project: Path, watcher_config: WatcherConfig):
"""Test that files in ignored directories are filtered."""
events: List[FileEvent] = []
watcher = FileWatcher(temp_project, watcher_config, lambda e: events.extend(e))
try:
watcher.start()
time.sleep(0.1)
# Create file in .git (should be ignored)
git_file = temp_project / ".git" / "test.py"
git_file.write_text("# In git\n")
time.sleep(watcher_config.debounce_ms / 1000.0 + 0.2)
# No events should be detected for .git files
git_events = [e for e in events if ".git" in str(e.path)]
assert len(git_events) == 0
finally:
watcher.stop()