feat: Implement adaptive RRF weights and query intent detection

- Added integration tests for adaptive RRF weights in hybrid search.
- Enhanced query intent detection with new classifications: keyword, semantic, and mixed.
- Introduced symbol boosting in search results based on explicit symbol matches.
- Implemented embedding-based reranking with configurable options.
- Added global symbol index for efficient symbol lookups across projects.
- Improved file deletion handling on Windows to avoid permission errors.
- Updated chunk configuration to increase overlap for better context.
- Modified package.json test script to target specific test files.
- Created comprehensive writing style guidelines for documentation.
- Added TypeScript tests for query intent detection and adaptive weights.
- Established performance benchmarks for global symbol indexing.
This commit is contained in:
catlog22
2025-12-26 15:08:47 +08:00
parent ecd5085e51
commit 4061ae48c4
29 changed files with 2685 additions and 828 deletions

Binary file not shown.

41
codex-lens/CHANGELOG.md Normal file
View File

@@ -0,0 +1,41 @@
# CodexLens Optimization Plan Changelog
This changelog tracks the **CodexLens optimization plan** milestones (not the Python package version in `pyproject.toml`).
## v1.0 (Optimization) 2025-12-26
### Optimizations
1. **P0: Context-aware hybrid chunking**
- Docstrings are extracted into dedicated chunks and excluded from code chunks.
- Docstring chunks include `parent_symbol` metadata when the docstring belongs to a function/class/method.
- Sliding-window chunk boundaries are deterministic for identical input.
2. **P1: Adaptive RRF weights (QueryIntent)**
- Query intent is classified as `keyword` / `semantic` / `mixed`.
- RRF weights adapt to intent:
- `keyword`: exact-heavy (favors lexical matches)
- `semantic`: vector-heavy (favors semantic matches)
- `mixed`: keeps base/default weights
3. **P2: Symbol boost**
- Fused results with an explicit symbol match (`symbol_name`) receive a multiplicative boost (default `1.5x`).
4. **P2: Embedding-based re-ranking (optional)**
- A second-stage ranker can reorder top results by semantic similarity.
- Re-ranking runs only when `Config.enable_reranking=True`.
5. **P3: Global symbol index (incremental + fast path)**
- `GlobalSymbolIndex` stores project-wide symbols in one SQLite DB for fast symbol lookups.
- `ChainSearchEngine.search_symbols()` uses the global index fast path when enabled.
### Migration Notes
- **Reindexing (recommended)**: deterministic chunking and docstring metadata affect stored chunks. For best results, regenerate indexes/embeddings after upgrading:
- Rebuild indexes and/or re-run embedding generation for existing projects.
- **New config flags**:
- `Config.enable_reranking` (default `False`)
- `Config.reranking_top_k` (default `50`)
- `Config.symbol_boost_factor` (default `1.5`)
- `Config.global_symbol_index_enabled` (default `True`)
- **Breaking changes**: none (behavioral improvements only).

View File

@@ -103,6 +103,11 @@ class Config:
# Indexing/search optimizations
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
# Optional search reranking (disabled by default)
enable_reranking: bool = False
reranking_top_k: int = 50
symbol_boost_factor: float = 1.5
# Multi-endpoint configuration for litellm backend
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]

View File

@@ -7,12 +7,38 @@ results via Reciprocal Rank Fusion (RRF) algorithm.
from __future__ import annotations
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
@contextmanager
def timer(name: str, logger: logging.Logger, level: int = logging.DEBUG):
"""Context manager for timing code blocks.
Args:
name: Name of the operation being timed
logger: Logger instance to use
level: Logging level (default DEBUG)
"""
start = time.perf_counter()
try:
yield
finally:
elapsed_ms = (time.perf_counter() - start) * 1000
logger.log(level, "[TIMING] %s: %.2fms", name, elapsed_ms)
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.ranking import reciprocal_rank_fusion, tag_search_source
from codexlens.search.ranking import (
apply_symbol_boost,
get_rrf_weights,
reciprocal_rank_fusion,
rerank_results,
tag_search_source,
)
from codexlens.storage.dir_index import DirIndexStore
@@ -34,14 +60,23 @@ class HybridSearchEngine:
"vector": 0.6,
}
def __init__(self, weights: Optional[Dict[str, float]] = None):
def __init__(
self,
weights: Optional[Dict[str, float]] = None,
config: Optional[Config] = None,
embedder: Any = None,
):
"""Initialize hybrid search engine.
Args:
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
config: Optional runtime config (enables optional reranking features)
embedder: Optional embedder instance for embedding-based reranking
"""
self.logger = logging.getLogger(__name__)
self.weights = weights or self.DEFAULT_WEIGHTS.copy()
self._config = config
self.embedder = embedder
def search(
self,
@@ -101,7 +136,8 @@ class HybridSearchEngine:
backends["vector"] = True
# Execute parallel searches
results_map = self._search_parallel(index_path, query, backends, limit)
with timer("parallel_search_total", self.logger):
results_map = self._search_parallel(index_path, query, backends, limit)
# Provide helpful message if pure-vector mode returns no results
if pure_vector and enable_vector and len(results_map.get("vector", [])) == 0:
@@ -120,11 +156,72 @@ class HybridSearchEngine:
if source in results_map
}
fused_results = reciprocal_rank_fusion(results_map, active_weights)
with timer("rrf_fusion", self.logger):
adaptive_weights = get_rrf_weights(query, active_weights)
fused_results = reciprocal_rank_fusion(results_map, adaptive_weights)
# Optional: boost results that include explicit symbol matches
boost_factor = (
self._config.symbol_boost_factor
if self._config is not None
else 1.5
)
with timer("symbol_boost", self.logger):
fused_results = apply_symbol_boost(
fused_results, boost_factor=boost_factor
)
# Optional: embedding-based reranking on top results
if self._config is not None and self._config.enable_reranking:
with timer("reranking", self.logger):
if self.embedder is None:
self.embedder = self._get_reranking_embedder()
fused_results = rerank_results(
query,
fused_results[:100],
self.embedder,
top_k=self._config.reranking_top_k,
)
# Apply final limit
return fused_results[:limit]
def _get_reranking_embedder(self) -> Any:
"""Create an embedder for reranking based on Config embedding settings."""
if self._config is None:
return None
try:
from codexlens.semantic.factory import get_embedder
except Exception as exc:
self.logger.debug("Reranking embedder unavailable: %s", exc)
return None
try:
if self._config.embedding_backend == "fastembed":
return get_embedder(
backend="fastembed",
profile=self._config.embedding_model,
use_gpu=self._config.embedding_use_gpu,
)
if self._config.embedding_backend == "litellm":
return get_embedder(
backend="litellm",
model=self._config.embedding_model,
endpoints=self._config.embedding_endpoints,
strategy=self._config.embedding_strategy,
cooldown=self._config.embedding_cooldown,
)
except Exception as exc:
self.logger.debug("Failed to initialize reranking embedder: %s", exc)
return None
self.logger.debug(
"Unknown embedding backend for reranking: %s",
self._config.embedding_backend,
)
return None
def _search_parallel(
self,
index_path: Path,
@@ -144,25 +241,30 @@ class HybridSearchEngine:
Dictionary mapping source name to results list
"""
results_map: Dict[str, List[SearchResult]] = {}
timing_data: Dict[str, float] = {}
# Use ThreadPoolExecutor for parallel I/O-bound searches
with ThreadPoolExecutor(max_workers=len(backends)) as executor:
# Submit search tasks
# Submit search tasks with timing
future_to_source = {}
submit_times = {}
if backends.get("exact"):
submit_times["exact"] = time.perf_counter()
future = executor.submit(
self._search_exact, index_path, query, limit
)
future_to_source[future] = "exact"
if backends.get("fuzzy"):
submit_times["fuzzy"] = time.perf_counter()
future = executor.submit(
self._search_fuzzy, index_path, query, limit
)
future_to_source[future] = "fuzzy"
if backends.get("vector"):
submit_times["vector"] = time.perf_counter()
future = executor.submit(
self._search_vector, index_path, query, limit
)
@@ -171,18 +273,26 @@ class HybridSearchEngine:
# Collect results as they complete
for future in as_completed(future_to_source):
source = future_to_source[future]
elapsed_ms = (time.perf_counter() - submit_times[source]) * 1000
timing_data[source] = elapsed_ms
try:
results = future.result()
# Tag results with source for debugging
tagged_results = tag_search_source(results, source)
results_map[source] = tagged_results
self.logger.debug(
"Got %d results from %s search", len(results), source
"[TIMING] %s_search: %.2fms (%d results)",
source, elapsed_ms, len(results)
)
except Exception as exc:
self.logger.error("Search failed for %s: %s", source, exc)
results_map[source] = []
# Log timing summary
if timing_data:
timing_str = ", ".join(f"{k}={v:.1f}ms" for k, v in timing_data.items())
self.logger.debug("[TIMING] search_backends: {%s}", timing_str)
return results_map
def _search_exact(
@@ -245,6 +355,8 @@ class HybridSearchEngine:
try:
# Check if semantic chunks table exists
import sqlite3
start_check = time.perf_counter()
try:
with sqlite3.connect(index_path) as conn:
cursor = conn.execute(
@@ -254,6 +366,10 @@ class HybridSearchEngine:
except sqlite3.Error as e:
self.logger.error("Database check failed in vector search: %s", e)
return []
self.logger.debug(
"[TIMING] vector_table_check: %.2fms",
(time.perf_counter() - start_check) * 1000
)
if not has_semantic_table:
self.logger.info(
@@ -267,7 +383,12 @@ class HybridSearchEngine:
from codexlens.semantic.factory import get_embedder
from codexlens.semantic.vector_store import VectorStore
start_init = time.perf_counter()
vector_store = VectorStore(index_path)
self.logger.debug(
"[TIMING] vector_store_init: %.2fms",
(time.perf_counter() - start_init) * 1000
)
# Check if vector store has data
if vector_store.count_chunks() == 0:
@@ -279,6 +400,7 @@ class HybridSearchEngine:
return []
# Get stored model configuration (preferred) or auto-detect from dimension
start_embedder = time.perf_counter()
model_config = vector_store.get_model_config()
if model_config:
backend = model_config.get("backend", "fastembed")
@@ -288,7 +410,7 @@ class HybridSearchEngine:
"Using stored model config: %s backend, %s (%s, %dd)",
backend, model_profile, model_name, model_config["embedding_dim"]
)
# Get embedder based on backend
if backend == "litellm":
embedder = get_embedder(backend="litellm", model=model_name)
@@ -324,21 +446,32 @@ class HybridSearchEngine:
detected_dim
)
embedder = get_embedder(backend="fastembed", profile="code")
self.logger.debug(
"[TIMING] embedder_init: %.2fms",
(time.perf_counter() - start_embedder) * 1000
)
# Generate query embedding
start_embed = time.perf_counter()
query_embedding = embedder.embed_single(query)
self.logger.debug(
"[TIMING] query_embedding: %.2fms",
(time.perf_counter() - start_embed) * 1000
)
# Search for similar chunks
start_search = time.perf_counter()
results = vector_store.search_similar(
query_embedding=query_embedding,
top_k=limit,
min_score=0.0, # Return all results, let RRF handle filtering
return_full_content=True,
)
self.logger.debug(
"[TIMING] vector_similarity_search: %.2fms (%d results)",
(time.perf_counter() - start_search) * 1000, len(results)
)
self.logger.debug("Vector search found %d results", len(results))
return results
except ImportError as exc:

View File

@@ -6,12 +6,98 @@ for combining results from heterogeneous search backends (exact FTS, fuzzy FTS,
from __future__ import annotations
import re
import math
from typing import Dict, List
from enum import Enum
from typing import Any, Dict, List
from codexlens.entities import SearchResult, AdditionalLocation
class QueryIntent(str, Enum):
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
KEYWORD = "keyword"
SEMANTIC = "semantic"
MIXED = "mixed"
def normalize_weights(weights: Dict[str, float]) -> Dict[str, float]:
"""Normalize weights to sum to 1.0 (best-effort)."""
total = sum(float(v) for v in weights.values() if v is not None)
if not math.isfinite(total) or total <= 0:
return {k: float(v) for k, v in weights.items()}
return {k: float(v) / total for k, v in weights.items()}
def detect_query_intent(query: str) -> QueryIntent:
"""Detect whether a query is code-like, natural-language, or mixed.
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
"""
trimmed = (query or "").strip()
if not trimmed:
return QueryIntent.MIXED
lower = trimmed.lower()
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
has_code_signals = bool(
re.search(r"(::|->|\.)", trimmed)
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
or re.search(r"\b\w+_\w+\b", trimmed)
or re.search(
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
lower,
flags=re.IGNORECASE,
)
)
has_natural_signals = bool(
word_count > 5
or "?" in trimmed
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
or re.search(
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
trimmed,
flags=re.IGNORECASE,
)
)
if has_code_signals and has_natural_signals:
return QueryIntent.MIXED
if has_code_signals:
return QueryIntent.KEYWORD
if has_natural_signals:
return QueryIntent.SEMANTIC
return QueryIntent.MIXED
def adjust_weights_by_intent(
intent: QueryIntent,
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Map intent → weights (kept aligned with TypeScript mapping)."""
if intent == QueryIntent.KEYWORD:
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
elif intent == QueryIntent.SEMANTIC:
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
else:
target = dict(base_weights)
# Preserve only keys that are present in base_weights (active backends).
keys = list(base_weights.keys())
filtered = {k: float(target.get(k, 0.0)) for k in keys}
return normalize_weights(filtered)
def get_rrf_weights(
query: str,
base_weights: Dict[str, float],
) -> Dict[str, float]:
"""Compute adaptive RRF weights from query intent."""
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
def reciprocal_rank_fusion(
results_map: Dict[str, List[SearchResult]],
weights: Dict[str, float] = None,
@@ -102,6 +188,186 @@ def reciprocal_rank_fusion(
return fused_results
def apply_symbol_boost(
results: List[SearchResult],
boost_factor: float = 1.5,
) -> List[SearchResult]:
"""Boost fused scores for results that include an explicit symbol match.
The boost is multiplicative on the current result.score (typically the RRF fusion score).
When boosted, the original score is preserved in metadata["original_fusion_score"] and
metadata["boosted"] is set to True.
"""
if not results:
return []
if boost_factor <= 1.0:
# Still return new objects to follow immutable transformation pattern.
return [
SearchResult(
path=r.path,
score=r.score,
excerpt=r.excerpt,
content=r.content,
symbol=r.symbol,
chunk=r.chunk,
metadata={**r.metadata},
start_line=r.start_line,
end_line=r.end_line,
symbol_name=r.symbol_name,
symbol_kind=r.symbol_kind,
additional_locations=list(r.additional_locations),
)
for r in results
]
boosted_results: List[SearchResult] = []
for result in results:
has_symbol = bool(result.symbol_name)
original_score = float(result.score)
boosted_score = original_score * boost_factor if has_symbol else original_score
metadata = {**result.metadata}
if has_symbol:
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
metadata["boosted"] = True
metadata["symbol_boost_factor"] = boost_factor
boosted_results.append(
SearchResult(
path=result.path,
score=boosted_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata=metadata,
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
boosted_results.sort(key=lambda r: r.score, reverse=True)
return boosted_results
def rerank_results(
query: str,
results: List[SearchResult],
embedder: Any,
top_k: int = 50,
) -> List[SearchResult]:
"""Re-rank results with embedding cosine similarity, combined with current score.
Combined score formula:
0.5 * rrf_score + 0.5 * cosine_similarity
If embedder is None or embedding fails, returns results as-is.
"""
if not results:
return []
if embedder is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
# Defensive: handle mismatched lengths and zero vectors.
n = min(len(vec_a), len(vec_b))
if n == 0:
return 0.0
dot = 0.0
norm_a = 0.0
norm_b = 0.0
for i in range(n):
a = float(vec_a[i])
b = float(vec_b[i])
dot += a * b
norm_a += a * a
norm_b += b * b
if norm_a <= 0.0 or norm_b <= 0.0:
return 0.0
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
return max(0.0, min(1.0, sim))
def text_for_embedding(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
# Fallback: stable, non-empty text.
return r.symbol_name or r.path
try:
if hasattr(embedder, "embed_single"):
query_vec = embedder.embed_single(query)
else:
query_vec = embedder.embed(query)[0]
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
doc_vecs = embedder.embed(doc_texts)
except Exception:
return results
reranked_results: List[SearchResult] = []
for idx, result in enumerate(results):
if idx < rerank_count:
rrf_score = float(result.score)
sim = cosine_similarity(query_vec, doc_vecs[idx])
combined_score = 0.5 * rrf_score + 0.5 * sim
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"rrf_score": rrf_score,
"cosine_similarity": sim,
"reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
# Preserve remaining results without re-ranking, but keep immutability.
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def normalize_bm25_score(score: float) -> float:
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.

View File

@@ -392,6 +392,22 @@ class HybridChunker:
filtered.append(symbol)
return filtered
def _find_parent_symbol(
self,
start_line: int,
end_line: int,
symbols: List[Symbol],
) -> Optional[Symbol]:
"""Find the smallest symbol range that fully contains a docstring span."""
candidates: List[Symbol] = []
for symbol in symbols:
sym_start, sym_end = symbol.range
if sym_start <= start_line and end_line <= sym_end:
candidates.append(symbol)
if not candidates:
return None
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
def chunk_file(
self,
content: str,
@@ -414,24 +430,53 @@ class HybridChunker:
chunks: List[SemanticChunk] = []
# Step 1: Extract docstrings as dedicated chunks
docstrings = self.docstring_extractor.extract_docstrings(content, language)
docstrings: List[Tuple[str, int, int]] = []
if language == "python":
# Fast path: avoid expensive docstring extraction if delimiters are absent.
if '"""' in content or "'''" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
elif language in {"javascript", "typescript"}:
if "/**" in content:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
else:
docstrings = self.docstring_extractor.extract_docstrings(content, language)
# Fast path: no docstrings -> delegate to base chunker directly.
if not docstrings:
if symbols:
base_chunks = self.base_chunker.chunk_by_symbol(
content, symbols, file_path, language, symbol_token_counts
)
else:
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
for chunk in base_chunks:
chunk.metadata["strategy"] = "hybrid"
chunk.metadata["chunk_type"] = "code"
return base_chunks
for docstring_content, start_line, end_line in docstrings:
if len(docstring_content.strip()) >= self.config.min_chunk_size:
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
# Use base chunker's token estimation method
token_count = self.base_chunker._estimate_token_count(docstring_content)
metadata = {
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
if parent_symbol is not None:
metadata["parent_symbol"] = parent_symbol.name
metadata["parent_symbol_kind"] = parent_symbol.kind
metadata["parent_symbol_range"] = parent_symbol.range
chunks.append(SemanticChunk(
content=docstring_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"chunk_type": "docstring",
"start_line": start_line,
"end_line": end_line,
"strategy": "hybrid",
"token_count": token_count,
}
metadata=metadata
))
# Step 2: Get line ranges occupied by docstrings

View File

@@ -0,0 +1,293 @@
import sqlite3
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from codexlens.config import Config
from codexlens.entities import Symbol
from codexlens.errors import StorageError
from codexlens.search.chain_search import ChainSearchEngine
from codexlens.storage.global_index import GlobalSymbolIndex
from codexlens.storage.path_mapper import PathMapper
from codexlens.storage.registry import RegistryStore
@pytest.fixture()
def temp_paths():
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
root = Path(tmpdir.name)
yield root
try:
tmpdir.cleanup()
except (PermissionError, OSError):
pass
def test_add_symbol(temp_paths: Path):
db_path = temp_paths / "indexes" / "_global_symbols.db"
index_path = temp_paths / "indexes" / "_index.db"
file_path = temp_paths / "src" / "a.py"
index_path.parent.mkdir(parents=True, exist_ok=True)
index_path.write_text("", encoding="utf-8")
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
with GlobalSymbolIndex(db_path, project_id=1) as store:
store.add_symbol(
Symbol(name="AuthManager", kind="class", range=(1, 2)),
file_path=file_path,
index_path=index_path,
)
matches = store.search("AuthManager", kind="class", limit=10, prefix_mode=True)
assert len(matches) == 1
assert matches[0].name == "AuthManager"
assert matches[0].file == str(file_path.resolve())
# Schema version safety: newer schema versions should be rejected.
bad_db = temp_paths / "indexes" / "_global_symbols_bad.db"
bad_db.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(bad_db)
conn.execute("PRAGMA user_version = 999")
conn.close()
with pytest.raises(StorageError):
GlobalSymbolIndex(bad_db, project_id=1).initialize()
def test_search_symbols(temp_paths: Path):
db_path = temp_paths / "indexes" / "_global_symbols.db"
index_path = temp_paths / "indexes" / "_index.db"
file_path = temp_paths / "src" / "mod.py"
index_path.parent.mkdir(parents=True, exist_ok=True)
index_path.write_text("", encoding="utf-8")
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("def authenticate():\n pass\n", encoding="utf-8")
with GlobalSymbolIndex(db_path, project_id=7) as store:
store.add_symbol(
Symbol(name="authenticate", kind="function", range=(1, 2)),
file_path=file_path,
index_path=index_path,
)
locations = store.search_symbols("auth", kind="function", limit=10, prefix_mode=True)
assert locations
assert any(p.endswith("mod.py") for p, _ in locations)
assert any(rng == (1, 2) for _, rng in locations)
def test_update_file_symbols(temp_paths: Path):
db_path = temp_paths / "indexes" / "_global_symbols.db"
file_path = temp_paths / "src" / "mod.py"
index_path = temp_paths / "indexes" / "_index.db"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("def a():\n pass\n", encoding="utf-8")
index_path.parent.mkdir(parents=True, exist_ok=True)
index_path.write_text("", encoding="utf-8")
with GlobalSymbolIndex(db_path, project_id=7) as store:
store.update_file_symbols(
file_path=file_path,
symbols=[
Symbol(name="old_func", kind="function", range=(1, 2)),
Symbol(name="Other", kind="class", range=(10, 20)),
],
index_path=index_path,
)
assert any(s.name == "old_func" for s in store.search("old_", prefix_mode=True))
store.update_file_symbols(
file_path=file_path,
symbols=[Symbol(name="new_func", kind="function", range=(3, 4))],
index_path=index_path,
)
assert not any(s.name == "old_func" for s in store.search("old_", prefix_mode=True))
assert any(s.name == "new_func" for s in store.search("new_", prefix_mode=True))
# Backward-compatible path: index_path can be omitted after it's been established.
store.update_file_symbols(
file_path=file_path,
symbols=[Symbol(name="new_func2", kind="function", range=(5, 6))],
index_path=None,
)
assert any(s.name == "new_func2" for s in store.search("new_func2", prefix_mode=True))
# New file + symbols without index_path should raise.
missing_index_file = temp_paths / "src" / "new_file.py"
with pytest.raises(StorageError):
store.update_file_symbols(
file_path=missing_index_file,
symbols=[Symbol(name="must_fail", kind="function", range=(1, 1))],
index_path=None,
)
deleted = store.delete_file_symbols(file_path)
assert deleted > 0
def test_incremental_updates(temp_paths: Path, monkeypatch):
db_path = temp_paths / "indexes" / "_global_symbols.db"
file_path = temp_paths / "src" / "same.py"
idx_a = temp_paths / "indexes" / "a" / "_index.db"
idx_b = temp_paths / "indexes" / "b" / "_index.db"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
idx_a.parent.mkdir(parents=True, exist_ok=True)
idx_a.write_text("", encoding="utf-8")
idx_b.parent.mkdir(parents=True, exist_ok=True)
idx_b.write_text("", encoding="utf-8")
with GlobalSymbolIndex(db_path, project_id=42) as store:
sym = Symbol(name="AuthManager", kind="class", range=(1, 2))
store.add_symbol(sym, file_path=file_path, index_path=idx_a)
store.add_symbol(sym, file_path=file_path, index_path=idx_b)
# prefix_mode=False exercises substring matching.
assert store.search("Manager", prefix_mode=False)
conn = sqlite3.connect(db_path)
row = conn.execute(
"""
SELECT index_path
FROM global_symbols
WHERE project_id=? AND symbol_name=? AND symbol_kind=? AND file_path=?
""",
(42, "AuthManager", "class", str(file_path.resolve())),
).fetchone()
conn.close()
assert row is not None
assert str(Path(row[0]).resolve()) == str(idx_b.resolve())
# Migration path coverage: simulate a future schema version and an older DB version.
migrating_db = temp_paths / "indexes" / "_global_symbols_migrate.db"
migrating_db.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(migrating_db)
conn.execute("PRAGMA user_version = 1")
conn.close()
monkeypatch.setattr(GlobalSymbolIndex, "SCHEMA_VERSION", 2)
GlobalSymbolIndex(migrating_db, project_id=1).initialize()
def test_concurrent_access(temp_paths: Path):
db_path = temp_paths / "indexes" / "_global_symbols.db"
index_path = temp_paths / "indexes" / "_index.db"
file_path = temp_paths / "src" / "a.py"
index_path.parent.mkdir(parents=True, exist_ok=True)
index_path.write_text("", encoding="utf-8")
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("class A:\n pass\n", encoding="utf-8")
with GlobalSymbolIndex(db_path, project_id=1) as store:
def add_many(worker_id: int):
for i in range(50):
store.add_symbol(
Symbol(name=f"Sym{worker_id}_{i}", kind="class", range=(1, 2)),
file_path=file_path,
index_path=index_path,
)
with ThreadPoolExecutor(max_workers=8) as ex:
list(ex.map(add_many, range(8)))
matches = store.search("Sym", kind="class", limit=1000, prefix_mode=True)
assert len(matches) >= 200
def test_chain_search_integration(temp_paths: Path):
project_root = temp_paths / "project"
project_root.mkdir(parents=True, exist_ok=True)
index_root = temp_paths / "indexes"
mapper = PathMapper(index_root=index_root)
index_db_path = mapper.source_to_index_db(project_root)
index_db_path.parent.mkdir(parents=True, exist_ok=True)
index_db_path.write_text("", encoding="utf-8")
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
project_info = registry.register_project(project_root, mapper.source_to_index_dir(project_root))
global_db_path = project_info.index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
with GlobalSymbolIndex(global_db_path, project_id=project_info.id) as global_index:
file_path = project_root / "auth.py"
global_index.update_file_symbols(
file_path=file_path,
symbols=[
Symbol(name="AuthManager", kind="class", range=(1, 10)),
Symbol(name="authenticate", kind="function", range=(12, 20)),
],
index_path=index_db_path,
)
config = Config(data_dir=temp_paths / "data", global_symbol_index_enabled=True)
engine = ChainSearchEngine(registry, mapper, config=config)
engine._search_symbols_parallel = MagicMock(side_effect=AssertionError("should not traverse chain"))
symbols = engine.search_symbols("Auth", project_root)
assert any(s.name == "AuthManager" for s in symbols)
registry.close()
def test_disabled_fallback(temp_paths: Path):
project_root = temp_paths / "project"
project_root.mkdir(parents=True, exist_ok=True)
index_root = temp_paths / "indexes"
mapper = PathMapper(index_root=index_root)
index_db_path = mapper.source_to_index_db(project_root)
index_db_path.parent.mkdir(parents=True, exist_ok=True)
index_db_path.write_text("", encoding="utf-8")
registry = RegistryStore(db_path=temp_paths / "registry.db")
registry.initialize()
registry.register_project(project_root, mapper.source_to_index_dir(project_root))
config = Config(data_dir=temp_paths / "data", global_symbol_index_enabled=False)
engine = ChainSearchEngine(registry, mapper, config=config)
engine._collect_index_paths = MagicMock(return_value=[index_db_path])
engine._search_symbols_parallel = MagicMock(
return_value=[Symbol(name="FallbackSymbol", kind="function", range=(1, 2))]
)
symbols = engine.search_symbols("Fallback", project_root)
assert any(s.name == "FallbackSymbol" for s in symbols)
assert engine._search_symbols_parallel.called
registry.close()
def test_performance_benchmark(temp_paths: Path):
db_path = temp_paths / "indexes" / "_global_symbols.db"
index_path = temp_paths / "indexes" / "_index.db"
file_path = temp_paths / "src" / "perf.py"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
index_path.parent.mkdir(parents=True, exist_ok=True)
index_path.write_text("", encoding="utf-8")
with GlobalSymbolIndex(db_path, project_id=1) as store:
for i in range(500):
store.add_symbol(
Symbol(name=f"AuthManager{i}", kind="class", range=(1, 2)),
file_path=file_path,
index_path=index_path,
)
start = time.perf_counter()
results = store.search("AuthManager", kind="class", limit=50, prefix_mode=True)
elapsed_ms = (time.perf_counter() - start) * 1000
assert elapsed_ms < 100.0
assert results

View File

@@ -551,3 +551,72 @@ class UserProfile:
# Verify <15% overhead (reasonable threshold for performance tests with system variance)
assert overhead < 15.0, f"Overhead {overhead:.2f}% exceeds 15% threshold (base={base_time:.4f}s, hybrid={hybrid_time:.4f}s)"
class TestHybridChunkerV1Optimizations:
"""Tests for v1.0 optimization behaviors (parent metadata + determinism)."""
def test_merged_docstring_metadata(self):
"""Docstring chunks include parent_symbol metadata when applicable."""
config = ChunkConfig(min_chunk_size=1)
chunker = HybridChunker(config=config)
content = '''"""Module docstring."""
def hello():
"""Function docstring."""
return 1
'''
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
chunks = chunker.chunk_file(content, symbols, "m.py", "python")
func_doc_chunks = [
c for c in chunks
if c.metadata.get("chunk_type") == "docstring" and c.metadata.get("start_line") == 4
]
assert len(func_doc_chunks) == 1
assert func_doc_chunks[0].metadata.get("parent_symbol") == "hello"
assert func_doc_chunks[0].metadata.get("parent_symbol_kind") == "function"
def test_deterministic_chunk_boundaries(self):
"""Chunk boundaries are stable across repeated runs on identical input."""
config = ChunkConfig(max_chunk_size=80, overlap=10, min_chunk_size=1)
chunker = HybridChunker(config=config)
# No docstrings, no symbols -> sliding window path.
content = "\n".join([f"line {i}: x = {i}" for i in range(1, 200)]) + "\n"
boundaries = []
for _ in range(3):
chunks = chunker.chunk_file(content, [], "deterministic.py", "python")
boundaries.append([
(c.metadata.get("start_line"), c.metadata.get("end_line"))
for c in chunks
if c.metadata.get("chunk_type") == "code"
])
assert boundaries[0] == boundaries[1] == boundaries[2]
def test_orphan_docstrings(self):
"""Module-level docstrings remain standalone (no parent_symbol assigned)."""
config = ChunkConfig(min_chunk_size=1)
chunker = HybridChunker(config=config)
content = '''"""Module-level docstring."""
def hello():
"""Function docstring."""
return 1
'''
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
chunks = chunker.chunk_file(content, symbols, "orphan.py", "python")
module_doc = [
c for c in chunks
if c.metadata.get("chunk_type") == "docstring" and c.metadata.get("start_line") == 1
]
assert len(module_doc) == 1
assert module_doc[0].metadata.get("parent_symbol") is None
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
assert code_chunks, "Expected at least one code chunk"
assert all("Module-level docstring" not in c.content for c in code_chunks)

View File

@@ -10,6 +10,7 @@ from pathlib import Path
import pytest
from codexlens.config import Config
from codexlens.entities import SearchResult
from codexlens.search.hybrid_search import HybridSearchEngine
from codexlens.storage.dir_index import DirIndexStore
@@ -774,3 +775,97 @@ class TestHybridSearchWithVectorMock:
assert hasattr(result, 'score')
assert result.score > 0 # RRF fusion scores are positive
class TestHybridSearchAdaptiveWeights:
"""Integration tests for adaptive RRF weights + reranking gating."""
def test_adaptive_weights_code_query(self):
"""Exact weight should dominate for code-like queries."""
from unittest.mock import patch
engine = HybridSearchEngine()
results_map = {
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
}
captured = {}
from codexlens.search import ranking as ranking_module
def capture_rrf(map_in, weights_in, k=60):
captured["weights"] = dict(weights_in)
return ranking_module.reciprocal_rank_fusion(map_in, weights_in, k=k)
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
"codexlens.search.hybrid_search.reciprocal_rank_fusion",
side_effect=capture_rrf,
):
engine.search(Path("dummy.db"), "def authenticate", enable_vector=True)
assert captured["weights"]["exact"] > 0.4
def test_adaptive_weights_nl_query(self):
"""Vector weight should dominate for natural-language queries."""
from unittest.mock import patch
engine = HybridSearchEngine()
results_map = {
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
}
captured = {}
from codexlens.search import ranking as ranking_module
def capture_rrf(map_in, weights_in, k=60):
captured["weights"] = dict(weights_in)
return ranking_module.reciprocal_rank_fusion(map_in, weights_in, k=k)
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
"codexlens.search.hybrid_search.reciprocal_rank_fusion",
side_effect=capture_rrf,
):
engine.search(Path("dummy.db"), "how to handle user login", enable_vector=True)
assert captured["weights"]["vector"] > 0.6
def test_reranking_enabled(self, tmp_path):
"""Reranking runs only when explicitly enabled via config."""
from unittest.mock import patch
results_map = {
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
}
class DummyEmbedder:
def embed(self, texts):
if isinstance(texts, str):
texts = [texts]
return [[1.0, 0.0] for _ in texts]
# Disabled: should not invoke rerank_results
config_off = Config(data_dir=tmp_path / "off", enable_reranking=False)
engine_off = HybridSearchEngine(config=config_off, embedder=DummyEmbedder())
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
"codexlens.search.hybrid_search.rerank_results"
) as rerank_mock:
engine_off.search(Path("dummy.db"), "query", enable_vector=True)
rerank_mock.assert_not_called()
# Enabled: should invoke rerank_results once
config_on = Config(data_dir=tmp_path / "on", enable_reranking=True, reranking_top_k=10)
engine_on = HybridSearchEngine(config=config_on, embedder=DummyEmbedder())
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
"codexlens.search.hybrid_search.rerank_results",
side_effect=lambda q, r, e, top_k=50: r,
) as rerank_mock:
engine_on.search(Path("dummy.db"), "query", enable_vector=True)
assert rerank_mock.call_count == 1

View File

@@ -3,6 +3,7 @@
import pytest
import sqlite3
import tempfile
import time
from pathlib import Path
from codexlens.search.hybrid_search import HybridSearchEngine
@@ -16,6 +17,22 @@ except ImportError:
SEMANTIC_DEPS_AVAILABLE = False
def _safe_unlink(path: Path, retries: int = 5, delay_s: float = 0.05) -> None:
"""Best-effort unlink for Windows where SQLite can keep files locked briefly."""
for attempt in range(retries):
try:
path.unlink()
return
except FileNotFoundError:
return
except PermissionError:
time.sleep(delay_s * (attempt + 1))
try:
path.unlink(missing_ok=True)
except (PermissionError, OSError):
pass
class TestPureVectorSearch:
"""Tests for pure vector search mode."""
@@ -48,7 +65,7 @@ class TestPureVectorSearch:
store.close()
if db_path.exists():
db_path.unlink()
_safe_unlink(db_path)
def test_pure_vector_without_embeddings(self, sample_db):
"""Test pure_vector mode returns empty when no embeddings exist."""
@@ -200,12 +217,8 @@ def login_handler(credentials: dict) -> bool:
yield db_path
store.close()
# Ignore file deletion errors on Windows (SQLite file lock)
try:
if db_path.exists():
db_path.unlink()
except PermissionError:
pass # Ignore Windows file lock errors
if db_path.exists():
_safe_unlink(db_path)
def test_pure_vector_with_embeddings(self, db_with_embeddings):
"""Test pure vector search returns results when embeddings exist."""
@@ -289,7 +302,7 @@ class TestSearchModeComparison:
store.close()
if db_path.exists():
db_path.unlink()
_safe_unlink(db_path)
def test_mode_comparison_without_embeddings(self, comparison_db):
"""Compare all search modes without embeddings."""

View File

@@ -7,8 +7,12 @@ import pytest
from codexlens.entities import SearchResult
from codexlens.search.ranking import (
apply_symbol_boost,
QueryIntent,
detect_query_intent,
normalize_bm25_score,
reciprocal_rank_fusion,
rerank_results,
tag_search_source,
)
@@ -342,6 +346,62 @@ class TestTagSearchSource:
assert tagged[0].symbol_kind == "function"
class TestSymbolBoost:
"""Tests for apply_symbol_boost function."""
def test_symbol_boost(self):
results = [
SearchResult(path="a.py", score=0.2, excerpt="...", symbol_name="foo"),
SearchResult(path="b.py", score=0.21, excerpt="..."),
]
boosted = apply_symbol_boost(results, boost_factor=1.5)
assert boosted[0].path == "a.py"
assert boosted[0].score == pytest.approx(0.2 * 1.5)
assert boosted[0].metadata["boosted"] is True
assert boosted[0].metadata["original_fusion_score"] == pytest.approx(0.2)
assert boosted[1].path == "b.py"
assert boosted[1].score == pytest.approx(0.21)
assert "boosted" not in boosted[1].metadata
class TestEmbeddingReranking:
"""Tests for rerank_results embedding-based similarity."""
def test_rerank_embedding_similarity(self):
class DummyEmbedder:
def embed(self, texts):
if isinstance(texts, str):
texts = [texts]
mapping = {
"query": [1.0, 0.0],
"doc1": [1.0, 0.0],
"doc2": [0.0, 1.0],
}
return [mapping[t] for t in texts]
results = [
SearchResult(path="a.py", score=0.2, excerpt="doc1"),
SearchResult(path="b.py", score=0.9, excerpt="doc2"),
]
reranked = rerank_results("query", results, DummyEmbedder(), top_k=2)
assert reranked[0].path == "a.py"
assert reranked[0].metadata["reranked"] is True
assert reranked[0].metadata["rrf_score"] == pytest.approx(0.2)
assert reranked[0].metadata["cosine_similarity"] == pytest.approx(1.0)
assert reranked[0].score == pytest.approx(0.5 * 0.2 + 0.5 * 1.0)
assert reranked[1].path == "b.py"
assert reranked[1].metadata["reranked"] is True
assert reranked[1].metadata["rrf_score"] == pytest.approx(0.9)
assert reranked[1].metadata["cosine_similarity"] == pytest.approx(0.0)
assert reranked[1].score == pytest.approx(0.5 * 0.9 + 0.5 * 0.0)
@pytest.mark.parametrize("k_value", [30, 60, 100])
class TestRRFParameterized:
"""Parameterized tests for RRF with different k values."""
@@ -419,3 +479,41 @@ class TestRRFEdgeCases:
# Should work with normalization
assert len(fused) == 1 # Deduplicated
assert fused[0].score > 0
class TestSymbolBoostAndIntentV1:
"""Tests for symbol boosting and query intent detection (v1.0)."""
def test_symbol_boost_application(self):
"""Results with symbol_name receive a multiplicative boost (default 1.5x)."""
results = [
SearchResult(path="a.py", score=0.4, excerpt="...", symbol_name="AuthManager"),
SearchResult(path="b.py", score=0.41, excerpt="..."),
]
boosted = apply_symbol_boost(results, boost_factor=1.5)
assert boosted[0].score == pytest.approx(0.4 * 1.5)
assert boosted[0].metadata["boosted"] is True
assert boosted[0].metadata["original_fusion_score"] == pytest.approx(0.4)
assert boosted[1].score == pytest.approx(0.41)
assert "boosted" not in boosted[1].metadata
@pytest.mark.parametrize(
("query", "expected"),
[
("def authenticate", QueryIntent.KEYWORD),
("MyClass", QueryIntent.KEYWORD),
("user_id", QueryIntent.KEYWORD),
("UserService::authenticate", QueryIntent.KEYWORD),
("ptr->next", QueryIntent.KEYWORD),
("how to handle user login", QueryIntent.SEMANTIC),
("what is authentication?", QueryIntent.SEMANTIC),
("where is this used?", QueryIntent.SEMANTIC),
("why does FooBar crash?", QueryIntent.MIXED),
("how to use user_id in query", QueryIntent.MIXED),
],
)
def test_query_intent_detection(self, query, expected):
"""Detect intent for representative queries (Python/TypeScript parity)."""
assert detect_query_intent(query) == expected

View File

@@ -466,7 +466,18 @@ class TestDiagnostics:
yield db_path
if db_path.exists():
db_path.unlink()
for attempt in range(5):
try:
db_path.unlink()
break
except PermissionError:
time.sleep(0.05 * (attempt + 1))
else:
# Best-effort cleanup (Windows SQLite locks can linger briefly).
try:
db_path.unlink(missing_ok=True)
except (PermissionError, OSError):
pass
def test_diagnose_empty_database(self, empty_db):
"""Diagnose behavior with empty database."""

View File

@@ -13,7 +13,7 @@ class TestChunkConfig:
"""Test default configuration values."""
config = ChunkConfig()
assert config.max_chunk_size == 1000
assert config.overlap == 100
assert config.overlap == 200
assert config.min_chunk_size == 50
def test_custom_config(self):