feat: 添加配置选项以调整重排序模型的权重和测试文件惩罚,增强语义搜索功能

This commit is contained in:
catlog22
2026-01-13 10:44:26 +08:00
parent bf06f4ddcc
commit 8c2d39d517
9 changed files with 1043 additions and 23 deletions

View File

@@ -141,6 +141,12 @@ class Config:
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
reranker_top_k: int = 50
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
# Chunk stripping configuration (for semantic embedding)
chunk_strip_comments: bool = True # Strip comments from code chunks
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
# Cascade search configuration (two-stage retrieval)
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
@@ -545,6 +551,35 @@ class Config:
except ValueError:
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
# Reranker tuning from environment
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
if test_penalty:
try:
self.reranker_test_file_penalty = float(test_penalty)
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
except ValueError:
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
if docstring_weight:
try:
weight = float(docstring_weight)
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
log.debug("Overriding reranker docstring weight from .env: %s", weight)
except ValueError:
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
# Chunk stripping from environment
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
if strip_comments:
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
if strip_docstrings:
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
@classmethod
def load(cls) -> "Config":
"""Load config with settings from file."""

View File

@@ -45,6 +45,12 @@ ENV_VARS = {
# General configuration
"CODEXLENS_DATA_DIR": "Custom data directory path",
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
# Chunking configuration
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
# Reranker tuning
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
}

View File

@@ -1816,12 +1816,22 @@ class ChainSearchEngine:
# Use cross_encoder_rerank from ranking module
from codexlens.search.ranking import cross_encoder_rerank
# Get chunk_type weights and test_file_penalty from config
chunk_type_weights = None
test_file_penalty = 0.0
if self._config is not None:
chunk_type_weights = getattr(self._config, "reranker_chunk_type_weights", None)
test_file_penalty = getattr(self._config, "reranker_test_file_penalty", 0.0)
return cross_encoder_rerank(
query=query,
results=results,
reranker=reranker,
top_k=top_k,
batch_size=32,
chunk_type_weights=chunk_type_weights,
test_file_penalty=test_file_penalty,
)
def search_files_only(self, query: str,

View File

@@ -613,11 +613,24 @@ def cross_encoder_rerank(
reranker: Any,
top_k: int = 50,
batch_size: int = 32,
chunk_type_weights: Optional[Dict[str, float]] = None,
test_file_penalty: float = 0.0,
) -> List[SearchResult]:
"""Second-stage reranking using a cross-encoder model.
This function is dependency-agnostic: callers can pass any object that exposes
a compatible `score_pairs(pairs, batch_size=...)` method.
Args:
query: Search query string
results: List of search results to rerank
reranker: Cross-encoder model with score_pairs or predict method
top_k: Number of top results to rerank
batch_size: Batch size for reranking
chunk_type_weights: Optional weights for different chunk types.
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
test_file_penalty: Penalty applied to test files (0.0-1.0).
Example: 0.2 means test files get 20% score reduction
"""
if not results:
return []
@@ -667,13 +680,50 @@ def cross_encoder_rerank(
reranked_results: List[SearchResult] = []
# Helper to detect test files
def is_test_file(path: str) -> bool:
if not path:
return False
basename = path.split("/")[-1].split("\\")[-1]
return (
basename.startswith("test_") or
basename.endswith("_test.py") or
basename.endswith(".test.ts") or
basename.endswith(".test.js") or
basename.endswith(".spec.ts") or
basename.endswith(".spec.js") or
"/tests/" in path or
"\\tests\\" in path or
"/test/" in path or
"\\test\\" in path
)
for idx, result in enumerate(results):
if idx < rerank_count:
prev_score = float(result.score)
ce_score = scores[idx]
ce_prob = probs[idx]
# Base combined score
combined_score = 0.5 * prev_score + 0.5 * ce_prob
# Apply chunk_type weight adjustment
if chunk_type_weights:
chunk_type = None
if result.chunk and hasattr(result.chunk, "metadata"):
chunk_type = result.chunk.metadata.get("chunk_type")
elif result.metadata:
chunk_type = result.metadata.get("chunk_type")
if chunk_type and chunk_type in chunk_type_weights:
weight = chunk_type_weights[chunk_type]
# Apply weight to CE contribution only
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
# Apply test file penalty
if test_file_penalty > 0 and is_test_file(result.path):
combined_score = combined_score * (1.0 - test_file_penalty)
reranked_results.append(
SearchResult(
path=result.path,

View File

@@ -43,6 +43,250 @@ class ChunkConfig:
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
min_chunk_size: int = 50 # Minimum chunk size
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
strip_comments: bool = True # Remove comments from chunk content for embedding
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
preserve_original: bool = True # Store original content in metadata when stripping
class CommentStripper:
"""Remove comments from source code while preserving structure."""
@staticmethod
def strip_python_comments(content: str) -> str:
"""Strip Python comments (# style) but preserve docstrings.
Args:
content: Python source code
Returns:
Code with comments removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
in_string = False
string_char = None
for line in lines:
new_line = []
i = 0
while i < len(line):
char = line[i]
# Handle string literals
if char in ('"', "'") and not in_string:
# Check for triple quotes
if line[i:i+3] in ('"""', "'''"):
in_string = True
string_char = line[i:i+3]
new_line.append(line[i:i+3])
i += 3
continue
else:
in_string = True
string_char = char
elif in_string:
if string_char and len(string_char) == 3:
if line[i:i+3] == string_char:
in_string = False
new_line.append(line[i:i+3])
i += 3
string_char = None
continue
elif char == string_char:
# Check for escape
if i > 0 and line[i-1] != '\\':
in_string = False
string_char = None
# Handle comments (only outside strings)
if char == '#' and not in_string:
# Rest of line is comment, skip it
new_line.append('\n' if line.endswith('\n') else '')
break
new_line.append(char)
i += 1
result_lines.append(''.join(new_line))
return ''.join(result_lines)
@staticmethod
def strip_c_style_comments(content: str) -> str:
"""Strip C-style comments (// and /* */) from code.
Args:
content: Source code with C-style comments
Returns:
Code with comments removed
"""
result = []
i = 0
in_string = False
string_char = None
in_multiline_comment = False
while i < len(content):
# Handle multi-line comment end
if in_multiline_comment:
if content[i:i+2] == '*/':
in_multiline_comment = False
i += 2
continue
i += 1
continue
char = content[i]
# Handle string literals
if char in ('"', "'", '`') and not in_string:
in_string = True
string_char = char
result.append(char)
i += 1
continue
elif in_string:
result.append(char)
if char == string_char and (i == 0 or content[i-1] != '\\'):
in_string = False
string_char = None
i += 1
continue
# Handle comments
if content[i:i+2] == '//':
# Single line comment - skip to end of line
while i < len(content) and content[i] != '\n':
i += 1
if i < len(content):
result.append('\n')
i += 1
continue
if content[i:i+2] == '/*':
in_multiline_comment = True
i += 2
continue
result.append(char)
i += 1
return ''.join(result)
@classmethod
def strip_comments(cls, content: str, language: str) -> str:
"""Strip comments based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with comments removed
"""
if language == "python":
return cls.strip_python_comments(content)
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
return cls.strip_c_style_comments(content)
return content
class DocstringStripper:
"""Remove docstrings from source code."""
@staticmethod
def strip_python_docstrings(content: str) -> str:
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
Args:
content: Python source code
Returns:
Code with docstrings removed
"""
lines = content.splitlines(keepends=True)
result_lines: List[str] = []
i = 0
while i < len(lines):
line = lines[i]
stripped = line.strip()
# Check for docstring start
if stripped.startswith('"""') or stripped.startswith("'''"):
quote_type = '"""' if stripped.startswith('"""') else "'''"
# Single line docstring
if stripped.count(quote_type) >= 2:
# Skip this line (docstring)
i += 1
continue
# Multi-line docstring - skip until closing
i += 1
while i < len(lines):
if quote_type in lines[i]:
i += 1
break
i += 1
continue
result_lines.append(line)
i += 1
return ''.join(result_lines)
@staticmethod
def strip_jsdoc_comments(content: str) -> str:
"""Strip JSDoc comments (/** ... */) from code.
Args:
content: JavaScript/TypeScript source code
Returns:
Code with JSDoc comments removed
"""
result = []
i = 0
in_jsdoc = False
while i < len(content):
if in_jsdoc:
if content[i:i+2] == '*/':
in_jsdoc = False
i += 2
continue
i += 1
continue
# Check for JSDoc start (/** but not /*)
if content[i:i+3] == '/**':
in_jsdoc = True
i += 3
continue
result.append(content[i])
i += 1
return ''.join(result)
@classmethod
def strip_docstrings(cls, content: str, language: str) -> str:
"""Strip docstrings based on language.
Args:
content: Source code content
language: Programming language
Returns:
Code with docstrings removed
"""
if language == "python":
return cls.strip_python_docstrings(content)
elif language in {"javascript", "typescript"}:
return cls.strip_jsdoc_comments(content)
return content
class Chunker:
@@ -51,6 +295,33 @@ class Chunker:
def __init__(self, config: ChunkConfig | None = None) -> None:
self.config = config or ChunkConfig()
self._tokenizer = get_default_tokenizer()
self._comment_stripper = CommentStripper()
self._docstring_stripper = DocstringStripper()
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
"""Process chunk content by stripping comments/docstrings if configured.
Args:
content: Original chunk content
language: Programming language
Returns:
Tuple of (processed_content, original_content_if_preserved)
"""
original = content if self.config.preserve_original else None
processed = content
if self.config.strip_comments:
processed = self._comment_stripper.strip_comments(processed, language)
if self.config.strip_docstrings:
processed = self._docstring_stripper.strip_docstrings(processed, language)
# If nothing changed, don't store original
if processed == content:
original = None
return processed, original
def _estimate_token_count(self, text: str) -> int:
"""Estimate token count based on config.
@@ -120,30 +391,45 @@ class Chunker:
sub_chunk.metadata["symbol_name"] = symbol.name
sub_chunk.metadata["symbol_kind"] = symbol.kind
sub_chunk.metadata["strategy"] = "symbol_split"
sub_chunk.metadata["chunk_type"] = "code"
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
chunks.extend(sub_chunks)
else:
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
continue
# Calculate token count if not provided
token_count = None
if symbol_token_counts and symbol.name in symbol_token_counts:
token_count = symbol_token_counts[symbol.name]
else:
token_count = self._estimate_token_count(chunk_content)
token_count = self._estimate_token_count(processed_content)
metadata = {
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=chunk_content,
content=processed_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"symbol_name": symbol.name,
"symbol_kind": symbol.kind,
"start_line": start_line,
"end_line": end_line,
"strategy": "symbol",
"token_count": token_count,
}
metadata=metadata
))
return chunks
@@ -188,7 +474,19 @@ class Chunker:
chunk_content = "".join(lines[start:end])
if len(chunk_content.strip()) >= self.config.min_chunk_size:
token_count = self._estimate_token_count(chunk_content)
# Process content (strip comments/docstrings if configured)
processed_content, original_content = self._process_content(chunk_content, language)
# Skip if processed content is too small
if len(processed_content.strip()) < self.config.min_chunk_size:
# Move window forward
step = lines_per_chunk - overlap_lines
if step <= 0:
step = 1
start += step
continue
token_count = self._estimate_token_count(processed_content)
# Calculate correct line numbers
if line_mapping:
@@ -200,18 +498,25 @@ class Chunker:
start_line = start + 1
end_line = end
metadata = {
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"chunk_type": "code",
"token_count": token_count,
}
# Store original content if it was modified
if original_content is not None:
metadata["original_content"] = original_content
chunks.append(SemanticChunk(
content=chunk_content,
content=processed_content,
embedding=None,
metadata={
"file": str(file_path),
"language": language,
"chunk_index": chunk_idx,
"start_line": start_line,
"end_line": end_line,
"strategy": "sliding_window",
"token_count": token_count,
}
metadata=metadata
))
chunk_idx += 1