mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-12 02:37:45 +08:00
Add graph expansion and cross-encoder reranking features
- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors. - Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring. - Created migrations to establish necessary database tables for relationships and graph neighbors. - Developed tests for graph expansion functionality, ensuring related results are populated correctly. - Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead. - Updated schema cleanup tests to reflect changes in versioning and deprecated fields. - Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
@@ -379,6 +379,117 @@ def rerank_results(
|
||||
return reranked_results
|
||||
|
||||
|
||||
def cross_encoder_rerank(
|
||||
query: str,
|
||||
results: List[SearchResult],
|
||||
reranker: Any,
|
||||
top_k: int = 50,
|
||||
batch_size: int = 32,
|
||||
) -> List[SearchResult]:
|
||||
"""Second-stage reranking using a cross-encoder model.
|
||||
|
||||
This function is dependency-agnostic: callers can pass any object that exposes
|
||||
a compatible `score_pairs(pairs, batch_size=...)` method.
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
if reranker is None or top_k <= 0:
|
||||
return results
|
||||
|
||||
rerank_count = min(int(top_k), len(results))
|
||||
|
||||
def text_for_pair(r: SearchResult) -> str:
|
||||
if r.excerpt and r.excerpt.strip():
|
||||
return r.excerpt
|
||||
if r.content and r.content.strip():
|
||||
return r.content
|
||||
if r.chunk and r.chunk.content and r.chunk.content.strip():
|
||||
return r.chunk.content
|
||||
return r.symbol_name or r.path
|
||||
|
||||
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
|
||||
|
||||
try:
|
||||
if hasattr(reranker, "score_pairs"):
|
||||
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
|
||||
elif hasattr(reranker, "predict"):
|
||||
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
|
||||
else:
|
||||
return results
|
||||
except Exception:
|
||||
return results
|
||||
|
||||
if not raw_scores or len(raw_scores) != rerank_count:
|
||||
return results
|
||||
|
||||
scores = [float(s) for s in raw_scores]
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
|
||||
def sigmoid(x: float) -> float:
|
||||
# Clamp to keep exp() stable.
|
||||
x = max(-50.0, min(50.0, x))
|
||||
return 1.0 / (1.0 + math.exp(-x))
|
||||
|
||||
if 0.0 <= min_s and max_s <= 1.0:
|
||||
probs = scores
|
||||
else:
|
||||
probs = [sigmoid(s) for s in scores]
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
prev_score = float(result.score)
|
||||
ce_score = scores[idx]
|
||||
ce_prob = probs[idx]
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=combined_score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"pre_cross_encoder_score": prev_score,
|
||||
"cross_encoder_score": ce_score,
|
||||
"cross_encoder_prob": ce_prob,
|
||||
"cross_encoder_reranked": True,
|
||||
},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
else:
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
score=result.score,
|
||||
excerpt=result.excerpt,
|
||||
content=result.content,
|
||||
symbol=result.symbol,
|
||||
chunk=result.chunk,
|
||||
metadata={**result.metadata},
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
symbol_name=result.symbol_name,
|
||||
symbol_kind=result.symbol_kind,
|
||||
additional_locations=list(result.additional_locations),
|
||||
)
|
||||
)
|
||||
|
||||
reranked_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return reranked_results
|
||||
|
||||
|
||||
def normalize_bm25_score(score: float) -> float:
|
||||
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user