mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-13 02:41:50 +08:00
feat: Implement adaptive RRF weights and query intent detection
- Added integration tests for adaptive RRF weights in hybrid search. - Enhanced query intent detection with new classifications: keyword, semantic, and mixed. - Introduced symbol boosting in search results based on explicit symbol matches. - Implemented embedding-based reranking with configurable options. - Added global symbol index for efficient symbol lookups across projects. - Improved file deletion handling on Windows to avoid permission errors. - Updated chunk configuration to increase overlap for better context. - Modified package.json test script to target specific test files. - Created comprehensive writing style guidelines for documentation. - Added TypeScript tests for query intent detection and adaptive weights. - Established performance benchmarks for global symbol indexing.
This commit is contained in:
@@ -6,12 +6,98 @@ for combining results from heterogeneous search backends (exact FTS, fuzzy FTS,
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import math
|
||||
from typing import Dict, List
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from codexlens.entities import SearchResult, AdditionalLocation
|
||||
|
||||
|
||||
class QueryIntent(str, Enum):
|
||||
"""Query intent for adaptive RRF weights (Python/TypeScript parity)."""
|
||||
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
def normalize_weights(weights: Dict[str, float]) -> Dict[str, float]:
|
||||
"""Normalize weights to sum to 1.0 (best-effort)."""
|
||||
total = sum(float(v) for v in weights.values() if v is not None)
|
||||
if not math.isfinite(total) or total <= 0:
|
||||
return {k: float(v) for k, v in weights.items()}
|
||||
return {k: float(v) / total for k, v in weights.items()}
|
||||
|
||||
|
||||
def detect_query_intent(query: str) -> QueryIntent:
|
||||
"""Detect whether a query is code-like, natural-language, or mixed.
|
||||
|
||||
Heuristic signals kept aligned with `ccw/src/tools/smart-search.ts`.
|
||||
"""
|
||||
trimmed = (query or "").strip()
|
||||
if not trimmed:
|
||||
return QueryIntent.MIXED
|
||||
|
||||
lower = trimmed.lower()
|
||||
word_count = len([w for w in re.split(r"\s+", trimmed) if w])
|
||||
|
||||
has_code_signals = bool(
|
||||
re.search(r"(::|->|\.)", trimmed)
|
||||
or re.search(r"[A-Z][a-z]+[A-Z]", trimmed)
|
||||
or re.search(r"\b\w+_\w+\b", trimmed)
|
||||
or re.search(
|
||||
r"\b(def|class|function|const|let|var|import|from|return|async|await|interface|type)\b",
|
||||
lower,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
has_natural_signals = bool(
|
||||
word_count > 5
|
||||
or "?" in trimmed
|
||||
or re.search(r"\b(how|what|why|when|where)\b", trimmed, flags=re.IGNORECASE)
|
||||
or re.search(
|
||||
r"\b(handle|explain|fix|implement|create|build|use|find|search|convert|parse|generate|support)\b",
|
||||
trimmed,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
)
|
||||
|
||||
if has_code_signals and has_natural_signals:
|
||||
return QueryIntent.MIXED
|
||||
if has_code_signals:
|
||||
return QueryIntent.KEYWORD
|
||||
if has_natural_signals:
|
||||
return QueryIntent.SEMANTIC
|
||||
return QueryIntent.MIXED
|
||||
|
||||
|
||||
def adjust_weights_by_intent(
|
||||
intent: QueryIntent,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Map intent → weights (kept aligned with TypeScript mapping)."""
|
||||
if intent == QueryIntent.KEYWORD:
|
||||
target = {"exact": 0.5, "fuzzy": 0.1, "vector": 0.4}
|
||||
elif intent == QueryIntent.SEMANTIC:
|
||||
target = {"exact": 0.2, "fuzzy": 0.1, "vector": 0.7}
|
||||
else:
|
||||
target = dict(base_weights)
|
||||
|
||||
# Preserve only keys that are present in base_weights (active backends).
|
||||
keys = list(base_weights.keys())
|
||||
filtered = {k: float(target.get(k, 0.0)) for k in keys}
|
||||
return normalize_weights(filtered)
|
||||
|
||||
|
||||
def get_rrf_weights(
|
||||
query: str,
|
||||
base_weights: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
"""Compute adaptive RRF weights from query intent."""
|
||||
return adjust_weights_by_intent(detect_query_intent(query), base_weights)
|
||||
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
results_map: Dict[str, List[SearchResult]],
|
||||
weights: Dict[str, float] = None,
|
||||
@@ -102,6 +188,186 @@ def reciprocal_rank_fusion(
|
||||
return fused_results
|
||||
|
||||
|
||||
def apply_symbol_boost(
|
||||
results: List[SearchResult],
|
||||
boost_factor: float = 1.5,
|
||||
) -> List[SearchResult]:
|
||||
"""Boost fused scores for results that include an explicit symbol match.
|
||||
|
||||
The boost is multiplicative on the current result.score (typically the RRF fusion score).
|
||||
When boosted, the original score is preserved in metadata["original_fusion_score"] and
|
||||
metadata["boosted"] is set to True.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if boost_factor <= 1.0:
|
||||
# Still return new objects to follow immutable transformation pattern.
|
||||
return [
|
||||
SearchResult(
|
||||
path=r.path,
|
||||
score=r.score,
|
||||
excerpt=r.excerpt,
|
||||
content=r.content,
|
||||
symbol=r.symbol,
|
||||
chunk=r.chunk,
|
||||
metadata={**r.metadata},
|
||||
start_line=r.start_line,
|
||||
end_line=r.end_line,
|
||||
symbol_name=r.symbol_name,
|
||||
symbol_kind=r.symbol_kind,
|
||||
additional_locations=list(r.additional_locations),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
boosted_results: List[SearchResult] = []
|
||||
for result in results:
|
||||
has_symbol = bool(result.symbol_name)
|
||||
original_score = float(result.score)
|
||||
boosted_score = original_score * boost_factor if has_symbol else original_score
|
||||
|
||||
metadata = {**result.metadata}
|
||||
if has_symbol:
|
||||
metadata.setdefault("original_fusion_score", metadata.get("fusion_score", original_score))
|
||||
metadata["boosted"] = True
|
||||
metadata["symbol_boost_factor"] = boost_factor
|
||||
|
||||
boosted_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=boosted_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata=metadata,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
boosted_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return boosted_results
|
||||
|
||||
|
||||
def rerank_results(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
embedder: Any,
|
||||
top_k: int = 50,
|
||||
) -> List[SearchResult]:
|
||||
"""Re-rank results with embedding cosine similarity, combined with current score.
|
||||
|
||||
Combined score formula:
|
||||
0.5 * rrf_score + 0.5 * cosine_similarity
|
||||
|
||||
If embedder is None or embedding fails, returns results as-is.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if embedder is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
|
||||
# Defensive: handle mismatched lengths and zero vectors.
|
||||
n = min(len(vec_a), len(vec_b))
|
||||
if n == 0:
|
||||
return 0.0
|
||||
dot = 0.0
|
||||
norm_a = 0.0
|
||||
norm_b = 0.0
|
||||
for i in range(n):
|
||||
a = float(vec_a[i])
|
||||
b = float(vec_b[i])
|
||||
dot += a * b
|
||||
norm_a += a * a
|
||||
norm_b += b * b
|
||||
if norm_a <= 0.0 or norm_b <= 0.0:
|
||||
return 0.0
|
||||
sim = dot / (math.sqrt(norm_a) * math.sqrt(norm_b))
|
||||
# SearchResult.score requires non-negative scores; clamp cosine similarity to [0, 1].
|
||||
return max(0.0, min(1.0, sim))
|
||||
|
||||
def text_for_embedding(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
# Fallback: stable, non-empty text.
|
||||
return r.symbol_name or r.path
|
||||
|
||||
try:
|
||||
if hasattr(embedder, "embed_single"):
|
||||
query_vec = embedder.embed_single(query)
|
||||
else:
|
||||
query_vec = embedder.embed(query)[0]
|
||||
|
||||
doc_texts = [text_for_embedding(r) for r in results[:rerank_count]]
|
||||
doc_vecs = embedder.embed(doc_texts)
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
rrf_score = float(result.score)
|
||||
sim = cosine_similarity(query_vec, doc_vecs[idx])
|
||||
combined_score = 0.5 * rrf_score + 0.5 * sim
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"rrf_score": rrf_score,
|
||||
"cosine_similarity": sim,
|
||||
"reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Preserve remaining results without re-ranking, but keep immutability.
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def normalize_bm25_score(score: float) -> float:
|
||||
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user