mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-02-10 02:24:35 +08:00
feat: 添加配置选项以调整重排序模型的权重和测试文件惩罚,增强语义搜索功能
This commit is contained in:
@@ -141,6 +141,12 @@ class Config:
|
||||
reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
reranker_top_k: int = 50
|
||||
reranker_max_input_tokens: int = 8192 # Maximum tokens for reranker API batching
|
||||
reranker_chunk_type_weights: Optional[Dict[str, float]] = None # Weights for chunk types: {"code": 1.0, "docstring": 0.7}
|
||||
reranker_test_file_penalty: float = 0.0 # Penalty for test files (0.0-1.0, e.g., 0.2 = 20% reduction)
|
||||
|
||||
# Chunk stripping configuration (for semantic embedding)
|
||||
chunk_strip_comments: bool = True # Strip comments from code chunks
|
||||
chunk_strip_docstrings: bool = True # Strip docstrings from code chunks
|
||||
|
||||
# Cascade search configuration (two-stage retrieval)
|
||||
enable_cascade_search: bool = False # Enable cascade search (coarse + fine ranking)
|
||||
@@ -545,6 +551,35 @@ class Config:
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_MAX_INPUT_TOKENS in .env: %r", reranker_max_tokens)
|
||||
|
||||
# Reranker tuning from environment
|
||||
test_penalty = get_env("RERANKER_TEST_FILE_PENALTY")
|
||||
if test_penalty:
|
||||
try:
|
||||
self.reranker_test_file_penalty = float(test_penalty)
|
||||
log.debug("Overriding reranker_test_file_penalty from .env: %s", self.reranker_test_file_penalty)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_TEST_FILE_PENALTY in .env: %r", test_penalty)
|
||||
|
||||
docstring_weight = get_env("RERANKER_DOCSTRING_WEIGHT")
|
||||
if docstring_weight:
|
||||
try:
|
||||
weight = float(docstring_weight)
|
||||
self.reranker_chunk_type_weights = {"code": 1.0, "docstring": weight}
|
||||
log.debug("Overriding reranker docstring weight from .env: %s", weight)
|
||||
except ValueError:
|
||||
log.warning("Invalid RERANKER_DOCSTRING_WEIGHT in .env: %r", docstring_weight)
|
||||
|
||||
# Chunk stripping from environment
|
||||
strip_comments = get_env("CHUNK_STRIP_COMMENTS")
|
||||
if strip_comments:
|
||||
self.chunk_strip_comments = strip_comments.lower() in ("true", "1", "yes")
|
||||
log.debug("Overriding chunk_strip_comments from .env: %s", self.chunk_strip_comments)
|
||||
|
||||
strip_docstrings = get_env("CHUNK_STRIP_DOCSTRINGS")
|
||||
if strip_docstrings:
|
||||
self.chunk_strip_docstrings = strip_docstrings.lower() in ("true", "1", "yes")
|
||||
log.debug("Overriding chunk_strip_docstrings from .env: %s", self.chunk_strip_docstrings)
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> "Config":
|
||||
"""Load config with settings from file."""
|
||||
|
||||
@@ -45,6 +45,12 @@ ENV_VARS = {
|
||||
# General configuration
|
||||
"CODEXLENS_DATA_DIR": "Custom data directory path",
|
||||
"CODEXLENS_DEBUG": "Enable debug mode (true/false)",
|
||||
# Chunking configuration
|
||||
"CHUNK_STRIP_COMMENTS": "Strip comments from code chunks for embedding: true/false (default: true)",
|
||||
"CHUNK_STRIP_DOCSTRINGS": "Strip docstrings from code chunks for embedding: true/false (default: true)",
|
||||
# Reranker tuning
|
||||
"RERANKER_TEST_FILE_PENALTY": "Penalty for test files in reranking: 0.0-1.0 (default: 0.0)",
|
||||
"RERANKER_DOCSTRING_WEIGHT": "Weight for docstring chunks in reranking: 0.0-1.0 (default: 1.0)",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1816,12 +1816,22 @@ class ChainSearchEngine:
|
||||
# Use cross_encoder_rerank from ranking module
|
||||
from codexlens.search.ranking import cross_encoder_rerank
|
||||
|
||||
# Get chunk_type weights and test_file_penalty from config
|
||||
chunk_type_weights = None
|
||||
test_file_penalty = 0.0
|
||||
|
||||
if self._config is not None:
|
||||
chunk_type_weights = getattr(self._config, "reranker_chunk_type_weights", None)
|
||||
test_file_penalty = getattr(self._config, "reranker_test_file_penalty", 0.0)
|
||||
|
||||
return cross_encoder_rerank(
|
||||
query=query,
|
||||
results=results,
|
||||
reranker=reranker,
|
||||
top_k=top_k,
|
||||
batch_size=32,
|
||||
chunk_type_weights=chunk_type_weights,
|
||||
test_file_penalty=test_file_penalty,
|
||||
)
|
||||
|
||||
def search_files_only(self, query: str,
|
||||
|
||||
@@ -613,11 +613,24 @@ def cross_encoder_rerank(
|
||||
reranker: Any,
|
||||
top_k: int = 50,
|
||||
batch_size: int = 32,
|
||||
chunk_type_weights: Optional[Dict[str, float]] = None,
|
||||
test_file_penalty: float = 0.0,
|
||||
) -> List[SearchResult]:
|
||||
"""Second-stage reranking using a cross-encoder model.
|
||||
|
||||
This function is dependency-agnostic: callers can pass any object that exposes
|
||||
a compatible `score_pairs(pairs, batch_size=...)` method.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
results: List of search results to rerank
|
||||
reranker: Cross-encoder model with score_pairs or predict method
|
||||
top_k: Number of top results to rerank
|
||||
batch_size: Batch size for reranking
|
||||
chunk_type_weights: Optional weights for different chunk types.
|
||||
Example: {"code": 1.0, "docstring": 0.7} - reduce docstring influence
|
||||
test_file_penalty: Penalty applied to test files (0.0-1.0).
|
||||
Example: 0.2 means test files get 20% score reduction
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
@@ -667,13 +680,50 @@ def cross_encoder_rerank(
|
||||
|
||||
reranked_results: List[SearchResult] = []
|
||||
|
||||
# Helper to detect test files
|
||||
def is_test_file(path: str) -> bool:
|
||||
if not path:
|
||||
return False
|
||||
basename = path.split("/")[-1].split("\\")[-1]
|
||||
return (
|
||||
basename.startswith("test_") or
|
||||
basename.endswith("_test.py") or
|
||||
basename.endswith(".test.ts") or
|
||||
basename.endswith(".test.js") or
|
||||
basename.endswith(".spec.ts") or
|
||||
basename.endswith(".spec.js") or
|
||||
"/tests/" in path or
|
||||
"\\tests\\" in path or
|
||||
"/test/" in path or
|
||||
"\\test\\" in path
|
||||
)
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if idx < rerank_count:
|
||||
prev_score = float(result.score)
|
||||
ce_score = scores[idx]
|
||||
ce_prob = probs[idx]
|
||||
|
||||
# Base combined score
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob
|
||||
|
||||
# Apply chunk_type weight adjustment
|
||||
if chunk_type_weights:
|
||||
chunk_type = None
|
||||
if result.chunk and hasattr(result.chunk, "metadata"):
|
||||
chunk_type = result.chunk.metadata.get("chunk_type")
|
||||
elif result.metadata:
|
||||
chunk_type = result.metadata.get("chunk_type")
|
||||
|
||||
if chunk_type and chunk_type in chunk_type_weights:
|
||||
weight = chunk_type_weights[chunk_type]
|
||||
# Apply weight to CE contribution only
|
||||
combined_score = 0.5 * prev_score + 0.5 * ce_prob * weight
|
||||
|
||||
# Apply test file penalty
|
||||
if test_file_penalty > 0 and is_test_file(result.path):
|
||||
combined_score = combined_score * (1.0 - test_file_penalty)
|
||||
|
||||
reranked_results.append(
|
||||
SearchResult(
|
||||
path=result.path,
|
||||
|
||||
@@ -43,6 +43,250 @@ class ChunkConfig:
|
||||
strategy: str = "auto" # Chunking strategy: auto, symbol, sliding_window, hybrid
|
||||
min_chunk_size: int = 50 # Minimum chunk size
|
||||
skip_token_count: bool = False # Skip expensive token counting (use char/4 estimate)
|
||||
strip_comments: bool = True # Remove comments from chunk content for embedding
|
||||
strip_docstrings: bool = True # Remove docstrings from chunk content for embedding
|
||||
preserve_original: bool = True # Store original content in metadata when stripping
|
||||
|
||||
|
||||
class CommentStripper:
|
||||
"""Remove comments from source code while preserving structure."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_comments(content: str) -> str:
|
||||
"""Strip Python comments (# style) but preserve docstrings.
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
for line in lines:
|
||||
new_line = []
|
||||
i = 0
|
||||
while i < len(line):
|
||||
char = line[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'") and not in_string:
|
||||
# Check for triple quotes
|
||||
if line[i:i+3] in ('"""', "'''"):
|
||||
in_string = True
|
||||
string_char = line[i:i+3]
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
continue
|
||||
else:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif in_string:
|
||||
if string_char and len(string_char) == 3:
|
||||
if line[i:i+3] == string_char:
|
||||
in_string = False
|
||||
new_line.append(line[i:i+3])
|
||||
i += 3
|
||||
string_char = None
|
||||
continue
|
||||
elif char == string_char:
|
||||
# Check for escape
|
||||
if i > 0 and line[i-1] != '\\':
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
# Handle comments (only outside strings)
|
||||
if char == '#' and not in_string:
|
||||
# Rest of line is comment, skip it
|
||||
new_line.append('\n' if line.endswith('\n') else '')
|
||||
break
|
||||
|
||||
new_line.append(char)
|
||||
i += 1
|
||||
|
||||
result_lines.append(''.join(new_line))
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_c_style_comments(content: str) -> str:
|
||||
"""Strip C-style comments (// and /* */) from code.
|
||||
|
||||
Args:
|
||||
content: Source code with C-style comments
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_multiline_comment = False
|
||||
|
||||
while i < len(content):
|
||||
# Handle multi-line comment end
|
||||
if in_multiline_comment:
|
||||
if content[i:i+2] == '*/':
|
||||
in_multiline_comment = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
char = content[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'", '`') and not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
elif in_string:
|
||||
result.append(char)
|
||||
if char == string_char and (i == 0 or content[i-1] != '\\'):
|
||||
in_string = False
|
||||
string_char = None
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Handle comments
|
||||
if content[i:i+2] == '//':
|
||||
# Single line comment - skip to end of line
|
||||
while i < len(content) and content[i] != '\n':
|
||||
i += 1
|
||||
if i < len(content):
|
||||
result.append('\n')
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if content[i:i+2] == '/*':
|
||||
in_multiline_comment = True
|
||||
i += 2
|
||||
continue
|
||||
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_comments(cls, content: str, language: str) -> str:
|
||||
"""Strip comments based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with comments removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_comments(content)
|
||||
elif language in {"javascript", "typescript", "java", "c", "cpp", "go", "rust"}:
|
||||
return cls.strip_c_style_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class DocstringStripper:
|
||||
"""Remove docstrings from source code."""
|
||||
|
||||
@staticmethod
|
||||
def strip_python_docstrings(content: str) -> str:
|
||||
"""Strip Python docstrings (triple-quoted strings at module/class/function level).
|
||||
|
||||
Args:
|
||||
content: Python source code
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
lines = content.splitlines(keepends=True)
|
||||
result_lines: List[str] = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
# Check for docstring start
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
quote_type = '"""' if stripped.startswith('"""') else "'''"
|
||||
|
||||
# Single line docstring
|
||||
if stripped.count(quote_type) >= 2:
|
||||
# Skip this line (docstring)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Multi-line docstring - skip until closing
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
if quote_type in lines[i]:
|
||||
i += 1
|
||||
break
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
i += 1
|
||||
|
||||
return ''.join(result_lines)
|
||||
|
||||
@staticmethod
|
||||
def strip_jsdoc_comments(content: str) -> str:
|
||||
"""Strip JSDoc comments (/** ... */) from code.
|
||||
|
||||
Args:
|
||||
content: JavaScript/TypeScript source code
|
||||
|
||||
Returns:
|
||||
Code with JSDoc comments removed
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
in_jsdoc = False
|
||||
|
||||
while i < len(content):
|
||||
if in_jsdoc:
|
||||
if content[i:i+2] == '*/':
|
||||
in_jsdoc = False
|
||||
i += 2
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Check for JSDoc start (/** but not /*)
|
||||
if content[i:i+3] == '/**':
|
||||
in_jsdoc = True
|
||||
i += 3
|
||||
continue
|
||||
|
||||
result.append(content[i])
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
@classmethod
|
||||
def strip_docstrings(cls, content: str, language: str) -> str:
|
||||
"""Strip docstrings based on language.
|
||||
|
||||
Args:
|
||||
content: Source code content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Code with docstrings removed
|
||||
"""
|
||||
if language == "python":
|
||||
return cls.strip_python_docstrings(content)
|
||||
elif language in {"javascript", "typescript"}:
|
||||
return cls.strip_jsdoc_comments(content)
|
||||
return content
|
||||
|
||||
|
||||
class Chunker:
|
||||
@@ -51,6 +295,33 @@ class Chunker:
|
||||
def __init__(self, config: ChunkConfig | None = None) -> None:
|
||||
self.config = config or ChunkConfig()
|
||||
self._tokenizer = get_default_tokenizer()
|
||||
self._comment_stripper = CommentStripper()
|
||||
self._docstring_stripper = DocstringStripper()
|
||||
|
||||
def _process_content(self, content: str, language: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process chunk content by stripping comments/docstrings if configured.
|
||||
|
||||
Args:
|
||||
content: Original chunk content
|
||||
language: Programming language
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_content, original_content_if_preserved)
|
||||
"""
|
||||
original = content if self.config.preserve_original else None
|
||||
processed = content
|
||||
|
||||
if self.config.strip_comments:
|
||||
processed = self._comment_stripper.strip_comments(processed, language)
|
||||
|
||||
if self.config.strip_docstrings:
|
||||
processed = self._docstring_stripper.strip_docstrings(processed, language)
|
||||
|
||||
# If nothing changed, don't store original
|
||||
if processed == content:
|
||||
original = None
|
||||
|
||||
return processed, original
|
||||
|
||||
def _estimate_token_count(self, text: str) -> int:
|
||||
"""Estimate token count based on config.
|
||||
@@ -120,30 +391,45 @@ class Chunker:
|
||||
sub_chunk.metadata["symbol_name"] = symbol.name
|
||||
sub_chunk.metadata["symbol_kind"] = symbol.kind
|
||||
sub_chunk.metadata["strategy"] = "symbol_split"
|
||||
sub_chunk.metadata["chunk_type"] = "code"
|
||||
sub_chunk.metadata["parent_symbol_range"] = (start_line, end_line)
|
||||
|
||||
chunks.extend(sub_chunks)
|
||||
else:
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
continue
|
||||
|
||||
# Calculate token count if not provided
|
||||
token_count = None
|
||||
if symbol_token_counts and symbol.name in symbol_token_counts:
|
||||
token_count = symbol_token_counts[symbol.name]
|
||||
else:
|
||||
token_count = self._estimate_token_count(chunk_content)
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"symbol_name": symbol.name,
|
||||
"symbol_kind": symbol.kind,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "symbol",
|
||||
"token_count": token_count,
|
||||
}
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
return chunks
|
||||
@@ -188,7 +474,19 @@ class Chunker:
|
||||
chunk_content = "".join(lines[start:end])
|
||||
|
||||
if len(chunk_content.strip()) >= self.config.min_chunk_size:
|
||||
token_count = self._estimate_token_count(chunk_content)
|
||||
# Process content (strip comments/docstrings if configured)
|
||||
processed_content, original_content = self._process_content(chunk_content, language)
|
||||
|
||||
# Skip if processed content is too small
|
||||
if len(processed_content.strip()) < self.config.min_chunk_size:
|
||||
# Move window forward
|
||||
step = lines_per_chunk - overlap_lines
|
||||
if step <= 0:
|
||||
step = 1
|
||||
start += step
|
||||
continue
|
||||
|
||||
token_count = self._estimate_token_count(processed_content)
|
||||
|
||||
# Calculate correct line numbers
|
||||
if line_mapping:
|
||||
@@ -200,18 +498,25 @@ class Chunker:
|
||||
start_line = start + 1
|
||||
end_line = end
|
||||
|
||||
metadata = {
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"chunk_type": "code",
|
||||
"token_count": token_count,
|
||||
}
|
||||
|
||||
# Store original content if it was modified
|
||||
if original_content is not None:
|
||||
metadata["original_content"] = original_content
|
||||
|
||||
chunks.append(SemanticChunk(
|
||||
content=chunk_content,
|
||||
content=processed_content,
|
||||
embedding=None,
|
||||
metadata={
|
||||
"file": str(file_path),
|
||||
"language": language,
|
||||
"chunk_index": chunk_idx,
|
||||
"start_line": start_line,
|
||||
"end_line": end_line,
|
||||
"strategy": "sliding_window",
|
||||
"token_count": token_count,
|
||||
}
|
||||
metadata=metadata
|
||||
))
|
||||
chunk_idx += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user