mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-05 01:50:27 +08:00
943 lines
32 KiB
Python
943 lines
32 KiB
Python
"""Ranking algorithms for hybrid search result fusion.
|
|
|
|
Implements Reciprocal Rank Fusion (RRF) and score normalization utilities
|
|
for combining results from heterogeneous search backends (SPLADE, exact FTS, fuzzy FTS, vector search).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
import math
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from codexlens.entities import SearchResult, AdditionalLocation
|
|
|
|
|
|
# Default RRF weights for SPLADE-based hybrid search
|
|
DEFAULT_WEIGHTS = {
|
|
"splade": 0.35, # Replaces exact(0.3) + fuzzy(0.1)
|
|
"vector": 0.5,
|
|
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
|
}
|
|
|
|
# Legacy weights for FTS fallback mode (when SPLADE unavailable)
|
|
FTS_FALLBACK_WEIGHTS = {
|
|
"exact": 0.25,
|
|
"fuzzy": 0.1,
|
|
"vector": 0.5,
|
|
"lsp_graph": 0.15, # Real-time LSP-based graph expansion
|
|
}
|
|
|
|
|
|
class QueryIntent(str, Enum):
|
|
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
|
|
|
KEYWORD = "keyword"
|
|
SEMANTIC = "semantic"
|
|
MIXED = "mixed"
|
|
|
|
|
|
def normalize_weights(weights: Dict[str, float | None]) -> Dict[str, float | None]:
|
|
"""Normalize weights to sum to 1.0 (best-effort)."""
|
|
total = sum(float(v) for v in weights.values() if v is not None)
|
|
|
|
# NaN total: do not attempt to normalize (division would propagate NaNs).
|
|
if math.isnan(total):
|
|
return dict(weights)
|
|
|
|
# Infinite total: do not attempt to normalize (division yields 0 or NaN).
|
|
if not math.isfinite(total):
|
|
return dict(weights)
|
|
|
|
# Zero/negative total: do not attempt to normalize (invalid denominator).
|
|
if total <= 0:
|
|
return dict(weights)
|
|
|
|
return {k: (float(v) / total if v is not None else None) for k, v in weights.items()}
|
|
|
|
|
|
def detect_query_intent(query: str) -> QueryIntent:
|
|
"""Detect whether a query is code-like, natural-language, or mixed.
|
|
|
|
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
|
|
"""
|
|
trimmed = (query or "").strip()
|
|
if not trimmed:
|
|
return QueryIntent.MIXED
|
|
|
|
lower = trimmed.lower()
|
|
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
|
|
|
|
has_code_signals = bool(
|
|
re.search(r"(::|->|\.)", trimmed)
|
|
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
|
|
or re.search(r"\b\w+_\w+\b", trimmed)
|
|
or re.search(
|
|
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
|
|
lower,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
)
|
|
has_natural_signals = bool(
|
|
word_count > 5
|
|
or "?" in trimmed
|
|
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
|
|
or re.search(
|
|
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
|
|
trimmed,
|
|
flags=re.IGNORECASE,
|
|
)
|
|
)
|
|
|
|
if has_code_signals and has_natural_signals:
|
|
return QueryIntent.MIXED
|
|
if has_code_signals:
|
|
return QueryIntent.KEYWORD
|
|
if has_natural_signals:
|
|
return QueryIntent.SEMANTIC
|
|
return QueryIntent.MIXED
|
|
|
|
|
|
def adjust_weights_by_intent(
|
|
intent: QueryIntent,
|
|
base_weights: Dict[str, float],
|
|
) -> Dict[str, float]:
|
|
"""Adjust RRF weights based on query intent."""
|
|
# Check if using SPLADE or FTS mode
|
|
use_splade = "splade" in base_weights
|
|
|
|
if intent == QueryIntent.KEYWORD:
|
|
if use_splade:
|
|
target = {"splade": 0.6, "vector": 0.4}
|
|
else:
|
|
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
|
elif intent == QueryIntent.SEMANTIC:
|
|
if use_splade:
|
|
target = {"splade": 0.3, "vector": 0.7}
|
|
else:
|
|
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
|
else:
|
|
target = dict(base_weights)
|
|
|
|
# Filter to active backends
|
|
keys = list(base_weights.keys())
|
|
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
|
return normalize_weights(filtered)
|
|
|
|
|
|
def get_rrf_weights(
|
|
query: str,
|
|
base_weights: Dict[str, float],
|
|
) -> Dict[str, float]:
|
|
"""Compute adaptive RRF weights from query intent."""
|
|
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
|
|
|
|
|
|
# File extensions to category mapping for fast lookup
|
|
_EXT_TO_CATEGORY: Dict[str, str] = {
|
|
# Code extensions
|
|
".py": "code", ".js": "code", ".jsx": "code", ".ts": "code", ".tsx": "code",
|
|
".java": "code", ".go": "code", ".zig": "code", ".m": "code", ".mm": "code",
|
|
".c": "code", ".h": "code", ".cc": "code", ".cpp": "code", ".hpp": "code", ".cxx": "code",
|
|
".rs": "code",
|
|
# Doc extensions
|
|
".md": "doc", ".mdx": "doc", ".txt": "doc", ".rst": "doc",
|
|
}
|
|
|
|
|
|
def get_file_category(path: str) -> Optional[str]:
|
|
"""Get file category ('code' or 'doc') from path extension.
|
|
|
|
Args:
|
|
path: File path string
|
|
|
|
Returns:
|
|
'code', 'doc', or None if unknown
|
|
"""
|
|
ext = Path(path).suffix.lower()
|
|
return _EXT_TO_CATEGORY.get(ext)
|
|
|
|
|
|
def filter_results_by_category(
|
|
results: List[SearchResult],
|
|
intent: QueryIntent,
|
|
allow_mixed: bool = True,
|
|
) -> List[SearchResult]:
|
|
"""Filter results by category based on query intent.
|
|
|
|
Strategy:
|
|
- KEYWORD (code intent): Only return code files
|
|
- SEMANTIC (doc intent): Prefer docs, but allow code if allow_mixed=True
|
|
- MIXED: Return all results
|
|
|
|
Args:
|
|
results: List of SearchResult objects
|
|
intent: Query intent from detect_query_intent()
|
|
allow_mixed: If True, SEMANTIC intent includes code files with lower priority
|
|
|
|
Returns:
|
|
Filtered and re-ranked list of SearchResult objects
|
|
"""
|
|
if not results or intent == QueryIntent.MIXED:
|
|
return results
|
|
|
|
code_results = []
|
|
doc_results = []
|
|
unknown_results = []
|
|
|
|
for r in results:
|
|
category = get_file_category(r.path)
|
|
if category == "code":
|
|
code_results.append(r)
|
|
elif category == "doc":
|
|
doc_results.append(r)
|
|
else:
|
|
unknown_results.append(r)
|
|
|
|
if intent == QueryIntent.KEYWORD:
|
|
# Code intent: return only code files + unknown (might be code)
|
|
filtered = code_results + unknown_results
|
|
elif intent == QueryIntent.SEMANTIC:
|
|
if allow_mixed:
|
|
# Semantic intent with mixed: docs first, then code
|
|
filtered = doc_results + code_results + unknown_results
|
|
else:
|
|
# Semantic intent strict: only docs
|
|
filtered = doc_results + unknown_results
|
|
else:
|
|
filtered = results
|
|
|
|
return filtered
|
|
|
|
|
|
def simple_weighted_fusion(
|
|
results_map: Dict[str, List[SearchResult]],
|
|
weights: Dict[str, float] = None,
|
|
) -> List[SearchResult]:
|
|
"""Combine search results using simple weighted sum of normalized scores.
|
|
|
|
This is an alternative to RRF that preserves score magnitude information.
|
|
Scores are min-max normalized per source before weighted combination.
|
|
|
|
Formula: score(d) = Σ weight_source * normalized_score_source(d)
|
|
|
|
Args:
|
|
results_map: Dictionary mapping source name to list of SearchResult objects
|
|
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
|
weights: Dictionary mapping source name to weight (default: equal weights)
|
|
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
|
|
|
Returns:
|
|
List of SearchResult objects sorted by fused score (descending)
|
|
|
|
Examples:
|
|
>>> fts_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
|
>>> vector_results = [SearchResult(path="b.py", score=0.85, excerpt="...")]
|
|
>>> results_map = {'exact': fts_results, 'vector': vector_results}
|
|
>>> fused = simple_weighted_fusion(results_map)
|
|
"""
|
|
if not results_map:
|
|
return []
|
|
|
|
# Default equal weights if not provided
|
|
if weights is None:
|
|
num_sources = len(results_map)
|
|
weights = {source: 1.0 / num_sources for source in results_map}
|
|
|
|
# Normalize weights to sum to 1.0
|
|
weight_sum = sum(weights.values())
|
|
if not math.isclose(weight_sum, 1.0, abs_tol=0.01) and weight_sum > 0:
|
|
weights = {source: w / weight_sum for source, w in weights.items()}
|
|
|
|
# Compute min-max normalization parameters per source
|
|
source_stats: Dict[str, tuple] = {}
|
|
for source_name, results in results_map.items():
|
|
if not results:
|
|
continue
|
|
scores = [r.score for r in results]
|
|
min_s, max_s = min(scores), max(scores)
|
|
source_stats[source_name] = (min_s, max_s)
|
|
|
|
def normalize_score(score: float, source: str) -> float:
|
|
"""Normalize score to [0, 1] range using min-max scaling."""
|
|
if source not in source_stats:
|
|
return 0.0
|
|
min_s, max_s = source_stats[source]
|
|
if max_s == min_s:
|
|
return 1.0 if score >= min_s else 0.0
|
|
return (score - min_s) / (max_s - min_s)
|
|
|
|
# Build unified result set with weighted scores
|
|
path_to_result: Dict[str, SearchResult] = {}
|
|
path_to_fusion_score: Dict[str, float] = {}
|
|
path_to_source_scores: Dict[str, Dict[str, float]] = {}
|
|
|
|
for source_name, results in results_map.items():
|
|
weight = weights.get(source_name, 0.0)
|
|
if weight == 0:
|
|
continue
|
|
|
|
for result in results:
|
|
path = result.path
|
|
normalized = normalize_score(result.score, source_name)
|
|
contribution = weight * normalized
|
|
|
|
if path not in path_to_fusion_score:
|
|
path_to_fusion_score[path] = 0.0
|
|
path_to_result[path] = result
|
|
path_to_source_scores[path] = {}
|
|
|
|
path_to_fusion_score[path] += contribution
|
|
path_to_source_scores[path][source_name] = normalized
|
|
|
|
# Create final results with fusion scores
|
|
fused_results = []
|
|
for path, base_result in path_to_result.items():
|
|
fusion_score = path_to_fusion_score[path]
|
|
|
|
fused_result = SearchResult(
|
|
path=base_result.path,
|
|
score=fusion_score,
|
|
excerpt=base_result.excerpt,
|
|
content=base_result.content,
|
|
symbol=base_result.symbol,
|
|
chunk=base_result.chunk,
|
|
metadata={
|
|
**base_result.metadata,
|
|
"fusion_method": "simple_weighted",
|
|
"fusion_score": fusion_score,
|
|
"original_score": base_result.score,
|
|
"source_scores": path_to_source_scores[path],
|
|
},
|
|
start_line=base_result.start_line,
|
|
end_line=base_result.end_line,
|
|
symbol_name=base_result.symbol_name,
|
|
symbol_kind=base_result.symbol_kind,
|
|
)
|
|
fused_results.append(fused_result)
|
|
|
|
fused_results.sort(key=lambda r: r.score, reverse=True)
|
|
return fused_results
|
|
|
|
|
|
def reciprocal_rank_fusion(
|
|
results_map: Dict[str, List[SearchResult]],
|
|
weights: Dict[str, float] = None,
|
|
k: int = 60,
|
|
) -> List[SearchResult]:
|
|
"""Combine search results from multiple sources using Reciprocal Rank Fusion.
|
|
|
|
RRF formula: score(d) = Σ weight_source / (k + rank_source(d))
|
|
|
|
Supports three-way fusion with FTS, Vector, and SPLADE sources.
|
|
|
|
Args:
|
|
results_map: Dictionary mapping source name to list of SearchResult objects
|
|
Sources: 'exact', 'fuzzy', 'vector', 'splade'
|
|
weights: Dictionary mapping source name to weight (default: equal weights)
|
|
Example: {'exact': 0.3, 'fuzzy': 0.1, 'vector': 0.6}
|
|
Or: {'splade': 0.4, 'vector': 0.6}
|
|
k: Constant to avoid division by zero and control rank influence (default 60)
|
|
|
|
Returns:
|
|
List of SearchResult objects sorted by fused score (descending)
|
|
|
|
Examples:
|
|
>>> exact_results = [SearchResult(path="a.py", score=10.0, excerpt="...")]
|
|
>>> fuzzy_results = [SearchResult(path="b.py", score=8.0, excerpt="...")]
|
|
>>> results_map = {'exact': exact_results, 'fuzzy': fuzzy_results}
|
|
>>> fused = reciprocal_rank_fusion(results_map)
|
|
|
|
# Three-way fusion with SPLADE
|
|
>>> results_map = {
|
|
... 'exact': exact_results,
|
|
... 'vector': vector_results,
|
|
... 'splade': splade_results
|
|
... }
|
|
>>> fused = reciprocal_rank_fusion(results_map, k=60)
|
|
"""
|
|
if not results_map:
|
|
return []
|
|
|
|
# Default equal weights if not provided
|
|
if weights is None:
|
|
num_sources = len(results_map)
|
|
weights = {source: 1.0 / num_sources for source in results_map}
|
|
|
|
# Validate weights sum to 1.0
|
|
weight_sum = sum(weights.values())
|
|
if not math.isclose(weight_sum, 1.0, abs_tol=0.01):
|
|
# Normalize weights to sum to 1.0
|
|
weights = {source: w / weight_sum for source, w in weights.items()}
|
|
|
|
# Build unified result set with RRF scores
|
|
path_to_result: Dict[str, SearchResult] = {}
|
|
path_to_fusion_score: Dict[str, float] = {}
|
|
path_to_source_ranks: Dict[str, Dict[str, int]] = {}
|
|
|
|
for source_name, results in results_map.items():
|
|
weight = weights.get(source_name, 0.0)
|
|
if weight == 0:
|
|
continue
|
|
|
|
for rank, result in enumerate(results, start=1):
|
|
path = result.path
|
|
rrf_contribution = weight / (k + rank)
|
|
|
|
# Initialize or accumulate fusion score
|
|
if path not in path_to_fusion_score:
|
|
path_to_fusion_score[path] = 0.0
|
|
path_to_result[path] = result
|
|
path_to_source_ranks[path] = {}
|
|
|
|
path_to_fusion_score[path] += rrf_contribution
|
|
path_to_source_ranks[path][source_name] = rank
|
|
|
|
# Create final results with fusion scores
|
|
fused_results = []
|
|
for path, base_result in path_to_result.items():
|
|
fusion_score = path_to_fusion_score[path]
|
|
|
|
# Create new SearchResult with fusion_score in metadata
|
|
fused_result = SearchResult(
|
|
path=base_result.path,
|
|
score=fusion_score,
|
|
excerpt=base_result.excerpt,
|
|
content=base_result.content,
|
|
symbol=base_result.symbol,
|
|
chunk=base_result.chunk,
|
|
metadata={
|
|
**base_result.metadata,
|
|
"fusion_method": "rrf",
|
|
"fusion_score": fusion_score,
|
|
"original_score": base_result.score,
|
|
"rrf_k": k,
|
|
"source_ranks": path_to_source_ranks[path],
|
|
},
|
|
start_line=base_result.start_line,
|
|
end_line=base_result.end_line,
|
|
symbol_name=base_result.symbol_name,
|
|
symbol_kind=base_result.symbol_kind,
|
|
)
|
|
fused_results.append(fused_result)
|
|
|
|
# Sort by fusion score descending
|
|
fused_results.sort(key=lambda r: r.score, reverse=True)
|
|
|
|
return fused_results
|
|
|
|
|
|
def apply_symbol_boost(
|
|
results: List[SearchResult],
|
|
boost_factor: float = 1.5,
|
|
) -> List[SearchResult]:
|
|
"""Boost fused scores for results that include an explicit symbol match.
|
|
|
|
The boost is multiplicative on the current result.score (typically the RRF fusion score).
|
|
When boosted, the original score is preserved in metadata["original_fusion_score"] and
|
|
metadata["boosted"] is set to True.
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
if boost_factor <= 1.0:
|
|
# Still return new objects to follow immutable transformation pattern.
|
|
return [
|
|
SearchResult(
|
|
path=r.path,
|
|
score=r.score,
|
|
excerpt=r.excerpt,
|
|
content=r.content,
|
|
symbol=r.symbol,
|
|
chunk=r.chunk,
|
|
metadata={**r.metadata},
|
|
start_line=r.start_line,
|
|
end_line=r.end_line,
|
|
symbol_name=r.symbol_name,
|
|
symbol_kind=r.symbol_kind,
|
|
additional_locations=list(r.additional_locations),
|
|
)
|
|
for r in results
|
|
]
|
|
|
|
boosted_results: List[SearchResult] = []
|
|
for result in results:
|
|
has_symbol = bool(result.symbol_name)
|
|
original_score = float(result.score)
|
|
boosted_score = original_score * boost_factor if has_symbol else original_score
|
|
|
|
metadata = {**result.metadata}
|
|
if has_symbol:
|
|
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
|
|
metadata["boosted"] = True
|
|
metadata["symbol_boost_factor"] = boost_factor
|
|
|
|
boosted_results.append(
|
|
SearchResult(
|
|
path=result.path,
|
|
score=boosted_score,
|
|
excerpt=result.excerpt,
|
|
content=result.content,
|
|
symbol=result.symbol,
|
|
chunk=result.chunk,
|
|
metadata=metadata,
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
symbol_name=result.symbol_name,
|
|
symbol_kind=result.symbol_kind,
|
|
additional_locations=list(result.additional_locations),
|
|
)
|
|
)
|
|
|
|
boosted_results.sort(key=lambda r: r.score, reverse=True)
|
|
return boosted_results
|
|
|
|
|
|
def rerank_results(
|
|
query: str,
|
|
results: List[SearchResult],
|
|
embedder: Any,
|
|
top_k: int = 50,
|
|
) -> List[SearchResult]:
|
|
"""Re-rank results with embedding cosine similarity, combined with current score.
|
|
|
|
Combined score formula:
|
|
0.5 * rrf_score + 0.5 * cosine_similarity
|
|
|
|
If embedder is None or embedding fails, returns results as-is.
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
if embedder is None or top_k <= 0:
|
|
return results
|
|
|
|
rerank_count = min(int(top_k), len(results))
|
|
|
|
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
|
# Defensive: handle mismatched lengths and zero vectors.
|
|
n = min(len(vec_a), len(vec_b))
|
|
if n == 0:
|
|
return 0.0
|
|
dot = 0.0
|
|
norm_a = 0.0
|
|
norm_b = 0.0
|
|
for i in range(n):
|
|
a = float(vec_a[i])
|
|
b = float(vec_b[i])
|
|
dot += a * b
|
|
norm_a += a * a
|
|
norm_b += b * b
|
|
if norm_a <= 0.0 or norm_b <= 0.0:
|
|
return 0.0
|
|
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
|
|
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
|
|
return max(0.0, min(1.0, sim))
|
|
|
|
def text_for_embedding(r: SearchResult) -> str:
|
|
if r.excerpt and r.excerpt.strip():
|
|
return r.excerpt
|
|
if r.content and r.content.strip():
|
|
return r.content
|
|
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
|
return r.chunk.content
|
|
# Fallback: stable, non-empty text.
|
|
return r.symbol_name or r.path
|
|
|
|
try:
|
|
if hasattr(embedder, "embed_single"):
|
|
query_vec = embedder.embed_single(query)
|
|
else:
|
|
query_vec = embedder.embed(query)[0]
|
|
|
|
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
|
|
doc_vecs = embedder.embed(doc_texts)
|
|
except Exception:
|
|
return results
|
|
|
|
reranked_results: List[SearchResult] = []
|
|
|
|
for idx, result in enumerate(results):
|
|
if idx < rerank_count:
|
|
rrf_score = float(result.score)
|
|
sim = cosine_similarity(query_vec, doc_vecs[idx])
|
|
combined_score = 0.5 * rrf_score + 0.5 * sim
|
|
|
|
reranked_results.append(
|
|
SearchResult(
|
|
path=result.path,
|
|
score=combined_score,
|
|
excerpt=result.excerpt,
|
|
content=result.content,
|
|
symbol=result.symbol,
|
|
chunk=result.chunk,
|
|
metadata={
|
|
**result.metadata,
|
|
"rrf_score": rrf_score,
|
|
"cosine_similarity": sim,
|
|
"reranked": True,
|
|
},
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
symbol_name=result.symbol_name,
|
|
symbol_kind=result.symbol_kind,
|
|
additional_locations=list(result.additional_locations),
|
|
)
|
|
)
|
|
else:
|
|
# Preserve remaining results without re-ranking, but keep immutability.
|
|
reranked_results.append(
|
|
SearchResult(
|
|
path=result.path,
|
|
score=result.score,
|
|
excerpt=result.excerpt,
|
|
content=result.content,
|
|
symbol=result.symbol,
|
|
chunk=result.chunk,
|
|
metadata={**result.metadata},
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
symbol_name=result.symbol_name,
|
|
symbol_kind=result.symbol_kind,
|
|
additional_locations=list(result.additional_locations),
|
|
)
|
|
)
|
|
|
|
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
|
return reranked_results
|
|
|
|
|
|
def cross_encoder_rerank(
|
|
query: str,
|
|
results: List[SearchResult],
|
|
reranker: Any,
|
|
top_k: int = 50,
|
|
batch_size: int = 32,
|
|
chunk_type_weights: Optional[Dict[str, float]] = None,
|
|
test_file_penalty: float = 0.0,
|
|
) -> List[SearchResult]:
|
|
"""Second-stage reranking using a cross-encoder model.
|
|
|
|
This function is dependency-agnostic: callers can pass any object that exposes
|
|
a compatible `score_pairs(pairs, batch_size=...)` method.
|
|
|
|
Args:
|
|
query: Search query string
|
|
results: List of search results to rerank
|
|
reranker: Cross-encoder model with score_pairs or predict method
|
|
top_k: Number of top results to rerank
|
|
batch_size: Batch size for reranking
|
|
chunk_type_weights: Optional weights for different chunk types.
|
|
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
|
|
test_file_penalty: Penalty applied to test files (0.0-1.0).
|
|
Example: 0.2 means test files get 20% score reduction
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
if reranker is None or top_k <= 0:
|
|
return results
|
|
|
|
rerank_count = min(int(top_k), len(results))
|
|
|
|
def text_for_pair(r: SearchResult) -> str:
|
|
if r.excerpt and r.excerpt.strip():
|
|
return r.excerpt
|
|
if r.content and r.content.strip():
|
|
return r.content
|
|
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
|
return r.chunk.content
|
|
return r.symbol_name or r.path
|
|
|
|
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
|
|
|
|
try:
|
|
if hasattr(reranker, "score_pairs"):
|
|
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
|
|
elif hasattr(reranker, "predict"):
|
|
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
|
|
else:
|
|
return results
|
|
except Exception:
|
|
return results
|
|
|
|
if not raw_scores or len(raw_scores) != rerank_count:
|
|
return results
|
|
|
|
scores = [float(s) for s in raw_scores]
|
|
min_s = min(scores)
|
|
max_s = max(scores)
|
|
|
|
def sigmoid(x: float) -> float:
|
|
# Clamp to keep exp() stable.
|
|
x = max(-50.0, min(50.0, x))
|
|
return 1.0 / (1.0 + math.exp(-x))
|
|
|
|
if 0.0 <= min_s and max_s <= 1.0:
|
|
probs = scores
|
|
else:
|
|
probs = [sigmoid(s) for s in scores]
|
|
|
|
reranked_results: List[SearchResult] = []
|
|
|
|
# Helper to detect test files
|
|
def is_test_file(path: str) -> bool:
|
|
if not path:
|
|
return False
|
|
basename = path.split("/")[-1].split("\\")[-1]
|
|
return (
|
|
basename.startswith("test_") or
|
|
basename.endswith("_test.py") or
|
|
basename.endswith(".test.ts") or
|
|
basename.endswith(".test.js") or
|
|
basename.endswith(".spec.ts") or
|
|
basename.endswith(".spec.js") or
|
|
"/tests/" in path or
|
|
"\\tests\\" in path or
|
|
"/test/" in path or
|
|
"\\test\\" in path
|
|
)
|
|
|
|
for idx, result in enumerate(results):
|
|
if idx < rerank_count:
|
|
prev_score = float(result.score)
|
|
ce_score = scores[idx]
|
|
ce_prob = probs[idx]
|
|
|
|
# Base combined score
|
|
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
|
|
|
# Apply chunk_type weight adjustment
|
|
if chunk_type_weights:
|
|
chunk_type = None
|
|
if result.chunk and hasattr(result.chunk, "metadata"):
|
|
chunk_type = result.chunk.metadata.get("chunk_type")
|
|
elif result.metadata:
|
|
chunk_type = result.metadata.get("chunk_type")
|
|
|
|
if chunk_type and chunk_type in chunk_type_weights:
|
|
weight = chunk_type_weights[chunk_type]
|
|
# Apply weight to CE contribution only
|
|
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
|
|
|
|
# Apply test file penalty
|
|
if test_file_penalty > 0 and is_test_file(result.path):
|
|
combined_score = combined_score * (1.0 - test_file_penalty)
|
|
|
|
reranked_results.append(
|
|
SearchResult(
|
|
path=result.path,
|
|
score=combined_score,
|
|
excerpt=result.excerpt,
|
|
content=result.content,
|
|
symbol=result.symbol,
|
|
chunk=result.chunk,
|
|
metadata={
|
|
**result.metadata,
|
|
"pre_cross_encoder_score": prev_score,
|
|
"cross_encoder_score": ce_score,
|
|
"cross_encoder_prob": ce_prob,
|
|
"cross_encoder_reranked": True,
|
|
},
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
symbol_name=result.symbol_name,
|
|
symbol_kind=result.symbol_kind,
|
|
additional_locations=list(result.additional_locations),
|
|
)
|
|
)
|
|
else:
|
|
reranked_results.append(
|
|
SearchResult(
|
|
path=result.path,
|
|
score=result.score,
|
|
excerpt=result.excerpt,
|
|
content=result.content,
|
|
symbol=result.symbol,
|
|
chunk=result.chunk,
|
|
metadata={**result.metadata},
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
symbol_name=result.symbol_name,
|
|
symbol_kind=result.symbol_kind,
|
|
additional_locations=list(result.additional_locations),
|
|
)
|
|
)
|
|
|
|
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
|
return reranked_results
|
|
|
|
|
|
def normalize_bm25_score(score: float) -> float:
|
|
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
|
|
|
SQLite FTS5 returns negative BM25 scores (more negative = better match).
|
|
Uses sigmoid transformation for normalization.
|
|
|
|
Args:
|
|
score: Raw BM25 score from SQLite (typically negative)
|
|
|
|
Returns:
|
|
Normalized score in range [0, 1]
|
|
|
|
Examples:
|
|
>>> normalize_bm25_score(-10.5) # Good match
|
|
0.85
|
|
>>> normalize_bm25_score(-1.2) # Weak match
|
|
0.62
|
|
"""
|
|
# Take absolute value (BM25 is negative in SQLite)
|
|
abs_score = abs(score)
|
|
|
|
# Sigmoid transformation: 1 / (1 + e^(-x))
|
|
# Scale factor of 0.1 maps typical BM25 range (-20 to 0) to (0, 1)
|
|
normalized = 1.0 / (1.0 + math.exp(-abs_score * 0.1))
|
|
|
|
return normalized
|
|
|
|
|
|
def tag_search_source(results: List[SearchResult], source: str) -> List[SearchResult]:
|
|
"""Tag search results with their source for RRF tracking.
|
|
|
|
Args:
|
|
results: List of SearchResult objects
|
|
source: Source identifier ('exact', 'fuzzy', 'vector')
|
|
|
|
Returns:
|
|
List of SearchResult objects with 'search_source' in metadata
|
|
"""
|
|
tagged_results = []
|
|
for result in results:
|
|
tagged_result = SearchResult(
|
|
path=result.path,
|
|
score=result.score,
|
|
excerpt=result.excerpt,
|
|
content=result.content,
|
|
symbol=result.symbol,
|
|
chunk=result.chunk,
|
|
metadata={**result.metadata, "search_source": source},
|
|
start_line=result.start_line,
|
|
end_line=result.end_line,
|
|
symbol_name=result.symbol_name,
|
|
symbol_kind=result.symbol_kind,
|
|
)
|
|
tagged_results.append(tagged_result)
|
|
|
|
return tagged_results
|
|
|
|
|
|
def group_similar_results(
|
|
results: List[SearchResult],
|
|
score_threshold_abs: float = 0.01,
|
|
content_field: str = "excerpt"
|
|
) -> List[SearchResult]:
|
|
"""Group search results by content and score similarity.
|
|
|
|
Groups results that have similar content and similar scores into a single
|
|
representative result, with other locations stored in additional_locations.
|
|
|
|
Algorithm:
|
|
1. Group results by content (using excerpt or content field)
|
|
2. Within each content group, create subgroups based on score similarity
|
|
3. Select highest-scoring result as representative for each subgroup
|
|
4. Store other results in subgroup as additional_locations
|
|
|
|
Args:
|
|
results: A list of SearchResult objects (typically sorted by score)
|
|
score_threshold_abs: Absolute score difference to consider results similar.
|
|
Results with |score_a - score_b| <= threshold are grouped.
|
|
Default 0.01 is suitable for RRF fusion scores.
|
|
content_field: The field to use for content grouping ('excerpt' or 'content')
|
|
|
|
Returns:
|
|
A new list of SearchResult objects where similar items are grouped.
|
|
The list is sorted by score descending.
|
|
|
|
Examples:
|
|
>>> results = [SearchResult(path="a.py", score=0.5, excerpt="def foo()"),
|
|
... SearchResult(path="b.py", score=0.5, excerpt="def foo()")]
|
|
>>> grouped = group_similar_results(results)
|
|
>>> len(grouped) # Two results merged into one
|
|
1
|
|
>>> len(grouped[0].additional_locations) # One additional location
|
|
1
|
|
"""
|
|
if not results:
|
|
return []
|
|
|
|
# Group results by content
|
|
content_map: Dict[str, List[SearchResult]] = {}
|
|
unidentifiable_results: List[SearchResult] = []
|
|
|
|
for r in results:
|
|
key = getattr(r, content_field, None)
|
|
if key and key.strip():
|
|
content_map.setdefault(key, []).append(r)
|
|
else:
|
|
# Results without content can't be grouped by content
|
|
unidentifiable_results.append(r)
|
|
|
|
final_results: List[SearchResult] = []
|
|
|
|
# Process each content group
|
|
for content_group in content_map.values():
|
|
# Sort by score descending within group
|
|
content_group.sort(key=lambda r: r.score, reverse=True)
|
|
|
|
while content_group:
|
|
# Take highest scoring as representative
|
|
representative = content_group.pop(0)
|
|
others_in_group = []
|
|
remaining_for_next_pass = []
|
|
|
|
# Find results with similar scores
|
|
for item in content_group:
|
|
if abs(representative.score - item.score) <= score_threshold_abs:
|
|
others_in_group.append(item)
|
|
else:
|
|
remaining_for_next_pass.append(item)
|
|
|
|
# Create grouped result with additional locations
|
|
if others_in_group:
|
|
# Build new result with additional_locations populated
|
|
grouped_result = SearchResult(
|
|
path=representative.path,
|
|
score=representative.score,
|
|
excerpt=representative.excerpt,
|
|
content=representative.content,
|
|
symbol=representative.symbol,
|
|
chunk=representative.chunk,
|
|
metadata={
|
|
**representative.metadata,
|
|
"grouped_count": len(others_in_group) + 1,
|
|
},
|
|
start_line=representative.start_line,
|
|
end_line=representative.end_line,
|
|
symbol_name=representative.symbol_name,
|
|
symbol_kind=representative.symbol_kind,
|
|
additional_locations=[
|
|
AdditionalLocation(
|
|
path=other.path,
|
|
score=other.score,
|
|
start_line=other.start_line,
|
|
end_line=other.end_line,
|
|
symbol_name=other.symbol_name,
|
|
) for other in others_in_group
|
|
],
|
|
)
|
|
final_results.append(grouped_result)
|
|
else:
|
|
final_results.append(representative)
|
|
|
|
content_group = remaining_for_next_pass
|
|
|
|
# Add ungroupable results
|
|
final_results.extend(unidentifiable_results)
|
|
|
|
# Sort final results by score descending
|
|
final_results.sort(key=lambda r: r.score, reverse=True)
|
|
|
|
return final_results
|