mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-11 02:33:51 +08:00
Add graph expansion and cross-encoder reranking features
- Implemented GraphExpander to enhance search results with related symbols using precomputed neighbors. - Added CrossEncoderReranker for second-stage search ranking, allowing for improved result scoring. - Created migrations to establish necessary database tables for relationships and graph neighbors. - Developed tests for graph expansion functionality, ensuring related results are populated correctly. - Enhanced performance benchmarks for cross-encoder reranking latency and graph expansion overhead. - Updated schema cleanup tests to reflect changes in versioning and deprecated fields. - Added new test cases for Treesitter parser to validate relationship extraction with alias resolution.
This commit is contained in:
86
codex-lens/src/codexlens/semantic/reranker.py
Normal file
86
codex-lens/src/codexlens/semantic/reranker.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Optional cross-encoder reranker for second-stage search ranking.
|
||||
|
||||
Install with: pip install codexlens[reranker]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder as _CrossEncoder
|
||||
|
||||
CROSS_ENCODER_AVAILABLE = True
|
||||
_import_error: str | None = None
|
||||
except ImportError as exc: # pragma: no cover - optional dependency
|
||||
_CrossEncoder = None # type: ignore[assignment]
|
||||
CROSS_ENCODER_AVAILABLE = False
|
||||
_import_error = str(exc)
|
||||
|
||||
|
||||
def check_cross_encoder_available() -> tuple[bool, str | None]:
|
||||
if CROSS_ENCODER_AVAILABLE:
|
||||
return True, None
|
||||
return False, _import_error or "sentence-transformers not available. Install with: pip install codexlens[reranker]"
|
||||
|
||||
|
||||
class CrossEncoderReranker:
|
||||
"""Cross-encoder reranker with lazy model loading."""
|
||||
|
||||
def __init__(self, model_name: str, *, device: str | None = None) -> None:
|
||||
self.model_name = (model_name or "").strip()
|
||||
if not self.model_name:
|
||||
raise ValueError("model_name cannot be blank")
|
||||
|
||||
self.device = (device or "").strip() or None
|
||||
self._model = None
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _load_model(self) -> None:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
ok, err = check_cross_encoder_available()
|
||||
if not ok:
|
||||
raise ImportError(err)
|
||||
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.device:
|
||||
self._model = _CrossEncoder(self.model_name, device=self.device) # type: ignore[misc]
|
||||
else:
|
||||
self._model = _CrossEncoder(self.model_name) # type: ignore[misc]
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load cross-encoder model %s: %s", self.model_name, exc)
|
||||
raise
|
||||
|
||||
def score_pairs(
|
||||
self,
|
||||
pairs: Sequence[Tuple[str, str]],
|
||||
*,
|
||||
batch_size: int = 32,
|
||||
) -> List[float]:
|
||||
"""Score (query, doc) pairs using the cross-encoder.
|
||||
|
||||
Returns:
|
||||
List of scores (one per pair) in the model's native scale (usually logits).
|
||||
"""
|
||||
if not pairs:
|
||||
return []
|
||||
|
||||
self._load_model()
|
||||
|
||||
if self._model is None: # pragma: no cover - defensive
|
||||
return []
|
||||
|
||||
bs = int(batch_size) if batch_size and int(batch_size) > 0 else 32
|
||||
scores = self._model.predict(list(pairs), batch_size=bs) # type: ignore[union-attr]
|
||||
return [float(s) for s in scores]
|
||||
|
||||
Reference in New Issue
Block a user