feat: add APIEmbedder for remote embedding with multi-endpoint support

- Introduced APIEmbedder class to handle embeddings via a remote HTTP API.
- Implemented token packing to optimize batch sizes based on token limits.
- Added support for multiple API endpoints with round-robin dispatching.
- Included retry logic for API calls with exponential backoff on failures.
- Enhanced indexing pipeline with file exclusion checks and smart chunking strategies.
- Updated tests to cover new APIEmbedder functionality and ensure robustness.
This commit is contained in:
catlog22
2026-03-17 17:17:24 +08:00
parent 34749d2fad
commit f37189dc64
18 changed files with 1633 additions and 476 deletions

View File

@@ -44,6 +44,9 @@ faiss-cpu = [
faiss-gpu = [
"faiss-gpu>=1.7.4",
]
embed-api = [
"httpx>=0.25",
]
reranker-api = [
"httpx>=0.25",
]

View File

@@ -57,6 +57,73 @@ def _create_config(args: argparse.Namespace) -> "Config":
kwargs: dict = {}
if hasattr(args, "embed_model") and args.embed_model:
kwargs["embed_model"] = args.embed_model
# API embedding overrides
if hasattr(args, "embed_api_url") and args.embed_api_url:
kwargs["embed_api_url"] = args.embed_api_url
if hasattr(args, "embed_api_key") and args.embed_api_key:
kwargs["embed_api_key"] = args.embed_api_key
if hasattr(args, "embed_api_model") and args.embed_api_model:
kwargs["embed_api_model"] = args.embed_api_model
# Also check env vars as fallback
if "embed_api_url" not in kwargs and os.environ.get("CODEXLENS_EMBED_API_URL"):
kwargs["embed_api_url"] = os.environ["CODEXLENS_EMBED_API_URL"]
if "embed_api_key" not in kwargs and os.environ.get("CODEXLENS_EMBED_API_KEY"):
kwargs["embed_api_key"] = os.environ["CODEXLENS_EMBED_API_KEY"]
if "embed_api_model" not in kwargs and os.environ.get("CODEXLENS_EMBED_API_MODEL"):
kwargs["embed_api_model"] = os.environ["CODEXLENS_EMBED_API_MODEL"]
# Multi-endpoint: CODEXLENS_EMBED_API_ENDPOINTS=url1|key1|model1,url2|key2|model2
endpoints_env = os.environ.get("CODEXLENS_EMBED_API_ENDPOINTS", "")
if endpoints_env:
endpoints = []
for entry in endpoints_env.split(","):
parts = entry.strip().split("|")
if len(parts) >= 2:
ep = {"url": parts[0], "key": parts[1]}
if len(parts) >= 3:
ep["model"] = parts[2]
endpoints.append(ep)
if endpoints:
kwargs["embed_api_endpoints"] = endpoints
# Embed dimension and concurrency from env
if os.environ.get("CODEXLENS_EMBED_DIM"):
kwargs["embed_dim"] = int(os.environ["CODEXLENS_EMBED_DIM"])
if os.environ.get("CODEXLENS_EMBED_BATCH_SIZE"):
kwargs["embed_batch_size"] = int(os.environ["CODEXLENS_EMBED_BATCH_SIZE"])
if os.environ.get("CODEXLENS_EMBED_API_CONCURRENCY"):
kwargs["embed_api_concurrency"] = int(os.environ["CODEXLENS_EMBED_API_CONCURRENCY"])
if os.environ.get("CODEXLENS_EMBED_API_MAX_TOKENS"):
kwargs["embed_api_max_tokens_per_batch"] = int(os.environ["CODEXLENS_EMBED_API_MAX_TOKENS"])
# Reranker API env vars
if os.environ.get("CODEXLENS_RERANKER_API_URL"):
kwargs["reranker_api_url"] = os.environ["CODEXLENS_RERANKER_API_URL"]
if os.environ.get("CODEXLENS_RERANKER_API_KEY"):
kwargs["reranker_api_key"] = os.environ["CODEXLENS_RERANKER_API_KEY"]
if os.environ.get("CODEXLENS_RERANKER_API_MODEL"):
kwargs["reranker_api_model"] = os.environ["CODEXLENS_RERANKER_API_MODEL"]
# Search pipeline params from env
if os.environ.get("CODEXLENS_RERANKER_TOP_K"):
kwargs["reranker_top_k"] = int(os.environ["CODEXLENS_RERANKER_TOP_K"])
if os.environ.get("CODEXLENS_RERANKER_BATCH_SIZE"):
kwargs["reranker_batch_size"] = int(os.environ["CODEXLENS_RERANKER_BATCH_SIZE"])
if os.environ.get("CODEXLENS_BINARY_TOP_K"):
kwargs["binary_top_k"] = int(os.environ["CODEXLENS_BINARY_TOP_K"])
if os.environ.get("CODEXLENS_ANN_TOP_K"):
kwargs["ann_top_k"] = int(os.environ["CODEXLENS_ANN_TOP_K"])
if os.environ.get("CODEXLENS_FTS_TOP_K"):
kwargs["fts_top_k"] = int(os.environ["CODEXLENS_FTS_TOP_K"])
if os.environ.get("CODEXLENS_FUSION_K"):
kwargs["fusion_k"] = int(os.environ["CODEXLENS_FUSION_K"])
# Indexing params from env
if os.environ.get("CODEXLENS_CODE_AWARE_CHUNKING"):
kwargs["code_aware_chunking"] = os.environ["CODEXLENS_CODE_AWARE_CHUNKING"].lower() == "true"
if os.environ.get("CODEXLENS_INDEX_WORKERS"):
kwargs["index_workers"] = int(os.environ["CODEXLENS_INDEX_WORKERS"])
if os.environ.get("CODEXLENS_MAX_FILE_SIZE"):
kwargs["max_file_size_bytes"] = int(os.environ["CODEXLENS_MAX_FILE_SIZE"])
if os.environ.get("CODEXLENS_HNSW_EF"):
kwargs["hnsw_ef"] = int(os.environ["CODEXLENS_HNSW_EF"])
if os.environ.get("CODEXLENS_HNSW_M"):
kwargs["hnsw_M"] = int(os.environ["CODEXLENS_HNSW_M"])
db_path = Path(args.db_path).resolve()
kwargs["metadata_db_path"] = str(db_path / "metadata.db")
return Config(**kwargs)
@@ -72,22 +139,43 @@ def _create_pipeline(
"""
from codexlens_search.config import Config
from codexlens_search.core.factory import create_ann_index, create_binary_index
from codexlens_search.embed.local import FastEmbedEmbedder
from codexlens_search.indexing.metadata import MetadataStore
from codexlens_search.indexing.pipeline import IndexingPipeline
from codexlens_search.rerank.local import FastEmbedReranker
from codexlens_search.search.fts import FTSEngine
from codexlens_search.search.pipeline import SearchPipeline
config = _create_config(args)
db_path = _resolve_db_path(args)
embedder = FastEmbedEmbedder(config)
# Select embedder: API if configured, otherwise local fastembed
if config.embed_api_url:
from codexlens_search.embed.api import APIEmbedder
embedder = APIEmbedder(config)
log.info("Using API embedder: %s", config.embed_api_url)
# Auto-detect embed_dim from API if still at default
if config.embed_dim == 384:
probe_vec = embedder.embed_single("dimension probe")
detected_dim = probe_vec.shape[0]
if detected_dim != config.embed_dim:
log.info("Auto-detected embed_dim=%d from API (was %d)", detected_dim, config.embed_dim)
config.embed_dim = detected_dim
else:
from codexlens_search.embed.local import FastEmbedEmbedder
embedder = FastEmbedEmbedder(config)
binary_store = create_binary_index(db_path, config.embed_dim, config)
ann_index = create_ann_index(db_path, config.embed_dim, config)
fts = FTSEngine(db_path / "fts.db")
metadata = MetadataStore(db_path / "metadata.db")
reranker = FastEmbedReranker(config)
# Select reranker: API if configured, otherwise local fastembed
if config.reranker_api_url:
from codexlens_search.rerank.api import APIReranker
reranker = APIReranker(config)
log.info("Using API reranker: %s", config.reranker_api_url)
else:
from codexlens_search.rerank.local import FastEmbedReranker
reranker = FastEmbedReranker(config)
indexing = IndexingPipeline(
embedder=embedder,
@@ -181,6 +269,19 @@ def cmd_remove_file(args: argparse.Namespace) -> None:
})
_DEFAULT_EXCLUDES = frozenset({
"node_modules", ".git", "__pycache__", "dist", "build",
".venv", "venv", ".tox", ".mypy_cache", ".pytest_cache",
".next", ".nuxt", "coverage", ".eggs", "*.egg-info",
})
def _should_exclude(path: Path, exclude_dirs: frozenset[str]) -> bool:
"""Check if any path component matches an exclude pattern."""
parts = path.parts
return any(part in exclude_dirs for part in parts)
def cmd_sync(args: argparse.Namespace) -> None:
"""Sync index with files under --root matching --glob pattern."""
indexing, _, _ = _create_pipeline(args)
@@ -189,12 +290,15 @@ def cmd_sync(args: argparse.Namespace) -> None:
if not root.is_dir():
_error_exit(f"Root directory not found: {root}")
exclude_dirs = frozenset(args.exclude) if args.exclude else _DEFAULT_EXCLUDES
pattern = args.glob or "**/*"
file_paths = [
p for p in root.glob(pattern)
if p.is_file()
if p.is_file() and not _should_exclude(p.relative_to(root), exclude_dirs)
]
log.debug("Sync: %d files after exclusion (root=%s, pattern=%s)", len(file_paths), root, pattern)
stats = indexing.sync(file_paths, root=root)
_json_output({
"status": "synced",
@@ -331,6 +435,23 @@ def _build_parser() -> argparse.ArgumentParser:
help="Enable debug logging to stderr",
)
# API embedding overrides (also read from CODEXLENS_EMBED_API_* env vars)
parser.add_argument(
"--embed-api-url",
default="",
help="Remote embedding API URL (OpenAI-compatible, e.g. https://api.openai.com/v1)",
)
parser.add_argument(
"--embed-api-key",
default="",
help="API key for remote embedding",
)
parser.add_argument(
"--embed-api-model",
default="",
help="Model name for remote embedding (e.g. text-embedding-3-small)",
)
sub = parser.add_subparsers(dest="command")
# init
@@ -354,6 +475,11 @@ def _build_parser() -> argparse.ArgumentParser:
p_sync = sub.add_parser("sync", help="Sync index with directory")
p_sync.add_argument("--root", "-r", required=True, help="Root directory to sync")
p_sync.add_argument("--glob", "-g", default="**/*", help="Glob pattern (default: **/*)")
p_sync.add_argument(
"--exclude", "-e", action="append", default=None,
help="Directory names to exclude (repeatable). "
"Defaults: node_modules, .git, __pycache__, dist, build, .venv, venv, .tox, .mypy_cache",
)
# watch
p_watch = sub.add_parser("watch", help="Watch directory for changes (JSONL output)")

View File

@@ -12,6 +12,15 @@ class Config:
embed_dim: int = 384
embed_batch_size: int = 64
# API embedding (optional — overrides local fastembed when set)
embed_api_url: str = "" # e.g. "https://api.openai.com/v1"
embed_api_key: str = ""
embed_api_model: str = "" # e.g. "text-embedding-3-small"
# Multi-endpoint: list of {"url": "...", "key": "...", "model": "..."} dicts
embed_api_endpoints: list[dict[str, str]] = None # type: ignore[assignment]
embed_api_concurrency: int = 4
embed_api_max_tokens_per_batch: int = 8192
# Model download / cache
model_cache_dir: str = "" # empty = fastembed default cache
hf_mirror: str = "" # HuggingFace mirror URL, e.g. "https://hf-mirror.com"
@@ -20,6 +29,21 @@ class Config:
device: str = "auto" # 'auto', 'cuda', 'cpu'
embed_providers: list[str] | None = None # explicit ONNX providers override
# File filtering
max_file_size_bytes: int = 1_000_000 # 1MB
exclude_extensions: frozenset[str] = None # type: ignore[assignment] # set in __post_init__
binary_detect_sample_bytes: int = 2048
binary_null_threshold: float = 0.10 # >10% null bytes = binary
generated_code_markers: tuple[str, ...] = ("@generated", "DO NOT EDIT", "auto-generated", "AUTO GENERATED")
# Code-aware chunking
code_aware_chunking: bool = True
code_extensions: frozenset[str] = frozenset({
".py", ".js", ".ts", ".jsx", ".tsx", ".go", ".java", ".cpp", ".c",
".h", ".hpp", ".cs", ".rs", ".rb", ".php", ".scala", ".kt", ".swift",
".lua", ".sh", ".bash", ".zsh", ".ps1", ".vue", ".svelte",
})
# Backend selection: 'auto', 'faiss', 'hnswlib'
ann_backend: str = "auto"
binary_backend: str = "auto"
@@ -64,6 +88,29 @@ class Config:
"graph": 0.15,
})
_DEFAULT_EXCLUDE_EXTENSIONS: frozenset[str] = frozenset({
# binaries / images
".png", ".jpg", ".jpeg", ".gif", ".webp", ".ico", ".bmp", ".svg",
".zip", ".gz", ".tar", ".rar", ".7z", ".bz2",
".bin", ".exe", ".dll", ".so", ".dylib", ".a", ".o", ".obj",
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
# build / generated
".min.js", ".min.css", ".map", ".lock",
".pyc", ".pyo", ".class", ".wasm",
# data
".sqlite", ".db", ".npy", ".npz", ".pkl", ".pickle",
".parquet", ".arrow", ".feather",
# media
".mp3", ".mp4", ".wav", ".avi", ".mov", ".flv",
".ttf", ".otf", ".woff", ".woff2", ".eot",
})
def __post_init__(self) -> None:
if self.exclude_extensions is None:
object.__setattr__(self, "exclude_extensions", self._DEFAULT_EXCLUDE_EXTENSIONS)
if self.embed_api_endpoints is None:
object.__setattr__(self, "embed_api_endpoints", [])
def resolve_embed_providers(self) -> list[str]:
"""Return ONNX execution providers based on device config.

View File

@@ -1,4 +1,5 @@
from .base import BaseEmbedder
from .local import FastEmbedEmbedder, EMBED_PROFILES
from .api import APIEmbedder
__all__ = ["BaseEmbedder", "FastEmbedEmbedder", "EMBED_PROFILES"]
__all__ = ["BaseEmbedder", "FastEmbedEmbedder", "APIEmbedder", "EMBED_PROFILES"]

View File

@@ -0,0 +1,232 @@
from __future__ import annotations
import itertools
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import httpx
import numpy as np
from ..config import Config
from .base import BaseEmbedder
logger = logging.getLogger(__name__)
class _Endpoint:
"""A single API endpoint with its own client and rate-limit tracking."""
__slots__ = ("url", "key", "model", "client", "failures", "lock")
def __init__(self, url: str, key: str, model: str) -> None:
self.url = url.rstrip("/")
if not self.url.endswith("/embeddings"):
self.url += "/embeddings"
self.key = key
self.model = model
self.client = httpx.Client(
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
},
timeout=60.0,
)
self.failures = 0
self.lock = threading.Lock()
class APIEmbedder(BaseEmbedder):
"""Embedder backed by remote HTTP API (OpenAI /v1/embeddings format).
Features:
- Token packing: packs small chunks into batches up to max_tokens_per_batch
- Multi-endpoint: round-robins across multiple (url, key) pairs
- Concurrent dispatch: parallel API calls via ThreadPoolExecutor
- Per-endpoint failure tracking and retry with backoff
"""
def __init__(self, config: Config) -> None:
self._config = config
self._endpoints = self._build_endpoints(config)
self._cycler = itertools.cycle(range(len(self._endpoints)))
self._cycler_lock = threading.Lock()
self._executor = ThreadPoolExecutor(
max_workers=min(config.embed_api_concurrency, len(self._endpoints) * 2),
)
@staticmethod
def _build_endpoints(config: Config) -> list[_Endpoint]:
"""Build endpoint list from config. Supports both single and multi configs."""
endpoints: list[_Endpoint] = []
# Multi-endpoint config takes priority
if config.embed_api_endpoints:
for ep in config.embed_api_endpoints:
endpoints.append(_Endpoint(
url=ep.get("url", config.embed_api_url),
key=ep.get("key", config.embed_api_key),
model=ep.get("model", config.embed_api_model),
))
# Fallback: single endpoint from top-level config
if not endpoints and config.embed_api_url:
endpoints.append(_Endpoint(
url=config.embed_api_url,
key=config.embed_api_key,
model=config.embed_api_model,
))
if not endpoints:
raise ValueError("No API embedding endpoints configured")
return endpoints
def _next_endpoint(self) -> _Endpoint:
with self._cycler_lock:
idx = next(self._cycler)
return self._endpoints[idx]
# -- Token packing ------------------------------------------------
@staticmethod
def _estimate_tokens(text: str) -> int:
"""Rough token estimate: ~4 chars per token for code."""
return max(1, len(text) // 4)
def _pack_batches(
self, texts: list[str]
) -> list[list[tuple[int, str]]]:
"""Pack texts into batches respecting max_tokens_per_batch.
Returns list of batches, each batch is list of (original_index, text).
Also respects embed_batch_size as max items per batch.
"""
max_tokens = self._config.embed_api_max_tokens_per_batch
max_items = self._config.embed_batch_size
batches: list[list[tuple[int, str]]] = []
current: list[tuple[int, str]] = []
current_tokens = 0
for i, text in enumerate(texts):
tokens = self._estimate_tokens(text)
# Start new batch if adding this text would exceed limits
if current and (
current_tokens + tokens > max_tokens
or len(current) >= max_items
):
batches.append(current)
current = []
current_tokens = 0
current.append((i, text))
current_tokens += tokens
if current:
batches.append(current)
return batches
# -- API call with retry ------------------------------------------
def _call_api(
self,
texts: list[str],
endpoint: _Endpoint,
max_retries: int = 3,
) -> list[np.ndarray]:
"""Call a single endpoint with retry logic."""
payload: dict = {"input": texts}
if endpoint.model:
payload["model"] = endpoint.model
last_exc: Exception | None = None
for attempt in range(max_retries):
try:
response = endpoint.client.post(endpoint.url, json=payload)
except Exception as exc:
last_exc = exc
logger.warning(
"API embed %s failed (attempt %d/%d): %s",
endpoint.url, attempt + 1, max_retries, exc,
)
time.sleep((2 ** attempt) * 0.5)
continue
if response.status_code in (429, 503):
logger.warning(
"API embed %s returned HTTP %s (attempt %d/%d), retrying...",
endpoint.url, response.status_code, attempt + 1, max_retries,
)
time.sleep((2 ** attempt) * 0.5)
continue
response.raise_for_status()
data = response.json()
items = data.get("data", [])
items.sort(key=lambda x: x["index"])
vectors = [
np.array(item["embedding"], dtype=np.float32)
for item in items
]
# Reset failure counter on success
with endpoint.lock:
endpoint.failures = 0
return vectors
# Track failures
with endpoint.lock:
endpoint.failures += 1
raise RuntimeError(
f"API embed failed at {endpoint.url} after {max_retries} attempts. "
f"Last error: {last_exc}"
)
# -- Public interface ---------------------------------------------
def embed_single(self, text: str) -> np.ndarray:
endpoint = self._next_endpoint()
vecs = self._call_api([text], endpoint)
return vecs[0]
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
if not texts:
return []
# 1. Pack into token-aware batches
packed = self._pack_batches(texts)
if len(packed) == 1:
# Single batch — no concurrency overhead needed
batch_texts = [t for _, t in packed[0]]
batch_indices = [i for i, _ in packed[0]]
endpoint = self._next_endpoint()
vecs = self._call_api(batch_texts, endpoint)
results: dict[int, np.ndarray] = {}
for idx, vec in zip(batch_indices, vecs):
results[idx] = vec
return [results[i] for i in range(len(texts))]
# 2. Dispatch batches concurrently across endpoints
results: dict[int, np.ndarray] = {}
futures = []
batch_index_map: list[list[int]] = []
for batch in packed:
batch_texts = [t for _, t in batch]
batch_indices = [i for i, _ in batch]
endpoint = self._next_endpoint()
future = self._executor.submit(self._call_api, batch_texts, endpoint)
futures.append(future)
batch_index_map.append(batch_indices)
for future, indices in zip(futures, batch_index_map):
vecs = future.result() # propagates exceptions
for idx, vec in zip(indices, vecs):
results[idx] = vec
return [results[i] for i in range(len(texts))]

View File

@@ -8,6 +8,7 @@ from __future__ import annotations
import hashlib
import logging
import queue
import re
import threading
import time
from dataclasses import dataclass
@@ -32,6 +33,52 @@ _DEFAULT_MAX_CHUNK_CHARS = 800
_DEFAULT_CHUNK_OVERLAP = 100
def is_file_excluded(file_path: Path, config: Config) -> str | None:
"""Check if a file should be excluded from indexing.
Returns exclusion reason string, or None if file should be indexed.
"""
# Extension check
suffix = file_path.suffix.lower()
# Handle compound extensions like .min.js
name_lower = file_path.name.lower()
for ext in config.exclude_extensions:
if name_lower.endswith(ext):
return f"excluded extension: {ext}"
# File size check
try:
size = file_path.stat().st_size
except OSError:
return "cannot stat file"
if size > config.max_file_size_bytes:
return f"exceeds max size ({size} > {config.max_file_size_bytes})"
if size == 0:
return "empty file"
# Binary detection: sample first N bytes
try:
with open(file_path, "rb") as f:
sample = f.read(config.binary_detect_sample_bytes)
except OSError:
return "cannot read file"
if sample:
null_ratio = sample.count(b"\x00") / len(sample)
if null_ratio > config.binary_null_threshold:
return f"binary file (null ratio: {null_ratio:.2%})"
# Generated code markers (check first 1KB of text)
try:
head = file_path.read_text(encoding="utf-8", errors="replace")[:1024]
except OSError:
return None # can't check, let it through
for marker in config.generated_code_markers:
if marker in head:
return f"generated code marker: {marker}"
return None
@dataclass
class IndexStats:
"""Statistics returned after indexing completes."""
@@ -126,16 +173,19 @@ class IndexingPipeline:
chunks_created = 0
for fpath in files:
# Noise file filter
exclude_reason = is_file_excluded(fpath, self._config)
if exclude_reason:
logger.debug("Skipping %s: %s", fpath, exclude_reason)
continue
try:
if fpath.stat().st_size > max_file_size:
continue
text = fpath.read_text(encoding="utf-8", errors="replace")
except Exception as exc:
logger.debug("Skipping %s: %s", fpath, exc)
continue
rel_path = str(fpath.relative_to(root)) if root else str(fpath)
file_chunks = self._chunk_text(text, rel_path, max_chunk_chars, chunk_overlap)
file_chunks = self._smart_chunk(text, rel_path, max_chunk_chars, chunk_overlap)
if not file_chunks:
continue
@@ -290,6 +340,106 @@ class IndexingPipeline:
return chunks
# Pattern matching top-level definitions across languages
_CODE_BOUNDARY_RE = re.compile(
r"^(?:"
r"(?:export\s+)?(?:async\s+)?(?:def|class|function)\s+" # Python/JS/TS
r"|(?:pub\s+)?(?:fn|struct|impl|enum|trait|mod)\s+" # Rust
r"|(?:func|type)\s+" # Go
r"|(?:public|private|protected|internal)?\s*(?:static\s+)?(?:class|interface|enum|record)\s+" # Java/C#
r"|(?:namespace|template)\s+" # C++
r")",
re.MULTILINE,
)
def _chunk_code(
self,
text: str,
path: str,
max_chars: int,
overlap: int,
) -> list[tuple[str, str, int, int]]:
"""Split code at function/class boundaries with fallback to _chunk_text.
Strategy:
1. Find all top-level definition boundaries via regex.
2. Split text into segments at those boundaries.
3. Merge small adjacent segments up to max_chars.
4. If a segment exceeds max_chars, fall back to _chunk_text for that segment.
"""
lines = text.splitlines(keepends=True)
if not lines:
return []
# Find boundary line numbers (0-based)
boundaries: list[int] = [0] # always start at line 0
for i, line in enumerate(lines):
if i == 0:
continue
# Only match lines with no or minimal indentation (top-level)
stripped = line.lstrip()
indent = len(line) - len(stripped)
if indent <= 4 and self._CODE_BOUNDARY_RE.match(stripped):
boundaries.append(i)
if len(boundaries) <= 1:
# No boundaries found, fall back to text chunking
return self._chunk_text(text, path, max_chars, overlap)
# Build raw segments between boundaries
raw_segments: list[tuple[int, int]] = [] # (start_line, end_line) 0-based
for idx in range(len(boundaries)):
start = boundaries[idx]
end = boundaries[idx + 1] if idx + 1 < len(boundaries) else len(lines)
raw_segments.append((start, end))
# Merge small adjacent segments up to max_chars
merged: list[tuple[int, int]] = []
cur_start, cur_end = raw_segments[0]
cur_len = sum(len(lines[i]) for i in range(cur_start, cur_end))
for seg_start, seg_end in raw_segments[1:]:
seg_len = sum(len(lines[i]) for i in range(seg_start, seg_end))
if cur_len + seg_len <= max_chars:
cur_end = seg_end
cur_len += seg_len
else:
merged.append((cur_start, cur_end))
cur_start, cur_end = seg_start, seg_end
cur_len = seg_len
merged.append((cur_start, cur_end))
# Build chunks, falling back to _chunk_text for oversized segments
chunks: list[tuple[str, str, int, int]] = []
for seg_start, seg_end in merged:
seg_text = "".join(lines[seg_start:seg_end])
if len(seg_text) > max_chars:
# Oversized: sub-chunk with text splitter
sub_chunks = self._chunk_text(seg_text, path, max_chars, overlap)
# Adjust line numbers relative to segment start
for chunk_text, p, sl, el in sub_chunks:
chunks.append((chunk_text, p, sl + seg_start, el + seg_start))
else:
chunks.append((seg_text, path, seg_start + 1, seg_end))
return chunks
def _smart_chunk(
self,
text: str,
path: str,
max_chars: int,
overlap: int,
) -> list[tuple[str, str, int, int]]:
"""Choose chunking strategy based on file type and config."""
if self._config.code_aware_chunking:
suffix = Path(path).suffix.lower()
if suffix in self._config.code_extensions:
result = self._chunk_code(text, path, max_chars, overlap)
if result:
return result
return self._chunk_text(text, path, max_chars, overlap)
# ------------------------------------------------------------------
# Incremental API
# ------------------------------------------------------------------
@@ -342,11 +492,14 @@ class IndexingPipeline:
meta = self._require_metadata()
t0 = time.monotonic()
# Noise file filter
exclude_reason = is_file_excluded(file_path, self._config)
if exclude_reason:
logger.debug("Skipping %s: %s", file_path, exclude_reason)
return IndexStats(duration_seconds=round(time.monotonic() - t0, 2))
# Read file
try:
if file_path.stat().st_size > max_file_size:
logger.debug("Skipping %s: exceeds max_file_size", file_path)
return IndexStats(duration_seconds=round(time.monotonic() - t0, 2))
text = file_path.read_text(encoding="utf-8", errors="replace")
except Exception as exc:
logger.debug("Skipping %s: %s", file_path, exc)
@@ -366,7 +519,7 @@ class IndexingPipeline:
self._fts.delete_by_path(rel_path)
# Chunk
file_chunks = self._chunk_text(text, rel_path, max_chunk_chars, chunk_overlap)
file_chunks = self._smart_chunk(text, rel_path, max_chunk_chars, chunk_overlap)
if not file_chunks:
# Register file with no chunks
meta.register_file(rel_path, content_hash, file_path.stat().st_mtime)

View File

@@ -21,6 +21,7 @@ _make_fastembed_mock()
from codexlens_search.config import Config # noqa: E402
from codexlens_search.embed.base import BaseEmbedder # noqa: E402
from codexlens_search.embed.local import EMBED_PROFILES, FastEmbedEmbedder # noqa: E402
from codexlens_search.embed.api import APIEmbedder # noqa: E402
class TestEmbedSingle(unittest.TestCase):
@@ -76,5 +77,182 @@ class TestBaseEmbedderAbstract(unittest.TestCase):
BaseEmbedder() # type: ignore[abstract]
# ---------------------------------------------------------------------------
# APIEmbedder
# ---------------------------------------------------------------------------
def _make_api_config(**overrides) -> Config:
defaults = dict(
embed_api_url="https://api.example.com/v1",
embed_api_key="test-key",
embed_api_model="text-embedding-3-small",
embed_dim=384,
embed_batch_size=2,
embed_api_max_tokens_per_batch=8192,
embed_api_concurrency=2,
)
defaults.update(overrides)
return Config(**defaults)
def _mock_200(count=1, dim=384):
r = MagicMock()
r.status_code = 200
r.json.return_value = {
"data": [{"index": j, "embedding": [0.1 * (j + 1)] * dim} for j in range(count)]
}
r.raise_for_status = MagicMock()
return r
class TestAPIEmbedderSingle(unittest.TestCase):
def test_embed_single_returns_float32(self):
config = _make_api_config()
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.return_value = _mock_200(1, 384)
embedder = APIEmbedder(config)
result = embedder.embed_single("hello")
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.dtype, np.float32)
self.assertEqual(result.shape, (384,))
class TestAPIEmbedderBatch(unittest.TestCase):
def test_embed_batch_splits_by_batch_size(self):
config = _make_api_config(embed_batch_size=2)
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [_mock_200(2, 384), _mock_200(1, 384)]
embedder = APIEmbedder(config)
result = embedder.embed_batch(["a", "b", "c"])
self.assertEqual(len(result), 3)
for arr in result:
self.assertIsInstance(arr, np.ndarray)
self.assertEqual(arr.dtype, np.float32)
def test_embed_batch_empty_returns_empty(self):
config = _make_api_config()
with patch("httpx.Client"):
embedder = APIEmbedder(config)
result = embedder.embed_batch([])
self.assertEqual(result, [])
class TestAPIEmbedderRetry(unittest.TestCase):
def test_retry_on_429(self):
config = _make_api_config()
mock_429 = MagicMock()
mock_429.status_code = 429
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [mock_429, _mock_200(1, 384)]
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
with patch("time.sleep"):
result = embedder._call_api(["test"], ep)
self.assertEqual(len(result), 1)
self.assertEqual(mock_client.post.call_count, 2)
def test_raises_after_max_retries(self):
config = _make_api_config()
mock_429 = MagicMock()
mock_429.status_code = 429
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.return_value = mock_429
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
with patch("time.sleep"):
with self.assertRaises(RuntimeError):
embedder._call_api(["test"], ep, max_retries=2)
class TestAPIEmbedderTokenPacking(unittest.TestCase):
def test_packs_small_texts_together(self):
config = _make_api_config(
embed_batch_size=100,
embed_api_max_tokens_per_batch=100, # ~400 chars
)
with patch("httpx.Client"):
embedder = APIEmbedder(config)
# 5 texts of 80 chars each (~20 tokens) -> 100 tokens = 1 batch at limit
texts = ["x" * 80] * 5
batches = embedder._pack_batches(texts)
# Should pack as many as fit under 100 tokens
self.assertTrue(len(batches) >= 1)
total_items = sum(len(b) for b in batches)
self.assertEqual(total_items, 5)
def test_large_text_gets_own_batch(self):
config = _make_api_config(
embed_batch_size=100,
embed_api_max_tokens_per_batch=50, # ~200 chars
)
with patch("httpx.Client"):
embedder = APIEmbedder(config)
# Mix of small and large texts
texts = ["small" * 10, "x" * 800, "tiny"]
batches = embedder._pack_batches(texts)
# Large text (200 tokens) exceeds 50 limit, should be separate
self.assertTrue(len(batches) >= 2)
class TestAPIEmbedderMultiEndpoint(unittest.TestCase):
def test_multi_endpoint_config(self):
config = _make_api_config(
embed_api_endpoints=[
{"url": "https://ep1.example.com/v1", "key": "k1", "model": "m1"},
{"url": "https://ep2.example.com/v1", "key": "k2", "model": "m2"},
]
)
with patch("httpx.Client"):
embedder = APIEmbedder(config)
self.assertEqual(len(embedder._endpoints), 2)
self.assertTrue(embedder._endpoints[0].url.endswith("/embeddings"))
self.assertTrue(embedder._endpoints[1].url.endswith("/embeddings"))
def test_single_endpoint_fallback(self):
config = _make_api_config() # no embed_api_endpoints
with patch("httpx.Client"):
embedder = APIEmbedder(config)
self.assertEqual(len(embedder._endpoints), 1)
class TestAPIEmbedderUrlNormalization(unittest.TestCase):
def test_appends_embeddings_path(self):
config = _make_api_config(embed_api_url="https://api.example.com/v1")
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.return_value = _mock_200(1, 384)
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
self.assertTrue(ep.url.endswith("/embeddings"))
def test_does_not_double_append(self):
config = _make_api_config(embed_api_url="https://api.example.com/v1/embeddings")
with patch("httpx.Client"):
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
self.assertFalse(ep.url.endswith("/embeddings/embeddings"))
if __name__ == "__main__":
unittest.main()