Add graph expansion and cross-encoder reranking features

- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors.
- Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring.
- Created migrations to establish necessary database tables for relationships and graph neighbors.
- Developed tests for graph expansion functionality, ensuring related results are populated correctly.
- Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead.
- Updated schema cleanup tests to reflect changes in versioning and deprecated fields.
- Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
catlog22
2025-12-31 16:58:59 +08:00
parent 4bde13e83a
commit 31a45f1f30
27 changed files with 2566 additions and 97 deletions

View File

@@ -379,6 +379,117 @@ def rerank_results(
return reranked_results
def cross_encoder_rerank(
query: str,
results: List[SearchResult],
reranker: Any,
top_k: int = 50,
batch_size: int = 32,
) -> List[SearchResult]:
"""Second-stage reranking using a cross-encoder model.
This function is dependency-agnostic: callers can pass any object that exposes
a compatible `score_pairs(pairs, batch_size=...)` method.
"""
if not results:
return []
if reranker is None or top_k <= 0:
return results
rerank_count = min(int(top_k), len(results))
def text_for_pair(r: SearchResult) -> str:
if r.excerpt and r.excerpt.strip():
return r.excerpt
if r.content and r.content.strip():
return r.content
if r.chunk and r.chunk.content and r.chunk.content.strip():
return r.chunk.content
return r.symbol_name or r.path
pairs = [(query, text_for_pair(r)) for r in results[:rerank_count]]
try:
if hasattr(reranker, "score_pairs"):
raw_scores = reranker.score_pairs(pairs, batch_size=int(batch_size))
elif hasattr(reranker, "predict"):
raw_scores = reranker.predict(pairs, batch_size=int(batch_size))
else:
return results
except Exception:
return results
if not raw_scores or len(raw_scores) != rerank_count:
return results
scores = [float(s) for s in raw_scores]
min_s = min(scores)
max_s = max(scores)
def sigmoid(x: float) -> float:
# Clamp to keep exp() stable.
x = max(-50.0, min(50.0, x))
return 1.0 / (1.0 + math.exp(-x))
if 0.0 <= min_s and max_s <= 1.0:
probs = scores
else:
probs = [sigmoid(s) for s in scores]
reranked_results: List[SearchResult] = []
for idx, result in enumerate(results):
if idx < rerank_count:
prev_score = float(result.score)
ce_score = scores[idx]
ce_prob = probs[idx]
combined_score = 0.5 * prev_score + 0.5 * ce_prob
reranked_results.append(
SearchResult(
path=result.path,
score=combined_score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={
**result.metadata,
"pre_cross_encoder_score": prev_score,
"cross_encoder_score": ce_score,
"cross_encoder_prob": ce_prob,
"cross_encoder_reranked": True,
},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
else:
reranked_results.append(
SearchResult(
path=result.path,
score=result.score,
excerpt=result.excerpt,
content=result.content,
symbol=result.symbol,
chunk=result.chunk,
metadata={**result.metadata},
start_line=result.start_line,
end_line=result.end_line,
symbol_name=result.symbol_name,
symbol_kind=result.symbol_kind,
additional_locations=list(result.additional_locations),
)
)
reranked_results.sort(key=lambda r: r.score, reverse=True)
return reranked_results
def normalize_bm25_score(score: float) -> float:
"""Normalize BM25 scores from SQLite FTS5 to 0-1 range.