mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-11 02:33:51 +08:00
feat: Implement adaptive RRF weights and query intent detection
- Added integration tests for adaptive RRF weights in hybrid search. - Enhanced query intent detection with new classifications: keyword, semantic, and mixed. - Introduced symbol boosting in search results based on explicit symbol matches. - Implemented embedding-based reranking with configurable options. - Added global symbol index for efficient symbol lookups across projects. - Improved file deletion handling on Windows to avoid permission errors. - Updated chunk configuration to increase overlap for better context. - Modified package.json test script to target specific test files. - Created comprehensive writing style guidelines for documentation. - Added TypeScript tests for query intent detection and adaptive weights. - Established performance benchmarks for global symbol indexing.
This commit is contained in:
Binary file not shown.
41
codex-lens/CHANGELOG.md
Normal file
41
codex-lens/CHANGELOG.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# CodexLens – Optimization Plan Changelog
|
||||
|
||||
This changelog tracks the **CodexLens optimization plan** milestones (not the Python package version in `pyproject.toml`).
|
||||
|
||||
## v1.0 (Optimization) – 2025-12-26
|
||||
|
||||
### Optimizations
|
||||
|
||||
1. **P0: Context-aware hybrid chunking**
|
||||
- Docstrings are extracted into dedicated chunks and excluded from code chunks.
|
||||
- Docstring chunks include `parent_symbol` metadata when the docstring belongs to a function/class/method.
|
||||
- Sliding-window chunk boundaries are deterministic for identical input.
|
||||
|
||||
2. **P1: Adaptive RRF weights (QueryIntent)**
|
||||
- Query intent is classified as `keyword` / `semantic` / `mixed`.
|
||||
- RRF weights adapt to intent:
|
||||
- `keyword`: exact-heavy (favors lexical matches)
|
||||
- `semantic`: vector-heavy (favors semantic matches)
|
||||
- `mixed`: keeps base/default weights
|
||||
|
||||
3. **P2: Symbol boost**
|
||||
- Fused results with an explicit symbol match (`symbol_name`) receive a multiplicative boost (default `1.5x`).
|
||||
|
||||
4. **P2: Embedding-based re-ranking (optional)**
|
||||
- A second-stage ranker can reorder top results by semantic similarity.
|
||||
- Re-ranking runs only when `Config.enable_reranking=True`.
|
||||
|
||||
5. **P3: Global symbol index (incremental + fast path)**
|
||||
- `GlobalSymbolIndex` stores project-wide symbols in one SQLite DB for fast symbol lookups.
|
||||
- `ChainSearchEngine.search_symbols()` uses the global index fast path when enabled.
|
||||
|
||||
### Migration Notes
|
||||
- **Reindexing (recommended)**: deterministic chunking and docstring metadata affect stored chunks. For best results, regenerate indexes/embeddings after upgrading:
|
||||
- Rebuild indexes and/or re-run embedding generation for existing projects.
|
||||
- **New config flags**:
|
||||
- `Config.enable_reranking` (default `False`)
|
||||
- `Config.reranking_top_k` (default `50`)
|
||||
- `Config.symbol_boost_factor` (default `1.5`)
|
||||
- `Config.global_symbol_index_enabled` (default `True`)
|
||||
- **Breaking changes**: none (behavioral improvements only).
|
||||
|
||||
@@ -103,6 +103,11 @@ class Config:
|
||||
# Indexing/search optimizations
|
||||
global_symbol_index_enabled: bool = True # Enable project-wide symbol index fast path
|
||||
|
||||
# Optional search reranking (disabled by default)
|
||||
enable_reranking: bool = False
|
||||
reranking_top_k: int = 50
|
||||
symbol_boost_factor: float = 1.5
|
||||
|
||||
# Multi-endpoint configuration for litellm backend
|
||||
embedding_endpoints: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# List of endpoint configs: [{"model": "...", "api_key": "...", "api_base": "...", "weight": 1.0}]
|
||||
|
||||
@@ -7,12 +7,38 @@ results via Reciprocal Rank Fusion (RRF) algorithm.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(name: str, logger: logging.Logger, level: int = logging.DEBUG):
|
||||
"""Context manager for timing code blocks.
|
||||
|
||||
Args:
|
||||
name: Name of the operation being timed
|
||||
logger: Logger instance to use
|
||||
level: Logging level (default DEBUG)
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
logger.log(level, "[TIMING] %s: %.2fms", name, elapsed_ms)
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import reciprocal_rank_fusion, tag_search_source
|
||||
from codexlens.search.ranking import (
|
||||
apply_symbol_boost,
|
||||
get_rrf_weights,
|
||||
reciprocal_rank_fusion,
|
||||
rerank_results,
|
||||
tag_search_source,
|
||||
)
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
|
||||
|
||||
@@ -34,14 +60,23 @@ class HybridSearchEngine:
|
||||
"vector": 0.6,
|
||||
}
|
||||
|
||||
def __init__(self, weights: Optional[Dict[str, float]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
weights: Optional[Dict[str, float]] = None,
|
||||
config: Optional[Config] = None,
|
||||
embedder: Any = None,
|
||||
):
|
||||
"""Initialize hybrid search engine.
|
||||
|
||||
Args:
|
||||
weights: Optional custom RRF weights (default: DEFAULT_WEIGHTS)
|
||||
config: Optional runtime config (enables optional reranking features)
|
||||
embedder: Optional embedder instance for embedding-based reranking
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.weights = weights or self.DEFAULT_WEIGHTS.copy()
|
||||
self._config = config
|
||||
self.embedder = embedder
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -101,7 +136,8 @@ class HybridSearchEngine:
|
||||
backends["vector"] = True
|
||||
|
||||
# Execute parallel searches
|
||||
results_map = self._search_parallel(index_path, query, backends, limit)
|
||||
with timer("parallel_search_total", self.logger):
|
||||
results_map = self._search_parallel(index_path, query, backends, limit)
|
||||
|
||||
# Provide helpful message if pure-vector mode returns no results
|
||||
if pure_vector and enable_vector and len(results_map.get("vector", [])) == 0:
|
||||
@@ -120,11 +156,72 @@ class HybridSearchEngine:
|
||||
if source in results_map
|
||||
}
|
||||
|
||||
fused_results = reciprocal_rank_fusion(results_map, active_weights)
|
||||
with timer("rrf_fusion", self.logger):
|
||||
adaptive_weights = get_rrf_weights(query, active_weights)
|
||||
fused_results = reciprocal_rank_fusion(results_map, adaptive_weights)
|
||||
|
||||
# Optional: boost results that include explicit symbol matches
|
||||
boost_factor = (
|
||||
self._config.symbol_boost_factor
|
||||
if self._config is not None
|
||||
else 1.5
|
||||
)
|
||||
with timer("symbol_boost", self.logger):
|
||||
fused_results = apply_symbol_boost(
|
||||
fused_results, boost_factor=boost_factor
|
||||
)
|
||||
|
||||
# Optional: embedding-based reranking on top results
|
||||
if self._config is not None and self._config.enable_reranking:
|
||||
with timer("reranking", self.logger):
|
||||
if self.embedder is None:
|
||||
self.embedder = self._get_reranking_embedder()
|
||||
fused_results = rerank_results(
|
||||
query,
|
||||
fused_results[:100],
|
||||
self.embedder,
|
||||
top_k=self._config.reranking_top_k,
|
||||
)
|
||||
|
||||
# Apply final limit
|
||||
return fused_results[:limit]
|
||||
|
||||
def _get_reranking_embedder(self) -> Any:
|
||||
"""Create an embedder for reranking based on Config embedding settings."""
|
||||
if self._config is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from codexlens.semantic.factory import get_embedder
|
||||
except Exception as exc:
|
||||
self.logger.debug("Reranking embedder unavailable: %s", exc)
|
||||
return None
|
||||
|
||||
try:
|
||||
if self._config.embedding_backend == "fastembed":
|
||||
return get_embedder(
|
||||
backend="fastembed",
|
||||
profile=self._config.embedding_model,
|
||||
use_gpu=self._config.embedding_use_gpu,
|
||||
)
|
||||
if self._config.embedding_backend == "litellm":
|
||||
return get_embedder(
|
||||
backend="litellm",
|
||||
model=self._config.embedding_model,
|
||||
endpoints=self._config.embedding_endpoints,
|
||||
strategy=self._config.embedding_strategy,
|
||||
cooldown=self._config.embedding_cooldown,
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to initialize reranking embedder: %s", exc)
|
||||
return None
|
||||
|
||||
self.logger.debug(
|
||||
"Unknown embedding backend for reranking: %s",
|
||||
self._config.embedding_backend,
|
||||
)
|
||||
return None
|
||||
|
||||
def _search_parallel(
|
||||
self,
|
||||
index_path: Path,
|
||||
@@ -144,25 +241,30 @@ class HybridSearchEngine:
|
||||
Dictionary mapping source name to results list
|
||||
"""
|
||||
results_map: Dict[str, List[SearchResult]] = {}
|
||||
timing_data: Dict[str, float] = {}
|
||||
|
||||
# Use ThreadPoolExecutor for parallel I/O-bound searches
|
||||
with ThreadPoolExecutor(max_workers=len(backends)) as executor:
|
||||
# Submit search tasks
|
||||
# Submit search tasks with timing
|
||||
future_to_source = {}
|
||||
submit_times = {}
|
||||
|
||||
if backends.get("exact"):
|
||||
submit_times["exact"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_exact, index_path, query, limit
|
||||
)
|
||||
future_to_source[future] = "exact"
|
||||
|
||||
if backends.get("fuzzy"):
|
||||
submit_times["fuzzy"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_fuzzy, index_path, query, limit
|
||||
)
|
||||
future_to_source[future] = "fuzzy"
|
||||
|
||||
if backends.get("vector"):
|
||||
submit_times["vector"] = time.perf_counter()
|
||||
future = executor.submit(
|
||||
self._search_vector, index_path, query, limit
|
||||
)
|
||||
@@ -171,18 +273,26 @@ class HybridSearchEngine:
|
||||
# Collect results as they complete
|
||||
for future in as_completed(future_to_source):
|
||||
source = future_to_source[future]
|
||||
elapsed_ms = (time.perf_counter() - submit_times[source]) * 1000
|
||||
timing_data[source] = elapsed_ms
|
||||
try:
|
||||
results = future.result()
|
||||
# Tag results with source for debugging
|
||||
tagged_results = tag_search_source(results, source)
|
||||
results_map[source] = tagged_results
|
||||
self.logger.debug(
|
||||
"Got %d results from %s search", len(results), source
|
||||
"[TIMING] %s_search: %.2fms (%d results)",
|
||||
source, elapsed_ms, len(results)
|
||||
)
|
||||
except Exception as exc:
|
||||
self.logger.error("Search failed for %s: %s", source, exc)
|
||||
results_map[source] = []
|
||||
|
||||
# Log timing summary
|
||||
if timing_data:
|
||||
timing_str = ", ".join(f"{k}={v:.1f}ms" for k, v in timing_data.items())
|
||||
self.logger.debug("[TIMING] search_backends: {%s}", timing_str)
|
||||
|
||||
return results_map
|
||||
|
||||
def _search_exact(
|
||||
@@ -245,6 +355,8 @@ class HybridSearchEngine:
|
||||
try:
|
||||
# Check if semantic chunks table exists
|
||||
import sqlite3
|
||||
|
||||
start_check = time.perf_counter()
|
||||
try:
|
||||
with sqlite3.connect(index_path) as conn:
|
||||
cursor = conn.execute(
|
||||
@@ -254,6 +366,10 @@ class HybridSearchEngine:
|
||||
except sqlite3.Error as e:
|
||||
self.logger.error("Database check failed in vector search: %s", e)
|
||||
return []
|
||||
self.logger.debug(
|
||||
"[TIMING] vector_table_check: %.2fms",
|
||||
(time.perf_counter() - start_check) * 1000
|
||||
)
|
||||
|
||||
if not has_semantic_table:
|
||||
self.logger.info(
|
||||
@@ -267,7 +383,12 @@ class HybridSearchEngine:
|
||||
from codexlens.semantic.factory import get_embedder
|
||||
from codexlens.semantic.vector_store import VectorStore
|
||||
|
||||
start_init = time.perf_counter()
|
||||
vector_store = VectorStore(index_path)
|
||||
self.logger.debug(
|
||||
"[TIMING] vector_store_init: %.2fms",
|
||||
(time.perf_counter() - start_init) * 1000
|
||||
)
|
||||
|
||||
# Check if vector store has data
|
||||
if vector_store.count_chunks() == 0:
|
||||
@@ -279,6 +400,7 @@ class HybridSearchEngine:
|
||||
return []
|
||||
|
||||
# Get stored model configuration (preferred) or auto-detect from dimension
|
||||
start_embedder = time.perf_counter()
|
||||
model_config = vector_store.get_model_config()
|
||||
if model_config:
|
||||
backend = model_config.get("backend", "fastembed")
|
||||
@@ -288,7 +410,7 @@ class HybridSearchEngine:
|
||||
"Using stored model config: %s backend, %s (%s, %dd)",
|
||||
backend, model_profile, model_name, model_config["embedding_dim"]
|
||||
)
|
||||
|
||||
|
||||
# Get embedder based on backend
|
||||
if backend == "litellm":
|
||||
embedder = get_embedder(backend="litellm", model=model_name)
|
||||
@@ -324,21 +446,32 @@ class HybridSearchEngine:
|
||||
detected_dim
|
||||
)
|
||||
embedder = get_embedder(backend="fastembed", profile="code")
|
||||
|
||||
|
||||
self.logger.debug(
|
||||
"[TIMING] embedder_init: %.2fms",
|
||||
(time.perf_counter() - start_embedder) * 1000
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
start_embed = time.perf_counter()
|
||||
query_embedding = embedder.embed_single(query)
|
||||
self.logger.debug(
|
||||
"[TIMING] query_embedding: %.2fms",
|
||||
(time.perf_counter() - start_embed) * 1000
|
||||
)
|
||||
|
||||
# Search for similar chunks
|
||||
start_search = time.perf_counter()
|
||||
results = vector_store.search_similar(
|
||||
query_embedding=query_embedding,
|
||||
top_k=limit,
|
||||
min_score=0.0, # Return all results, let RRF handle filtering
|
||||
return_full_content=True,
|
||||
)
|
||||
self.logger.debug(
|
||||
"[TIMING] vector_similarity_search: %.2fms (%d results)",
|
||||
(time.perf_counter() - start_search) * 1000, len(results)
|
||||
)
|
||||
|
||||
self.logger.debug("Vector search found %d results", len(results))
|
||||
return results
|
||||
|
||||
except ImportError as exc:
|
||||
|
||||
@@ -6,12 +6,98 @@ for combining results from heterogeneous search backends (exact FTS, fuzzy FTS,
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import math
|
||||
from typing import Dict, List
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
class QueryIntent(str, Enum):
|
||||
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
def normalize_weights(weights: Dict[str, float]) -> Dict[str, float]:
|
||||
"""Normalize weights to sum to 1.0 (best-effort)."""
|
||||
total = sum(float(v) for v in weights.values() if v is not None)
|
||||
if not math.isfinite(total) or total <= 0:
|
||||
return {k: float(v) for k, v in weights.items()}
|
||||
return {k: float(v) / total for k, v in weights.items()}
|
||||
|
||||
|
||||
def detect_query_intent(query: str) -> QueryIntent:
|
||||
"""Detect whether a query is code-like, natural-language, or mixed.
|
||||
|
||||
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
|
||||
"""
|
||||
trimmed = (query or "").strip()
|
||||
if not trimmed:
|
||||
return QueryIntent.MIXED
|
||||
|
||||
lower = trimmed.lower()
|
||||
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
|
||||
|
||||
has_code_signals = bool(
|
||||
re.search(r"(::|->|\.)", trimmed)
|
||||
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
|
||||
or re.search(r"\b\w+_\w+\b", trimmed)
|
||||
or re.search(
|
||||
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
|
||||
lower,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
has_natural_signals = bool(
|
||||
word_count > 5
|
||||
or "?" in trimmed
|
||||
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
|
||||
or re.search(
|
||||
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
|
||||
trimmed,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
|
||||
if has_code_signals and has_natural_signals:
|
||||
return QueryIntent.MIXED
|
||||
if has_code_signals:
|
||||
return QueryIntent.KEYWORD
|
||||
if has_natural_signals:
|
||||
return QueryIntent.SEMANTIC
|
||||
return QueryIntent.MIXED
|
||||
|
||||
|
||||
def adjust_weights_by_intent(
|
||||
intent: QueryIntent,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Map intent → weights (kept aligned with TypeScript mapping)."""
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
# Preserve only keys that are present in base_weights (active backends).
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
return normalize_weights(filtered)
|
||||
|
||||
|
||||
def get_rrf_weights(
|
||||
query: str,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Compute adaptive RRF weights from query intent."""
|
||||
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
results_map: Dict[str, List[SearchResult]],
|
||||
weights: Dict[str, float] = None,
|
||||
@@ -102,6 +188,186 @@ def reciprocal_rank_fusion(
|
||||
return fused_results
|
||||
|
||||
|
||||
def apply_symbol_boost(
|
||||
results: List[SearchResult],
|
||||
boost_factor: float = 1.5,
|
||||
) -> List[SearchResult]:
|
||||
"""Boost fused scores for results that include an explicit symbol match.
|
||||
|
||||
The boost is multiplicative on the current result.score (typically the RRF fusion score).
|
||||
When boosted, the original score is preserved in metadata["original_fusion_score"] and
|
||||
metadata["boosted"] is set to True.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if boost_factor <= 1.0:
|
||||
# Still return new objects to follow immutable transformation pattern.
|
||||
return [
|
||||
SearchResult(
|
||||
path=r.path,
|
||||
score=r.score,
|
||||
excerpt=r.excerpt,
|
||||
content=r.content,
|
||||
symbol=r.symbol,
|
||||
chunk=r.chunk,
|
||||
metadata={**r.metadata},
|
||||
start_line=r.start_line,
|
||||
end_line=r.end_line,
|
||||
symbol_name=r.symbol_name,
|
||||
symbol_kind=r.symbol_kind,
|
||||
additional_locations=list(r.additional_locations),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
boosted_results: List[SearchResult] = []
|
||||
for result in results:
|
||||
has_symbol = bool(result.symbol_name)
|
||||
original_score = float(result.score)
|
||||
boosted_score = original_score * boost_factor if has_symbol else original_score
|
||||
|
||||
metadata = {**result.metadata}
|
||||
if has_symbol:
|
||||
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
|
||||
metadata["boosted"] = True
|
||||
metadata["symbol_boost_factor"] = boost_factor
|
||||
|
||||
boosted_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=boosted_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata=metadata,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
boosted_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return boosted_results
|
||||
|
||||
|
||||
def rerank_results(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
embedder: Any,
|
||||
top_k: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Re-rank results with embedding cosine similarity, combined with current score.
|
||||
|
||||
Combined score formula:
|
||||
0.5 * rrf_score + 0.5 * cosine_similarity
|
||||
|
||||
If embedder is None or embedding fails, returns results as-is.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if embedder is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
||||
# Defensive: handle mismatched lengths and zero vectors.
|
||||
n = min(len(vec_a), len(vec_b))
|
||||
if n == 0:
|
||||
return 0.0
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
for i in range(n):
|
||||
a = float(vec_a[i])
|
||||
b = float(vec_b[i])
|
||||
dot += a * b
|
||||
norm_a += a * a
|
||||
norm_b += b * b
|
||||
if norm_a <= 0.0 or norm_b <= 0.0:
|
||||
return 0.0
|
||||
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
|
||||
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
|
||||
return max(0.0, min(1.0, sim))
|
||||
|
||||
def text_for_embedding(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
# Fallback: stable, non-empty text.
|
||||
return r.symbol_name or r.path
|
||||
|
||||
try:
|
||||
if hasattr(embedder, "embed_single"):
|
||||
query_vec = embedder.embed_single(query)
|
||||
else:
|
||||
query_vec = embedder.embed(query)[0]
|
||||
|
||||
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
|
||||
doc_vecs = embedder.embed(doc_texts)
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
rrf_score = float(result.score)
|
||||
sim = cosine_similarity(query_vec, doc_vecs[idx])
|
||||
combined_score = 0.5 * rrf_score + 0.5 * sim
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"rrf_score": rrf_score,
|
||||
"cosine_similarity": sim,
|
||||
"reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Preserve remaining results without re-ranking, but keep immutability.
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def normalize_bm25_score(score: float) -> float:
|
||||
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||
|
||||
|
||||
@@ -392,6 +392,22 @@ class HybridChunker:
|
||||
filtered.append(symbol)
|
||||
return filtered
|
||||
|
||||
def _find_parent_symbol(
|
||||
self,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
symbols: List[Symbol],
|
||||
) -> Optional[Symbol]:
|
||||
"""Find the smallest symbol range that fully contains a docstring span."""
|
||||
candidates: List[Symbol] = []
|
||||
for symbol in symbols:
|
||||
sym_start, sym_end = symbol.range
|
||||
if sym_start <= start_line and end_line <= sym_end:
|
||||
candidates.append(symbol)
|
||||
if not candidates:
|
||||
return None
|
||||
return min(candidates, key=lambda s: (s.range[1] - s.range[0], s.range[0]))
|
||||
|
||||
def chunk_file(
|
||||
self,
|
||||
content: str,
|
||||
@@ -414,24 +430,53 @@ class HybridChunker:
|
||||
chunks: List[SemanticChunk] = []
|
||||
|
||||
# Step 1: Extract docstrings as dedicated chunks
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
docstrings: List[Tuple[str, int, int]] = []
|
||||
if language == "python":
|
||||
# Fast path: avoid expensive docstring extraction if delimiters are absent.
|
||||
if '"""' in content or "'''" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
if "/**" in content:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
else:
|
||||
docstrings = self.docstring_extractor.extract_docstrings(content, language)
|
||||
|
||||
# Fast path: no docstrings -> delegate to base chunker directly.
|
||||
if not docstrings:
|
||||
if symbols:
|
||||
base_chunks = self.base_chunker.chunk_by_symbol(
|
||||
content, symbols, file_path, language, symbol_token_counts
|
||||
)
|
||||
else:
|
||||
base_chunks = self.base_chunker.chunk_sliding_window(content, file_path, language)
|
||||
|
||||
for chunk in base_chunks:
|
||||
chunk.metadata["strategy"] = "hybrid"
|
||||
chunk.metadata["chunk_type"] = "code"
|
||||
return base_chunks
|
||||
|
||||
for docstring_content, start_line, end_line in docstrings:
|
||||
if len(docstring_content.strip()) >= self.config.min_chunk_size:
|
||||
parent_symbol = self._find_parent_symbol(start_line, end_line, symbols)
|
||||
# Use base chunker's token estimation method
|
||||
token_count = self.base_chunker._estimate_token_count(docstring_content)
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
if parent_symbol is not None:
|
||||
metadata["parent_symbol"] = parent_symbol.name
|
||||
metadata["parent_symbol_kind"] = parent_symbol.kind
|
||||
metadata["parent_symbol_range"] = parent_symbol.range
|
||||
chunks.append(SemanticChunk(
|
||||
content=docstring_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_type": "docstring",
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "hybrid",
|
||||
"token_count": token_count,
|
||||
}
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Step 2: Get line ranges occupied by docstrings
|
||||
|
||||
293
codex-lens/tests/test_global_index.py
Normal file
293
codex-lens/tests/test_global_index.py
Normal file
@@ -0,0 +1,293 @@
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import Symbol
|
||||
from codexlens.errors import StorageError
|
||||
from codexlens.search.chain_search import ChainSearchEngine
|
||||
from codexlens.storage.global_index import GlobalSymbolIndex
|
||||
from codexlens.storage.path_mapper import PathMapper
|
||||
from codexlens.storage.registry import RegistryStore
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def temp_paths():
|
||||
tmpdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
||||
root = Path(tmpdir.name)
|
||||
yield root
|
||||
try:
|
||||
tmpdir.cleanup()
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def test_add_symbol(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "a.py"
|
||||
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=1) as store:
|
||||
store.add_symbol(
|
||||
Symbol(name="AuthManager", kind="class", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
matches = store.search("AuthManager", kind="class", limit=10, prefix_mode=True)
|
||||
assert len(matches) == 1
|
||||
assert matches[0].name == "AuthManager"
|
||||
assert matches[0].file == str(file_path.resolve())
|
||||
|
||||
# Schema version safety: newer schema versions should be rejected.
|
||||
bad_db = temp_paths / "indexes" / "_global_symbols_bad.db"
|
||||
bad_db.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(bad_db)
|
||||
conn.execute("PRAGMA user_version = 999")
|
||||
conn.close()
|
||||
|
||||
with pytest.raises(StorageError):
|
||||
GlobalSymbolIndex(bad_db, project_id=1).initialize()
|
||||
|
||||
|
||||
def test_search_symbols(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "mod.py"
|
||||
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("def authenticate():\n pass\n", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=7) as store:
|
||||
store.add_symbol(
|
||||
Symbol(name="authenticate", kind="function", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
locations = store.search_symbols("auth", kind="function", limit=10, prefix_mode=True)
|
||||
assert locations
|
||||
assert any(p.endswith("mod.py") for p, _ in locations)
|
||||
assert any(rng == (1, 2) for _, rng in locations)
|
||||
|
||||
|
||||
def test_update_file_symbols(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
file_path = temp_paths / "src" / "mod.py"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("def a():\n pass\n", encoding="utf-8")
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=7) as store:
|
||||
store.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[
|
||||
Symbol(name="old_func", kind="function", range=(1, 2)),
|
||||
Symbol(name="Other", kind="class", range=(10, 20)),
|
||||
],
|
||||
index_path=index_path,
|
||||
)
|
||||
assert any(s.name == "old_func" for s in store.search("old_", prefix_mode=True))
|
||||
|
||||
store.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[Symbol(name="new_func", kind="function", range=(3, 4))],
|
||||
index_path=index_path,
|
||||
)
|
||||
assert not any(s.name == "old_func" for s in store.search("old_", prefix_mode=True))
|
||||
assert any(s.name == "new_func" for s in store.search("new_", prefix_mode=True))
|
||||
|
||||
# Backward-compatible path: index_path can be omitted after it's been established.
|
||||
store.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[Symbol(name="new_func2", kind="function", range=(5, 6))],
|
||||
index_path=None,
|
||||
)
|
||||
assert any(s.name == "new_func2" for s in store.search("new_func2", prefix_mode=True))
|
||||
|
||||
# New file + symbols without index_path should raise.
|
||||
missing_index_file = temp_paths / "src" / "new_file.py"
|
||||
with pytest.raises(StorageError):
|
||||
store.update_file_symbols(
|
||||
file_path=missing_index_file,
|
||||
symbols=[Symbol(name="must_fail", kind="function", range=(1, 1))],
|
||||
index_path=None,
|
||||
)
|
||||
|
||||
deleted = store.delete_file_symbols(file_path)
|
||||
assert deleted > 0
|
||||
|
||||
|
||||
def test_incremental_updates(temp_paths: Path, monkeypatch):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
file_path = temp_paths / "src" / "same.py"
|
||||
idx_a = temp_paths / "indexes" / "a" / "_index.db"
|
||||
idx_b = temp_paths / "indexes" / "b" / "_index.db"
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
|
||||
idx_a.parent.mkdir(parents=True, exist_ok=True)
|
||||
idx_a.write_text("", encoding="utf-8")
|
||||
idx_b.parent.mkdir(parents=True, exist_ok=True)
|
||||
idx_b.write_text("", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=42) as store:
|
||||
sym = Symbol(name="AuthManager", kind="class", range=(1, 2))
|
||||
store.add_symbol(sym, file_path=file_path, index_path=idx_a)
|
||||
store.add_symbol(sym, file_path=file_path, index_path=idx_b)
|
||||
|
||||
# prefix_mode=False exercises substring matching.
|
||||
assert store.search("Manager", prefix_mode=False)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT index_path
|
||||
FROM global_symbols
|
||||
WHERE project_id=? AND symbol_name=? AND symbol_kind=? AND file_path=?
|
||||
""",
|
||||
(42, "AuthManager", "class", str(file_path.resolve())),
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row is not None
|
||||
assert str(Path(row[0]).resolve()) == str(idx_b.resolve())
|
||||
|
||||
# Migration path coverage: simulate a future schema version and an older DB version.
|
||||
migrating_db = temp_paths / "indexes" / "_global_symbols_migrate.db"
|
||||
migrating_db.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(migrating_db)
|
||||
conn.execute("PRAGMA user_version = 1")
|
||||
conn.close()
|
||||
|
||||
monkeypatch.setattr(GlobalSymbolIndex, "SCHEMA_VERSION", 2)
|
||||
GlobalSymbolIndex(migrating_db, project_id=1).initialize()
|
||||
|
||||
|
||||
def test_concurrent_access(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "a.py"
|
||||
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class A:\n pass\n", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=1) as store:
|
||||
def add_many(worker_id: int):
|
||||
for i in range(50):
|
||||
store.add_symbol(
|
||||
Symbol(name=f"Sym{worker_id}_{i}", kind="class", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as ex:
|
||||
list(ex.map(add_many, range(8)))
|
||||
|
||||
matches = store.search("Sym", kind="class", limit=1000, prefix_mode=True)
|
||||
assert len(matches) >= 200
|
||||
|
||||
|
||||
def test_chain_search_integration(temp_paths: Path):
|
||||
project_root = temp_paths / "project"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_root = temp_paths / "indexes"
|
||||
mapper = PathMapper(index_root=index_root)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
index_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_db_path.write_text("", encoding="utf-8")
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
project_info = registry.register_project(project_root, mapper.source_to_index_dir(project_root))
|
||||
|
||||
global_db_path = project_info.index_root / GlobalSymbolIndex.DEFAULT_DB_NAME
|
||||
with GlobalSymbolIndex(global_db_path, project_id=project_info.id) as global_index:
|
||||
file_path = project_root / "auth.py"
|
||||
global_index.update_file_symbols(
|
||||
file_path=file_path,
|
||||
symbols=[
|
||||
Symbol(name="AuthManager", kind="class", range=(1, 10)),
|
||||
Symbol(name="authenticate", kind="function", range=(12, 20)),
|
||||
],
|
||||
index_path=index_db_path,
|
||||
)
|
||||
|
||||
config = Config(data_dir=temp_paths / "data", global_symbol_index_enabled=True)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
engine._search_symbols_parallel = MagicMock(side_effect=AssertionError("should not traverse chain"))
|
||||
|
||||
symbols = engine.search_symbols("Auth", project_root)
|
||||
assert any(s.name == "AuthManager" for s in symbols)
|
||||
registry.close()
|
||||
|
||||
|
||||
def test_disabled_fallback(temp_paths: Path):
|
||||
project_root = temp_paths / "project"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
index_root = temp_paths / "indexes"
|
||||
mapper = PathMapper(index_root=index_root)
|
||||
index_db_path = mapper.source_to_index_db(project_root)
|
||||
index_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_db_path.write_text("", encoding="utf-8")
|
||||
|
||||
registry = RegistryStore(db_path=temp_paths / "registry.db")
|
||||
registry.initialize()
|
||||
registry.register_project(project_root, mapper.source_to_index_dir(project_root))
|
||||
|
||||
config = Config(data_dir=temp_paths / "data", global_symbol_index_enabled=False)
|
||||
engine = ChainSearchEngine(registry, mapper, config=config)
|
||||
engine._collect_index_paths = MagicMock(return_value=[index_db_path])
|
||||
engine._search_symbols_parallel = MagicMock(
|
||||
return_value=[Symbol(name="FallbackSymbol", kind="function", range=(1, 2))]
|
||||
)
|
||||
|
||||
symbols = engine.search_symbols("Fallback", project_root)
|
||||
assert any(s.name == "FallbackSymbol" for s in symbols)
|
||||
assert engine._search_symbols_parallel.called
|
||||
registry.close()
|
||||
|
||||
|
||||
def test_performance_benchmark(temp_paths: Path):
|
||||
db_path = temp_paths / "indexes" / "_global_symbols.db"
|
||||
index_path = temp_paths / "indexes" / "_index.db"
|
||||
file_path = temp_paths / "src" / "perf.py"
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("class AuthManager:\n pass\n", encoding="utf-8")
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
index_path.write_text("", encoding="utf-8")
|
||||
|
||||
with GlobalSymbolIndex(db_path, project_id=1) as store:
|
||||
for i in range(500):
|
||||
store.add_symbol(
|
||||
Symbol(name=f"AuthManager{i}", kind="class", range=(1, 2)),
|
||||
file_path=file_path,
|
||||
index_path=index_path,
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
results = store.search("AuthManager", kind="class", limit=50, prefix_mode=True)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
assert elapsed_ms < 100.0
|
||||
assert results
|
||||
@@ -551,3 +551,72 @@ class UserProfile:
|
||||
# Verify <15% overhead (reasonable threshold for performance tests with system variance)
|
||||
assert overhead < 15.0, f"Overhead {overhead:.2f}% exceeds 15% threshold (base={base_time:.4f}s, hybrid={hybrid_time:.4f}s)"
|
||||
|
||||
|
||||
class TestHybridChunkerV1Optimizations:
|
||||
"""Tests for v1.0 optimization behaviors (parent metadata + determinism)."""
|
||||
|
||||
def test_merged_docstring_metadata(self):
|
||||
"""Docstring chunks include parent_symbol metadata when applicable."""
|
||||
config = ChunkConfig(min_chunk_size=1)
|
||||
chunker = HybridChunker(config=config)
|
||||
|
||||
content = '''"""Module docstring."""
|
||||
|
||||
def hello():
|
||||
"""Function docstring."""
|
||||
return 1
|
||||
'''
|
||||
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
|
||||
|
||||
chunks = chunker.chunk_file(content, symbols, "m.py", "python")
|
||||
func_doc_chunks = [
|
||||
c for c in chunks
|
||||
if c.metadata.get("chunk_type") == "docstring" and c.metadata.get("start_line") == 4
|
||||
]
|
||||
assert len(func_doc_chunks) == 1
|
||||
assert func_doc_chunks[0].metadata.get("parent_symbol") == "hello"
|
||||
assert func_doc_chunks[0].metadata.get("parent_symbol_kind") == "function"
|
||||
|
||||
def test_deterministic_chunk_boundaries(self):
|
||||
"""Chunk boundaries are stable across repeated runs on identical input."""
|
||||
config = ChunkConfig(max_chunk_size=80, overlap=10, min_chunk_size=1)
|
||||
chunker = HybridChunker(config=config)
|
||||
|
||||
# No docstrings, no symbols -> sliding window path.
|
||||
content = "\n".join([f"line {i}: x = {i}" for i in range(1, 200)]) + "\n"
|
||||
|
||||
boundaries = []
|
||||
for _ in range(3):
|
||||
chunks = chunker.chunk_file(content, [], "deterministic.py", "python")
|
||||
boundaries.append([
|
||||
(c.metadata.get("start_line"), c.metadata.get("end_line"))
|
||||
for c in chunks
|
||||
if c.metadata.get("chunk_type") == "code"
|
||||
])
|
||||
|
||||
assert boundaries[0] == boundaries[1] == boundaries[2]
|
||||
|
||||
def test_orphan_docstrings(self):
|
||||
"""Module-level docstrings remain standalone (no parent_symbol assigned)."""
|
||||
config = ChunkConfig(min_chunk_size=1)
|
||||
chunker = HybridChunker(config=config)
|
||||
|
||||
content = '''"""Module-level docstring."""
|
||||
|
||||
def hello():
|
||||
"""Function docstring."""
|
||||
return 1
|
||||
'''
|
||||
symbols = [Symbol(name="hello", kind="function", range=(3, 5))]
|
||||
chunks = chunker.chunk_file(content, symbols, "orphan.py", "python")
|
||||
|
||||
module_doc = [
|
||||
c for c in chunks
|
||||
if c.metadata.get("chunk_type") == "docstring" and c.metadata.get("start_line") == 1
|
||||
]
|
||||
assert len(module_doc) == 1
|
||||
assert module_doc[0].metadata.get("parent_symbol") is None
|
||||
|
||||
code_chunks = [c for c in chunks if c.metadata.get("chunk_type") == "code"]
|
||||
assert code_chunks, "Expected at least one code chunk"
|
||||
assert all("Module-level docstring" not in c.content for c in code_chunks)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codexlens.config import Config
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.hybrid_search import HybridSearchEngine
|
||||
from codexlens.storage.dir_index import DirIndexStore
|
||||
@@ -774,3 +775,97 @@ class TestHybridSearchWithVectorMock:
|
||||
assert hasattr(result, 'score')
|
||||
assert result.score > 0 # RRF fusion scores are positive
|
||||
|
||||
|
||||
class TestHybridSearchAdaptiveWeights:
|
||||
"""Integration tests for adaptive RRF weights + reranking gating."""
|
||||
|
||||
def test_adaptive_weights_code_query(self):
|
||||
"""Exact weight should dominate for code-like queries."""
|
||||
from unittest.mock import patch
|
||||
|
||||
engine = HybridSearchEngine()
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
captured = {}
|
||||
from codexlens.search import ranking as ranking_module
|
||||
|
||||
def capture_rrf(map_in, weights_in, k=60):
|
||||
captured["weights"] = dict(weights_in)
|
||||
return ranking_module.reciprocal_rank_fusion(map_in, weights_in, k=k)
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.reciprocal_rank_fusion",
|
||||
side_effect=capture_rrf,
|
||||
):
|
||||
engine.search(Path("dummy.db"), "def authenticate", enable_vector=True)
|
||||
|
||||
assert captured["weights"]["exact"] > 0.4
|
||||
|
||||
def test_adaptive_weights_nl_query(self):
|
||||
"""Vector weight should dominate for natural-language queries."""
|
||||
from unittest.mock import patch
|
||||
|
||||
engine = HybridSearchEngine()
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
captured = {}
|
||||
from codexlens.search import ranking as ranking_module
|
||||
|
||||
def capture_rrf(map_in, weights_in, k=60):
|
||||
captured["weights"] = dict(weights_in)
|
||||
return ranking_module.reciprocal_rank_fusion(map_in, weights_in, k=k)
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.reciprocal_rank_fusion",
|
||||
side_effect=capture_rrf,
|
||||
):
|
||||
engine.search(Path("dummy.db"), "how to handle user login", enable_vector=True)
|
||||
|
||||
assert captured["weights"]["vector"] > 0.6
|
||||
|
||||
def test_reranking_enabled(self, tmp_path):
|
||||
"""Reranking runs only when explicitly enabled via config."""
|
||||
from unittest.mock import patch
|
||||
|
||||
results_map = {
|
||||
"exact": [SearchResult(path="a.py", score=10.0, excerpt="a")],
|
||||
"fuzzy": [SearchResult(path="b.py", score=9.0, excerpt="b")],
|
||||
"vector": [SearchResult(path="c.py", score=0.9, excerpt="c")],
|
||||
}
|
||||
|
||||
class DummyEmbedder:
|
||||
def embed(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
return [[1.0, 0.0] for _ in texts]
|
||||
|
||||
# Disabled: should not invoke rerank_results
|
||||
config_off = Config(data_dir=tmp_path / "off", enable_reranking=False)
|
||||
engine_off = HybridSearchEngine(config=config_off, embedder=DummyEmbedder())
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.rerank_results"
|
||||
) as rerank_mock:
|
||||
engine_off.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
rerank_mock.assert_not_called()
|
||||
|
||||
# Enabled: should invoke rerank_results once
|
||||
config_on = Config(data_dir=tmp_path / "on", enable_reranking=True, reranking_top_k=10)
|
||||
engine_on = HybridSearchEngine(config=config_on, embedder=DummyEmbedder())
|
||||
|
||||
with patch.object(HybridSearchEngine, "_search_parallel", return_value=results_map), patch(
|
||||
"codexlens.search.hybrid_search.rerank_results",
|
||||
side_effect=lambda q, r, e, top_k=50: r,
|
||||
) as rerank_mock:
|
||||
engine_on.search(Path("dummy.db"), "query", enable_vector=True)
|
||||
assert rerank_mock.call_count == 1
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import pytest
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from codexlens.search.hybrid_search import HybridSearchEngine
|
||||
@@ -16,6 +17,22 @@ except ImportError:
|
||||
SEMANTIC_DEPS_AVAILABLE = False
|
||||
|
||||
|
||||
def _safe_unlink(path: Path, retries: int = 5, delay_s: float = 0.05) -> None:
|
||||
"""Best-effort unlink for Windows where SQLite can keep files locked briefly."""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
path.unlink()
|
||||
return
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except PermissionError:
|
||||
time.sleep(delay_s * (attempt + 1))
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
class TestPureVectorSearch:
|
||||
"""Tests for pure vector search mode."""
|
||||
|
||||
@@ -48,7 +65,7 @@ class TestPureVectorSearch:
|
||||
store.close()
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
_safe_unlink(db_path)
|
||||
|
||||
def test_pure_vector_without_embeddings(self, sample_db):
|
||||
"""Test pure_vector mode returns empty when no embeddings exist."""
|
||||
@@ -200,12 +217,8 @@ def login_handler(credentials: dict) -> bool:
|
||||
yield db_path
|
||||
store.close()
|
||||
|
||||
# Ignore file deletion errors on Windows (SQLite file lock)
|
||||
try:
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
except PermissionError:
|
||||
pass # Ignore Windows file lock errors
|
||||
if db_path.exists():
|
||||
_safe_unlink(db_path)
|
||||
|
||||
def test_pure_vector_with_embeddings(self, db_with_embeddings):
|
||||
"""Test pure vector search returns results when embeddings exist."""
|
||||
@@ -289,7 +302,7 @@ class TestSearchModeComparison:
|
||||
store.close()
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
_safe_unlink(db_path)
|
||||
|
||||
def test_mode_comparison_without_embeddings(self, comparison_db):
|
||||
"""Compare all search modes without embeddings."""
|
||||
|
||||
@@ -7,8 +7,12 @@ import pytest
|
||||
|
||||
from codexlens.entities import SearchResult
|
||||
from codexlens.search.ranking import (
|
||||
apply_symbol_boost,
|
||||
QueryIntent,
|
||||
detect_query_intent,
|
||||
normalize_bm25_score,
|
||||
reciprocal_rank_fusion,
|
||||
rerank_results,
|
||||
tag_search_source,
|
||||
)
|
||||
|
||||
@@ -342,6 +346,62 @@ class TestTagSearchSource:
|
||||
assert tagged[0].symbol_kind == "function"
|
||||
|
||||
|
||||
class TestSymbolBoost:
|
||||
"""Tests for apply_symbol_boost function."""
|
||||
|
||||
def test_symbol_boost(self):
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.2, excerpt="...", symbol_name="foo"),
|
||||
SearchResult(path="b.py", score=0.21, excerpt="..."),
|
||||
]
|
||||
|
||||
boosted = apply_symbol_boost(results, boost_factor=1.5)
|
||||
|
||||
assert boosted[0].path == "a.py"
|
||||
assert boosted[0].score == pytest.approx(0.2 * 1.5)
|
||||
assert boosted[0].metadata["boosted"] is True
|
||||
assert boosted[0].metadata["original_fusion_score"] == pytest.approx(0.2)
|
||||
|
||||
assert boosted[1].path == "b.py"
|
||||
assert boosted[1].score == pytest.approx(0.21)
|
||||
assert "boosted" not in boosted[1].metadata
|
||||
|
||||
|
||||
class TestEmbeddingReranking:
|
||||
"""Tests for rerank_results embedding-based similarity."""
|
||||
|
||||
def test_rerank_embedding_similarity(self):
|
||||
class DummyEmbedder:
|
||||
def embed(self, texts):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
mapping = {
|
||||
"query": [1.0, 0.0],
|
||||
"doc1": [1.0, 0.0],
|
||||
"doc2": [0.0, 1.0],
|
||||
}
|
||||
return [mapping[t] for t in texts]
|
||||
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.2, excerpt="doc1"),
|
||||
SearchResult(path="b.py", score=0.9, excerpt="doc2"),
|
||||
]
|
||||
|
||||
reranked = rerank_results("query", results, DummyEmbedder(), top_k=2)
|
||||
|
||||
assert reranked[0].path == "a.py"
|
||||
assert reranked[0].metadata["reranked"] is True
|
||||
assert reranked[0].metadata["rrf_score"] == pytest.approx(0.2)
|
||||
assert reranked[0].metadata["cosine_similarity"] == pytest.approx(1.0)
|
||||
assert reranked[0].score == pytest.approx(0.5 * 0.2 + 0.5 * 1.0)
|
||||
|
||||
assert reranked[1].path == "b.py"
|
||||
assert reranked[1].metadata["reranked"] is True
|
||||
assert reranked[1].metadata["rrf_score"] == pytest.approx(0.9)
|
||||
assert reranked[1].metadata["cosine_similarity"] == pytest.approx(0.0)
|
||||
assert reranked[1].score == pytest.approx(0.5 * 0.9 + 0.5 * 0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k_value", [30, 60, 100])
|
||||
class TestRRFParameterized:
|
||||
"""Parameterized tests for RRF with different k values."""
|
||||
@@ -419,3 +479,41 @@ class TestRRFEdgeCases:
|
||||
# Should work with normalization
|
||||
assert len(fused) == 1 # Deduplicated
|
||||
assert fused[0].score > 0
|
||||
|
||||
|
||||
class TestSymbolBoostAndIntentV1:
|
||||
"""Tests for symbol boosting and query intent detection (v1.0)."""
|
||||
|
||||
def test_symbol_boost_application(self):
|
||||
"""Results with symbol_name receive a multiplicative boost (default 1.5x)."""
|
||||
results = [
|
||||
SearchResult(path="a.py", score=0.4, excerpt="...", symbol_name="AuthManager"),
|
||||
SearchResult(path="b.py", score=0.41, excerpt="..."),
|
||||
]
|
||||
|
||||
boosted = apply_symbol_boost(results, boost_factor=1.5)
|
||||
|
||||
assert boosted[0].score == pytest.approx(0.4 * 1.5)
|
||||
assert boosted[0].metadata["boosted"] is True
|
||||
assert boosted[0].metadata["original_fusion_score"] == pytest.approx(0.4)
|
||||
assert boosted[1].score == pytest.approx(0.41)
|
||||
assert "boosted" not in boosted[1].metadata
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("query", "expected"),
|
||||
[
|
||||
("def authenticate", QueryIntent.KEYWORD),
|
||||
("MyClass", QueryIntent.KEYWORD),
|
||||
("user_id", QueryIntent.KEYWORD),
|
||||
("UserService::authenticate", QueryIntent.KEYWORD),
|
||||
("ptr->next", QueryIntent.KEYWORD),
|
||||
("how to handle user login", QueryIntent.SEMANTIC),
|
||||
("what is authentication?", QueryIntent.SEMANTIC),
|
||||
("where is this used?", QueryIntent.SEMANTIC),
|
||||
("why does FooBar crash?", QueryIntent.MIXED),
|
||||
("how to use user_id in query", QueryIntent.MIXED),
|
||||
],
|
||||
)
|
||||
def test_query_intent_detection(self, query, expected):
|
||||
"""Detect intent for representative queries (Python/TypeScript parity)."""
|
||||
assert detect_query_intent(query) == expected
|
||||
|
||||
@@ -466,7 +466,18 @@ class TestDiagnostics:
|
||||
|
||||
yield db_path
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
for attempt in range(5):
|
||||
try:
|
||||
db_path.unlink()
|
||||
break
|
||||
except PermissionError:
|
||||
time.sleep(0.05 * (attempt + 1))
|
||||
else:
|
||||
# Best-effort cleanup (Windows SQLite locks can linger briefly).
|
||||
try:
|
||||
db_path.unlink(missing_ok=True)
|
||||
except (PermissionError, OSError):
|
||||
pass
|
||||
|
||||
def test_diagnose_empty_database(self, empty_db):
|
||||
"""Diagnose behavior with empty database."""
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestChunkConfig:
|
||||
"""Test default configuration values."""
|
||||
config = ChunkConfig()
|
||||
assert config.max_chunk_size == 1000
|
||||
assert config.overlap == 100
|
||||
assert config.overlap == 200
|
||||
assert config.min_chunk_size == 50
|
||||
|
||||
def test_custom_config(self):
|
||||
|
||||
Reference in New Issue
Block a user