mirror of
https://github.com/catlog22/Claude-Code-Workflow.git
synced 2026-03-18 18:48:48 +08:00
feat: add APIEmbedder for remote embedding with multi-endpoint support
- Introduced APIEmbedder class to handle embeddings via a remote HTTP API. - Implemented token packing to optimize batch sizes based on token limits. - Added support for multiple API endpoints with round-robin dispatching. - Included retry logic for API calls with exponential backoff on failures. - Enhanced indexing pipeline with file exclusion checks and smart chunking strategies. - Updated tests to cover new APIEmbedder functionality and ensure robustness.
This commit is contained in:
@@ -44,6 +44,9 @@ faiss-cpu = [
|
||||
faiss-gpu = [
|
||||
"faiss-gpu>=1.7.4",
|
||||
]
|
||||
embed-api = [
|
||||
"httpx>=0.25",
|
||||
]
|
||||
reranker-api = [
|
||||
"httpx>=0.25",
|
||||
]
|
||||
|
||||
@@ -57,6 +57,73 @@ def _create_config(args: argparse.Namespace) -> "Config":
|
||||
kwargs: dict = {}
|
||||
if hasattr(args, "embed_model") and args.embed_model:
|
||||
kwargs["embed_model"] = args.embed_model
|
||||
# API embedding overrides
|
||||
if hasattr(args, "embed_api_url") and args.embed_api_url:
|
||||
kwargs["embed_api_url"] = args.embed_api_url
|
||||
if hasattr(args, "embed_api_key") and args.embed_api_key:
|
||||
kwargs["embed_api_key"] = args.embed_api_key
|
||||
if hasattr(args, "embed_api_model") and args.embed_api_model:
|
||||
kwargs["embed_api_model"] = args.embed_api_model
|
||||
# Also check env vars as fallback
|
||||
if "embed_api_url" not in kwargs and os.environ.get("CODEXLENS_EMBED_API_URL"):
|
||||
kwargs["embed_api_url"] = os.environ["CODEXLENS_EMBED_API_URL"]
|
||||
if "embed_api_key" not in kwargs and os.environ.get("CODEXLENS_EMBED_API_KEY"):
|
||||
kwargs["embed_api_key"] = os.environ["CODEXLENS_EMBED_API_KEY"]
|
||||
if "embed_api_model" not in kwargs and os.environ.get("CODEXLENS_EMBED_API_MODEL"):
|
||||
kwargs["embed_api_model"] = os.environ["CODEXLENS_EMBED_API_MODEL"]
|
||||
# Multi-endpoint: CODEXLENS_EMBED_API_ENDPOINTS=url1|key1|model1,url2|key2|model2
|
||||
endpoints_env = os.environ.get("CODEXLENS_EMBED_API_ENDPOINTS", "")
|
||||
if endpoints_env:
|
||||
endpoints = []
|
||||
for entry in endpoints_env.split(","):
|
||||
parts = entry.strip().split("|")
|
||||
if len(parts) >= 2:
|
||||
ep = {"url": parts[0], "key": parts[1]}
|
||||
if len(parts) >= 3:
|
||||
ep["model"] = parts[2]
|
||||
endpoints.append(ep)
|
||||
if endpoints:
|
||||
kwargs["embed_api_endpoints"] = endpoints
|
||||
# Embed dimension and concurrency from env
|
||||
if os.environ.get("CODEXLENS_EMBED_DIM"):
|
||||
kwargs["embed_dim"] = int(os.environ["CODEXLENS_EMBED_DIM"])
|
||||
if os.environ.get("CODEXLENS_EMBED_BATCH_SIZE"):
|
||||
kwargs["embed_batch_size"] = int(os.environ["CODEXLENS_EMBED_BATCH_SIZE"])
|
||||
if os.environ.get("CODEXLENS_EMBED_API_CONCURRENCY"):
|
||||
kwargs["embed_api_concurrency"] = int(os.environ["CODEXLENS_EMBED_API_CONCURRENCY"])
|
||||
if os.environ.get("CODEXLENS_EMBED_API_MAX_TOKENS"):
|
||||
kwargs["embed_api_max_tokens_per_batch"] = int(os.environ["CODEXLENS_EMBED_API_MAX_TOKENS"])
|
||||
# Reranker API env vars
|
||||
if os.environ.get("CODEXLENS_RERANKER_API_URL"):
|
||||
kwargs["reranker_api_url"] = os.environ["CODEXLENS_RERANKER_API_URL"]
|
||||
if os.environ.get("CODEXLENS_RERANKER_API_KEY"):
|
||||
kwargs["reranker_api_key"] = os.environ["CODEXLENS_RERANKER_API_KEY"]
|
||||
if os.environ.get("CODEXLENS_RERANKER_API_MODEL"):
|
||||
kwargs["reranker_api_model"] = os.environ["CODEXLENS_RERANKER_API_MODEL"]
|
||||
# Search pipeline params from env
|
||||
if os.environ.get("CODEXLENS_RERANKER_TOP_K"):
|
||||
kwargs["reranker_top_k"] = int(os.environ["CODEXLENS_RERANKER_TOP_K"])
|
||||
if os.environ.get("CODEXLENS_RERANKER_BATCH_SIZE"):
|
||||
kwargs["reranker_batch_size"] = int(os.environ["CODEXLENS_RERANKER_BATCH_SIZE"])
|
||||
if os.environ.get("CODEXLENS_BINARY_TOP_K"):
|
||||
kwargs["binary_top_k"] = int(os.environ["CODEXLENS_BINARY_TOP_K"])
|
||||
if os.environ.get("CODEXLENS_ANN_TOP_K"):
|
||||
kwargs["ann_top_k"] = int(os.environ["CODEXLENS_ANN_TOP_K"])
|
||||
if os.environ.get("CODEXLENS_FTS_TOP_K"):
|
||||
kwargs["fts_top_k"] = int(os.environ["CODEXLENS_FTS_TOP_K"])
|
||||
if os.environ.get("CODEXLENS_FUSION_K"):
|
||||
kwargs["fusion_k"] = int(os.environ["CODEXLENS_FUSION_K"])
|
||||
# Indexing params from env
|
||||
if os.environ.get("CODEXLENS_CODE_AWARE_CHUNKING"):
|
||||
kwargs["code_aware_chunking"] = os.environ["CODEXLENS_CODE_AWARE_CHUNKING"].lower() == "true"
|
||||
if os.environ.get("CODEXLENS_INDEX_WORKERS"):
|
||||
kwargs["index_workers"] = int(os.environ["CODEXLENS_INDEX_WORKERS"])
|
||||
if os.environ.get("CODEXLENS_MAX_FILE_SIZE"):
|
||||
kwargs["max_file_size_bytes"] = int(os.environ["CODEXLENS_MAX_FILE_SIZE"])
|
||||
if os.environ.get("CODEXLENS_HNSW_EF"):
|
||||
kwargs["hnsw_ef"] = int(os.environ["CODEXLENS_HNSW_EF"])
|
||||
if os.environ.get("CODEXLENS_HNSW_M"):
|
||||
kwargs["hnsw_M"] = int(os.environ["CODEXLENS_HNSW_M"])
|
||||
db_path = Path(args.db_path).resolve()
|
||||
kwargs["metadata_db_path"] = str(db_path / "metadata.db")
|
||||
return Config(**kwargs)
|
||||
@@ -72,22 +139,43 @@ def _create_pipeline(
|
||||
"""
|
||||
from codexlens_search.config import Config
|
||||
from codexlens_search.core.factory import create_ann_index, create_binary_index
|
||||
from codexlens_search.embed.local import FastEmbedEmbedder
|
||||
from codexlens_search.indexing.metadata import MetadataStore
|
||||
from codexlens_search.indexing.pipeline import IndexingPipeline
|
||||
from codexlens_search.rerank.local import FastEmbedReranker
|
||||
from codexlens_search.search.fts import FTSEngine
|
||||
from codexlens_search.search.pipeline import SearchPipeline
|
||||
|
||||
config = _create_config(args)
|
||||
db_path = _resolve_db_path(args)
|
||||
|
||||
embedder = FastEmbedEmbedder(config)
|
||||
# Select embedder: API if configured, otherwise local fastembed
|
||||
if config.embed_api_url:
|
||||
from codexlens_search.embed.api import APIEmbedder
|
||||
embedder = APIEmbedder(config)
|
||||
log.info("Using API embedder: %s", config.embed_api_url)
|
||||
# Auto-detect embed_dim from API if still at default
|
||||
if config.embed_dim == 384:
|
||||
probe_vec = embedder.embed_single("dimension probe")
|
||||
detected_dim = probe_vec.shape[0]
|
||||
if detected_dim != config.embed_dim:
|
||||
log.info("Auto-detected embed_dim=%d from API (was %d)", detected_dim, config.embed_dim)
|
||||
config.embed_dim = detected_dim
|
||||
else:
|
||||
from codexlens_search.embed.local import FastEmbedEmbedder
|
||||
embedder = FastEmbedEmbedder(config)
|
||||
|
||||
binary_store = create_binary_index(db_path, config.embed_dim, config)
|
||||
ann_index = create_ann_index(db_path, config.embed_dim, config)
|
||||
fts = FTSEngine(db_path / "fts.db")
|
||||
metadata = MetadataStore(db_path / "metadata.db")
|
||||
reranker = FastEmbedReranker(config)
|
||||
|
||||
# Select reranker: API if configured, otherwise local fastembed
|
||||
if config.reranker_api_url:
|
||||
from codexlens_search.rerank.api import APIReranker
|
||||
reranker = APIReranker(config)
|
||||
log.info("Using API reranker: %s", config.reranker_api_url)
|
||||
else:
|
||||
from codexlens_search.rerank.local import FastEmbedReranker
|
||||
reranker = FastEmbedReranker(config)
|
||||
|
||||
indexing = IndexingPipeline(
|
||||
embedder=embedder,
|
||||
@@ -181,6 +269,19 @@ def cmd_remove_file(args: argparse.Namespace) -> None:
|
||||
})
|
||||
|
||||
|
||||
_DEFAULT_EXCLUDES = frozenset({
|
||||
"node_modules", ".git", "__pycache__", "dist", "build",
|
||||
".venv", "venv", ".tox", ".mypy_cache", ".pytest_cache",
|
||||
".next", ".nuxt", "coverage", ".eggs", "*.egg-info",
|
||||
})
|
||||
|
||||
|
||||
def _should_exclude(path: Path, exclude_dirs: frozenset[str]) -> bool:
|
||||
"""Check if any path component matches an exclude pattern."""
|
||||
parts = path.parts
|
||||
return any(part in exclude_dirs for part in parts)
|
||||
|
||||
|
||||
def cmd_sync(args: argparse.Namespace) -> None:
|
||||
"""Sync index with files under --root matching --glob pattern."""
|
||||
indexing, _, _ = _create_pipeline(args)
|
||||
@@ -189,12 +290,15 @@ def cmd_sync(args: argparse.Namespace) -> None:
|
||||
if not root.is_dir():
|
||||
_error_exit(f"Root directory not found: {root}")
|
||||
|
||||
exclude_dirs = frozenset(args.exclude) if args.exclude else _DEFAULT_EXCLUDES
|
||||
pattern = args.glob or "**/*"
|
||||
file_paths = [
|
||||
p for p in root.glob(pattern)
|
||||
if p.is_file()
|
||||
if p.is_file() and not _should_exclude(p.relative_to(root), exclude_dirs)
|
||||
]
|
||||
|
||||
log.debug("Sync: %d files after exclusion (root=%s, pattern=%s)", len(file_paths), root, pattern)
|
||||
|
||||
stats = indexing.sync(file_paths, root=root)
|
||||
_json_output({
|
||||
"status": "synced",
|
||||
@@ -331,6 +435,23 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
help="Enable debug logging to stderr",
|
||||
)
|
||||
|
||||
# API embedding overrides (also read from CODEXLENS_EMBED_API_* env vars)
|
||||
parser.add_argument(
|
||||
"--embed-api-url",
|
||||
default="",
|
||||
help="Remote embedding API URL (OpenAI-compatible, e.g. https://api.openai.com/v1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embed-api-key",
|
||||
default="",
|
||||
help="API key for remote embedding",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embed-api-model",
|
||||
default="",
|
||||
help="Model name for remote embedding (e.g. text-embedding-3-small)",
|
||||
)
|
||||
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
# init
|
||||
@@ -354,6 +475,11 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
p_sync = sub.add_parser("sync", help="Sync index with directory")
|
||||
p_sync.add_argument("--root", "-r", required=True, help="Root directory to sync")
|
||||
p_sync.add_argument("--glob", "-g", default="**/*", help="Glob pattern (default: **/*)")
|
||||
p_sync.add_argument(
|
||||
"--exclude", "-e", action="append", default=None,
|
||||
help="Directory names to exclude (repeatable). "
|
||||
"Defaults: node_modules, .git, __pycache__, dist, build, .venv, venv, .tox, .mypy_cache",
|
||||
)
|
||||
|
||||
# watch
|
||||
p_watch = sub.add_parser("watch", help="Watch directory for changes (JSONL output)")
|
||||
|
||||
@@ -12,6 +12,15 @@ class Config:
|
||||
embed_dim: int = 384
|
||||
embed_batch_size: int = 64
|
||||
|
||||
# API embedding (optional — overrides local fastembed when set)
|
||||
embed_api_url: str = "" # e.g. "https://api.openai.com/v1"
|
||||
embed_api_key: str = ""
|
||||
embed_api_model: str = "" # e.g. "text-embedding-3-small"
|
||||
# Multi-endpoint: list of {"url": "...", "key": "...", "model": "..."} dicts
|
||||
embed_api_endpoints: list[dict[str, str]] = None # type: ignore[assignment]
|
||||
embed_api_concurrency: int = 4
|
||||
embed_api_max_tokens_per_batch: int = 8192
|
||||
|
||||
# Model download / cache
|
||||
model_cache_dir: str = "" # empty = fastembed default cache
|
||||
hf_mirror: str = "" # HuggingFace mirror URL, e.g. "https://hf-mirror.com"
|
||||
@@ -20,6 +29,21 @@ class Config:
|
||||
device: str = "auto" # 'auto', 'cuda', 'cpu'
|
||||
embed_providers: list[str] | None = None # explicit ONNX providers override
|
||||
|
||||
# File filtering
|
||||
max_file_size_bytes: int = 1_000_000 # 1MB
|
||||
exclude_extensions: frozenset[str] = None # type: ignore[assignment] # set in __post_init__
|
||||
binary_detect_sample_bytes: int = 2048
|
||||
binary_null_threshold: float = 0.10 # >10% null bytes = binary
|
||||
generated_code_markers: tuple[str, ...] = ("@generated", "DO NOT EDIT", "auto-generated", "AUTO GENERATED")
|
||||
|
||||
# Code-aware chunking
|
||||
code_aware_chunking: bool = True
|
||||
code_extensions: frozenset[str] = frozenset({
|
||||
".py", ".js", ".ts", ".jsx", ".tsx", ".go", ".java", ".cpp", ".c",
|
||||
".h", ".hpp", ".cs", ".rs", ".rb", ".php", ".scala", ".kt", ".swift",
|
||||
".lua", ".sh", ".bash", ".zsh", ".ps1", ".vue", ".svelte",
|
||||
})
|
||||
|
||||
# Backend selection: 'auto', 'faiss', 'hnswlib'
|
||||
ann_backend: str = "auto"
|
||||
binary_backend: str = "auto"
|
||||
@@ -64,6 +88,29 @@ class Config:
|
||||
"graph": 0.15,
|
||||
})
|
||||
|
||||
_DEFAULT_EXCLUDE_EXTENSIONS: frozenset[str] = frozenset({
|
||||
# binaries / images
|
||||
".png", ".jpg", ".jpeg", ".gif", ".webp", ".ico", ".bmp", ".svg",
|
||||
".zip", ".gz", ".tar", ".rar", ".7z", ".bz2",
|
||||
".bin", ".exe", ".dll", ".so", ".dylib", ".a", ".o", ".obj",
|
||||
".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
|
||||
# build / generated
|
||||
".min.js", ".min.css", ".map", ".lock",
|
||||
".pyc", ".pyo", ".class", ".wasm",
|
||||
# data
|
||||
".sqlite", ".db", ".npy", ".npz", ".pkl", ".pickle",
|
||||
".parquet", ".arrow", ".feather",
|
||||
# media
|
||||
".mp3", ".mp4", ".wav", ".avi", ".mov", ".flv",
|
||||
".ttf", ".otf", ".woff", ".woff2", ".eot",
|
||||
})
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.exclude_extensions is None:
|
||||
object.__setattr__(self, "exclude_extensions", self._DEFAULT_EXCLUDE_EXTENSIONS)
|
||||
if self.embed_api_endpoints is None:
|
||||
object.__setattr__(self, "embed_api_endpoints", [])
|
||||
|
||||
def resolve_embed_providers(self) -> list[str]:
|
||||
"""Return ONNX execution providers based on device config.
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .base import BaseEmbedder
|
||||
from .local import FastEmbedEmbedder, EMBED_PROFILES
|
||||
from .api import APIEmbedder
|
||||
|
||||
__all__ = ["BaseEmbedder", "FastEmbedEmbedder", "EMBED_PROFILES"]
|
||||
__all__ = ["BaseEmbedder", "FastEmbedEmbedder", "APIEmbedder", "EMBED_PROFILES"]
|
||||
|
||||
232
codex-lens-v2/src/codexlens_search/embed/api.py
Normal file
232
codex-lens-v2/src/codexlens_search/embed/api.py
Normal file
@@ -0,0 +1,232 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from ..config import Config
|
||||
from .base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _Endpoint:
|
||||
"""A single API endpoint with its own client and rate-limit tracking."""
|
||||
|
||||
__slots__ = ("url", "key", "model", "client", "failures", "lock")
|
||||
|
||||
def __init__(self, url: str, key: str, model: str) -> None:
|
||||
self.url = url.rstrip("/")
|
||||
if not self.url.endswith("/embeddings"):
|
||||
self.url += "/embeddings"
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.client = httpx.Client(
|
||||
headers={
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=60.0,
|
||||
)
|
||||
self.failures = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
|
||||
class APIEmbedder(BaseEmbedder):
|
||||
"""Embedder backed by remote HTTP API (OpenAI /v1/embeddings format).
|
||||
|
||||
Features:
|
||||
- Token packing: packs small chunks into batches up to max_tokens_per_batch
|
||||
- Multi-endpoint: round-robins across multiple (url, key) pairs
|
||||
- Concurrent dispatch: parallel API calls via ThreadPoolExecutor
|
||||
- Per-endpoint failure tracking and retry with backoff
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config) -> None:
|
||||
self._config = config
|
||||
self._endpoints = self._build_endpoints(config)
|
||||
self._cycler = itertools.cycle(range(len(self._endpoints)))
|
||||
self._cycler_lock = threading.Lock()
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=min(config.embed_api_concurrency, len(self._endpoints) * 2),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_endpoints(config: Config) -> list[_Endpoint]:
|
||||
"""Build endpoint list from config. Supports both single and multi configs."""
|
||||
endpoints: list[_Endpoint] = []
|
||||
|
||||
# Multi-endpoint config takes priority
|
||||
if config.embed_api_endpoints:
|
||||
for ep in config.embed_api_endpoints:
|
||||
endpoints.append(_Endpoint(
|
||||
url=ep.get("url", config.embed_api_url),
|
||||
key=ep.get("key", config.embed_api_key),
|
||||
model=ep.get("model", config.embed_api_model),
|
||||
))
|
||||
|
||||
# Fallback: single endpoint from top-level config
|
||||
if not endpoints and config.embed_api_url:
|
||||
endpoints.append(_Endpoint(
|
||||
url=config.embed_api_url,
|
||||
key=config.embed_api_key,
|
||||
model=config.embed_api_model,
|
||||
))
|
||||
|
||||
if not endpoints:
|
||||
raise ValueError("No API embedding endpoints configured")
|
||||
|
||||
return endpoints
|
||||
|
||||
def _next_endpoint(self) -> _Endpoint:
|
||||
with self._cycler_lock:
|
||||
idx = next(self._cycler)
|
||||
return self._endpoints[idx]
|
||||
|
||||
# -- Token packing ------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""Rough token estimate: ~4 chars per token for code."""
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
def _pack_batches(
|
||||
self, texts: list[str]
|
||||
) -> list[list[tuple[int, str]]]:
|
||||
"""Pack texts into batches respecting max_tokens_per_batch.
|
||||
|
||||
Returns list of batches, each batch is list of (original_index, text).
|
||||
Also respects embed_batch_size as max items per batch.
|
||||
"""
|
||||
max_tokens = self._config.embed_api_max_tokens_per_batch
|
||||
max_items = self._config.embed_batch_size
|
||||
batches: list[list[tuple[int, str]]] = []
|
||||
current: list[tuple[int, str]] = []
|
||||
current_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokens = self._estimate_tokens(text)
|
||||
# Start new batch if adding this text would exceed limits
|
||||
if current and (
|
||||
current_tokens + tokens > max_tokens
|
||||
or len(current) >= max_items
|
||||
):
|
||||
batches.append(current)
|
||||
current = []
|
||||
current_tokens = 0
|
||||
current.append((i, text))
|
||||
current_tokens += tokens
|
||||
|
||||
if current:
|
||||
batches.append(current)
|
||||
|
||||
return batches
|
||||
|
||||
# -- API call with retry ------------------------------------------
|
||||
|
||||
def _call_api(
|
||||
self,
|
||||
texts: list[str],
|
||||
endpoint: _Endpoint,
|
||||
max_retries: int = 3,
|
||||
) -> list[np.ndarray]:
|
||||
"""Call a single endpoint with retry logic."""
|
||||
payload: dict = {"input": texts}
|
||||
if endpoint.model:
|
||||
payload["model"] = endpoint.model
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = endpoint.client.post(endpoint.url, json=payload)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logger.warning(
|
||||
"API embed %s failed (attempt %d/%d): %s",
|
||||
endpoint.url, attempt + 1, max_retries, exc,
|
||||
)
|
||||
time.sleep((2 ** attempt) * 0.5)
|
||||
continue
|
||||
|
||||
if response.status_code in (429, 503):
|
||||
logger.warning(
|
||||
"API embed %s returned HTTP %s (attempt %d/%d), retrying...",
|
||||
endpoint.url, response.status_code, attempt + 1, max_retries,
|
||||
)
|
||||
time.sleep((2 ** attempt) * 0.5)
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
items = data.get("data", [])
|
||||
items.sort(key=lambda x: x["index"])
|
||||
vectors = [
|
||||
np.array(item["embedding"], dtype=np.float32)
|
||||
for item in items
|
||||
]
|
||||
|
||||
# Reset failure counter on success
|
||||
with endpoint.lock:
|
||||
endpoint.failures = 0
|
||||
|
||||
return vectors
|
||||
|
||||
# Track failures
|
||||
with endpoint.lock:
|
||||
endpoint.failures += 1
|
||||
|
||||
raise RuntimeError(
|
||||
f"API embed failed at {endpoint.url} after {max_retries} attempts. "
|
||||
f"Last error: {last_exc}"
|
||||
)
|
||||
|
||||
# -- Public interface ---------------------------------------------
|
||||
|
||||
def embed_single(self, text: str) -> np.ndarray:
|
||||
endpoint = self._next_endpoint()
|
||||
vecs = self._call_api([text], endpoint)
|
||||
return vecs[0]
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 1. Pack into token-aware batches
|
||||
packed = self._pack_batches(texts)
|
||||
|
||||
if len(packed) == 1:
|
||||
# Single batch — no concurrency overhead needed
|
||||
batch_texts = [t for _, t in packed[0]]
|
||||
batch_indices = [i for i, _ in packed[0]]
|
||||
endpoint = self._next_endpoint()
|
||||
vecs = self._call_api(batch_texts, endpoint)
|
||||
results: dict[int, np.ndarray] = {}
|
||||
for idx, vec in zip(batch_indices, vecs):
|
||||
results[idx] = vec
|
||||
return [results[i] for i in range(len(texts))]
|
||||
|
||||
# 2. Dispatch batches concurrently across endpoints
|
||||
results: dict[int, np.ndarray] = {}
|
||||
futures = []
|
||||
batch_index_map: list[list[int]] = []
|
||||
|
||||
for batch in packed:
|
||||
batch_texts = [t for _, t in batch]
|
||||
batch_indices = [i for i, _ in batch]
|
||||
endpoint = self._next_endpoint()
|
||||
future = self._executor.submit(self._call_api, batch_texts, endpoint)
|
||||
futures.append(future)
|
||||
batch_index_map.append(batch_indices)
|
||||
|
||||
for future, indices in zip(futures, batch_index_map):
|
||||
vecs = future.result() # propagates exceptions
|
||||
for idx, vec in zip(indices, vecs):
|
||||
results[idx] = vec
|
||||
|
||||
return [results[i] for i in range(len(texts))]
|
||||
@@ -8,6 +8,7 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -32,6 +33,52 @@ _DEFAULT_MAX_CHUNK_CHARS = 800
|
||||
_DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
|
||||
def is_file_excluded(file_path: Path, config: Config) -> str | None:
|
||||
"""Check if a file should be excluded from indexing.
|
||||
|
||||
Returns exclusion reason string, or None if file should be indexed.
|
||||
"""
|
||||
# Extension check
|
||||
suffix = file_path.suffix.lower()
|
||||
# Handle compound extensions like .min.js
|
||||
name_lower = file_path.name.lower()
|
||||
for ext in config.exclude_extensions:
|
||||
if name_lower.endswith(ext):
|
||||
return f"excluded extension: {ext}"
|
||||
|
||||
# File size check
|
||||
try:
|
||||
size = file_path.stat().st_size
|
||||
except OSError:
|
||||
return "cannot stat file"
|
||||
if size > config.max_file_size_bytes:
|
||||
return f"exceeds max size ({size} > {config.max_file_size_bytes})"
|
||||
if size == 0:
|
||||
return "empty file"
|
||||
|
||||
# Binary detection: sample first N bytes
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
sample = f.read(config.binary_detect_sample_bytes)
|
||||
except OSError:
|
||||
return "cannot read file"
|
||||
if sample:
|
||||
null_ratio = sample.count(b"\x00") / len(sample)
|
||||
if null_ratio > config.binary_null_threshold:
|
||||
return f"binary file (null ratio: {null_ratio:.2%})"
|
||||
|
||||
# Generated code markers (check first 1KB of text)
|
||||
try:
|
||||
head = file_path.read_text(encoding="utf-8", errors="replace")[:1024]
|
||||
except OSError:
|
||||
return None # can't check, let it through
|
||||
for marker in config.generated_code_markers:
|
||||
if marker in head:
|
||||
return f"generated code marker: {marker}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexStats:
|
||||
"""Statistics returned after indexing completes."""
|
||||
@@ -126,16 +173,19 @@ class IndexingPipeline:
|
||||
chunks_created = 0
|
||||
|
||||
for fpath in files:
|
||||
# Noise file filter
|
||||
exclude_reason = is_file_excluded(fpath, self._config)
|
||||
if exclude_reason:
|
||||
logger.debug("Skipping %s: %s", fpath, exclude_reason)
|
||||
continue
|
||||
try:
|
||||
if fpath.stat().st_size > max_file_size:
|
||||
continue
|
||||
text = fpath.read_text(encoding="utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
logger.debug("Skipping %s: %s", fpath, exc)
|
||||
continue
|
||||
|
||||
rel_path = str(fpath.relative_to(root)) if root else str(fpath)
|
||||
file_chunks = self._chunk_text(text, rel_path, max_chunk_chars, chunk_overlap)
|
||||
file_chunks = self._smart_chunk(text, rel_path, max_chunk_chars, chunk_overlap)
|
||||
|
||||
if not file_chunks:
|
||||
continue
|
||||
@@ -290,6 +340,106 @@ class IndexingPipeline:
|
||||
|
||||
return chunks
|
||||
|
||||
# Pattern matching top-level definitions across languages
|
||||
_CODE_BOUNDARY_RE = re.compile(
|
||||
r"^(?:"
|
||||
r"(?:export\s+)?(?:async\s+)?(?:def|class|function)\s+" # Python/JS/TS
|
||||
r"|(?:pub\s+)?(?:fn|struct|impl|enum|trait|mod)\s+" # Rust
|
||||
r"|(?:func|type)\s+" # Go
|
||||
r"|(?:public|private|protected|internal)?\s*(?:static\s+)?(?:class|interface|enum|record)\s+" # Java/C#
|
||||
r"|(?:namespace|template)\s+" # C++
|
||||
r")",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
def _chunk_code(
|
||||
self,
|
||||
text: str,
|
||||
path: str,
|
||||
max_chars: int,
|
||||
overlap: int,
|
||||
) -> list[tuple[str, str, int, int]]:
|
||||
"""Split code at function/class boundaries with fallback to _chunk_text.
|
||||
|
||||
Strategy:
|
||||
1. Find all top-level definition boundaries via regex.
|
||||
2. Split text into segments at those boundaries.
|
||||
3. Merge small adjacent segments up to max_chars.
|
||||
4. If a segment exceeds max_chars, fall back to _chunk_text for that segment.
|
||||
"""
|
||||
lines = text.splitlines(keepends=True)
|
||||
if not lines:
|
||||
return []
|
||||
|
||||
# Find boundary line numbers (0-based)
|
||||
boundaries: list[int] = [0] # always start at line 0
|
||||
for i, line in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
# Only match lines with no or minimal indentation (top-level)
|
||||
stripped = line.lstrip()
|
||||
indent = len(line) - len(stripped)
|
||||
if indent <= 4 and self._CODE_BOUNDARY_RE.match(stripped):
|
||||
boundaries.append(i)
|
||||
|
||||
if len(boundaries) <= 1:
|
||||
# No boundaries found, fall back to text chunking
|
||||
return self._chunk_text(text, path, max_chars, overlap)
|
||||
|
||||
# Build raw segments between boundaries
|
||||
raw_segments: list[tuple[int, int]] = [] # (start_line, end_line) 0-based
|
||||
for idx in range(len(boundaries)):
|
||||
start = boundaries[idx]
|
||||
end = boundaries[idx + 1] if idx + 1 < len(boundaries) else len(lines)
|
||||
raw_segments.append((start, end))
|
||||
|
||||
# Merge small adjacent segments up to max_chars
|
||||
merged: list[tuple[int, int]] = []
|
||||
cur_start, cur_end = raw_segments[0]
|
||||
cur_len = sum(len(lines[i]) for i in range(cur_start, cur_end))
|
||||
|
||||
for seg_start, seg_end in raw_segments[1:]:
|
||||
seg_len = sum(len(lines[i]) for i in range(seg_start, seg_end))
|
||||
if cur_len + seg_len <= max_chars:
|
||||
cur_end = seg_end
|
||||
cur_len += seg_len
|
||||
else:
|
||||
merged.append((cur_start, cur_end))
|
||||
cur_start, cur_end = seg_start, seg_end
|
||||
cur_len = seg_len
|
||||
merged.append((cur_start, cur_end))
|
||||
|
||||
# Build chunks, falling back to _chunk_text for oversized segments
|
||||
chunks: list[tuple[str, str, int, int]] = []
|
||||
for seg_start, seg_end in merged:
|
||||
seg_text = "".join(lines[seg_start:seg_end])
|
||||
if len(seg_text) > max_chars:
|
||||
# Oversized: sub-chunk with text splitter
|
||||
sub_chunks = self._chunk_text(seg_text, path, max_chars, overlap)
|
||||
# Adjust line numbers relative to segment start
|
||||
for chunk_text, p, sl, el in sub_chunks:
|
||||
chunks.append((chunk_text, p, sl + seg_start, el + seg_start))
|
||||
else:
|
||||
chunks.append((seg_text, path, seg_start + 1, seg_end))
|
||||
|
||||
return chunks
|
||||
|
||||
def _smart_chunk(
|
||||
self,
|
||||
text: str,
|
||||
path: str,
|
||||
max_chars: int,
|
||||
overlap: int,
|
||||
) -> list[tuple[str, str, int, int]]:
|
||||
"""Choose chunking strategy based on file type and config."""
|
||||
if self._config.code_aware_chunking:
|
||||
suffix = Path(path).suffix.lower()
|
||||
if suffix in self._config.code_extensions:
|
||||
result = self._chunk_code(text, path, max_chars, overlap)
|
||||
if result:
|
||||
return result
|
||||
return self._chunk_text(text, path, max_chars, overlap)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incremental API
|
||||
# ------------------------------------------------------------------
|
||||
@@ -342,11 +492,14 @@ class IndexingPipeline:
|
||||
meta = self._require_metadata()
|
||||
t0 = time.monotonic()
|
||||
|
||||
# Noise file filter
|
||||
exclude_reason = is_file_excluded(file_path, self._config)
|
||||
if exclude_reason:
|
||||
logger.debug("Skipping %s: %s", file_path, exclude_reason)
|
||||
return IndexStats(duration_seconds=round(time.monotonic() - t0, 2))
|
||||
|
||||
# Read file
|
||||
try:
|
||||
if file_path.stat().st_size > max_file_size:
|
||||
logger.debug("Skipping %s: exceeds max_file_size", file_path)
|
||||
return IndexStats(duration_seconds=round(time.monotonic() - t0, 2))
|
||||
text = file_path.read_text(encoding="utf-8", errors="replace")
|
||||
except Exception as exc:
|
||||
logger.debug("Skipping %s: %s", file_path, exc)
|
||||
@@ -366,7 +519,7 @@ class IndexingPipeline:
|
||||
self._fts.delete_by_path(rel_path)
|
||||
|
||||
# Chunk
|
||||
file_chunks = self._chunk_text(text, rel_path, max_chunk_chars, chunk_overlap)
|
||||
file_chunks = self._smart_chunk(text, rel_path, max_chunk_chars, chunk_overlap)
|
||||
if not file_chunks:
|
||||
# Register file with no chunks
|
||||
meta.register_file(rel_path, content_hash, file_path.stat().st_mtime)
|
||||
|
||||
@@ -21,6 +21,7 @@ _make_fastembed_mock()
|
||||
from codexlens_search.config import Config # noqa: E402
|
||||
from codexlens_search.embed.base import BaseEmbedder # noqa: E402
|
||||
from codexlens_search.embed.local import EMBED_PROFILES, FastEmbedEmbedder # noqa: E402
|
||||
from codexlens_search.embed.api import APIEmbedder # noqa: E402
|
||||
|
||||
|
||||
class TestEmbedSingle(unittest.TestCase):
|
||||
@@ -76,5 +77,182 @@ class TestBaseEmbedderAbstract(unittest.TestCase):
|
||||
BaseEmbedder() # type: ignore[abstract]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# APIEmbedder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_api_config(**overrides) -> Config:
|
||||
defaults = dict(
|
||||
embed_api_url="https://api.example.com/v1",
|
||||
embed_api_key="test-key",
|
||||
embed_api_model="text-embedding-3-small",
|
||||
embed_dim=384,
|
||||
embed_batch_size=2,
|
||||
embed_api_max_tokens_per_batch=8192,
|
||||
embed_api_concurrency=2,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return Config(**defaults)
|
||||
|
||||
|
||||
def _mock_200(count=1, dim=384):
|
||||
r = MagicMock()
|
||||
r.status_code = 200
|
||||
r.json.return_value = {
|
||||
"data": [{"index": j, "embedding": [0.1 * (j + 1)] * dim} for j in range(count)]
|
||||
}
|
||||
r.raise_for_status = MagicMock()
|
||||
return r
|
||||
|
||||
|
||||
class TestAPIEmbedderSingle(unittest.TestCase):
|
||||
def test_embed_single_returns_float32(self):
|
||||
config = _make_api_config()
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.return_value = _mock_200(1, 384)
|
||||
|
||||
embedder = APIEmbedder(config)
|
||||
result = embedder.embed_single("hello")
|
||||
|
||||
self.assertIsInstance(result, np.ndarray)
|
||||
self.assertEqual(result.dtype, np.float32)
|
||||
self.assertEqual(result.shape, (384,))
|
||||
|
||||
|
||||
class TestAPIEmbedderBatch(unittest.TestCase):
|
||||
def test_embed_batch_splits_by_batch_size(self):
|
||||
config = _make_api_config(embed_batch_size=2)
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.side_effect = [_mock_200(2, 384), _mock_200(1, 384)]
|
||||
|
||||
embedder = APIEmbedder(config)
|
||||
result = embedder.embed_batch(["a", "b", "c"])
|
||||
|
||||
self.assertEqual(len(result), 3)
|
||||
for arr in result:
|
||||
self.assertIsInstance(arr, np.ndarray)
|
||||
self.assertEqual(arr.dtype, np.float32)
|
||||
|
||||
def test_embed_batch_empty_returns_empty(self):
|
||||
config = _make_api_config()
|
||||
with patch("httpx.Client"):
|
||||
embedder = APIEmbedder(config)
|
||||
result = embedder.embed_batch([])
|
||||
self.assertEqual(result, [])
|
||||
|
||||
|
||||
class TestAPIEmbedderRetry(unittest.TestCase):
|
||||
def test_retry_on_429(self):
|
||||
config = _make_api_config()
|
||||
mock_429 = MagicMock()
|
||||
mock_429.status_code = 429
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.side_effect = [mock_429, _mock_200(1, 384)]
|
||||
|
||||
embedder = APIEmbedder(config)
|
||||
ep = embedder._endpoints[0]
|
||||
with patch("time.sleep"):
|
||||
result = embedder._call_api(["test"], ep)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(mock_client.post.call_count, 2)
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
config = _make_api_config()
|
||||
mock_429 = MagicMock()
|
||||
mock_429.status_code = 429
|
||||
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.return_value = mock_429
|
||||
|
||||
embedder = APIEmbedder(config)
|
||||
ep = embedder._endpoints[0]
|
||||
with patch("time.sleep"):
|
||||
with self.assertRaises(RuntimeError):
|
||||
embedder._call_api(["test"], ep, max_retries=2)
|
||||
|
||||
|
||||
class TestAPIEmbedderTokenPacking(unittest.TestCase):
|
||||
def test_packs_small_texts_together(self):
|
||||
config = _make_api_config(
|
||||
embed_batch_size=100,
|
||||
embed_api_max_tokens_per_batch=100, # ~400 chars
|
||||
)
|
||||
with patch("httpx.Client"):
|
||||
embedder = APIEmbedder(config)
|
||||
|
||||
# 5 texts of 80 chars each (~20 tokens) -> 100 tokens = 1 batch at limit
|
||||
texts = ["x" * 80] * 5
|
||||
batches = embedder._pack_batches(texts)
|
||||
# Should pack as many as fit under 100 tokens
|
||||
self.assertTrue(len(batches) >= 1)
|
||||
total_items = sum(len(b) for b in batches)
|
||||
self.assertEqual(total_items, 5)
|
||||
|
||||
def test_large_text_gets_own_batch(self):
|
||||
config = _make_api_config(
|
||||
embed_batch_size=100,
|
||||
embed_api_max_tokens_per_batch=50, # ~200 chars
|
||||
)
|
||||
with patch("httpx.Client"):
|
||||
embedder = APIEmbedder(config)
|
||||
|
||||
# Mix of small and large texts
|
||||
texts = ["small" * 10, "x" * 800, "tiny"]
|
||||
batches = embedder._pack_batches(texts)
|
||||
# Large text (200 tokens) exceeds 50 limit, should be separate
|
||||
self.assertTrue(len(batches) >= 2)
|
||||
|
||||
|
||||
class TestAPIEmbedderMultiEndpoint(unittest.TestCase):
|
||||
def test_multi_endpoint_config(self):
|
||||
config = _make_api_config(
|
||||
embed_api_endpoints=[
|
||||
{"url": "https://ep1.example.com/v1", "key": "k1", "model": "m1"},
|
||||
{"url": "https://ep2.example.com/v1", "key": "k2", "model": "m2"},
|
||||
]
|
||||
)
|
||||
with patch("httpx.Client"):
|
||||
embedder = APIEmbedder(config)
|
||||
self.assertEqual(len(embedder._endpoints), 2)
|
||||
self.assertTrue(embedder._endpoints[0].url.endswith("/embeddings"))
|
||||
self.assertTrue(embedder._endpoints[1].url.endswith("/embeddings"))
|
||||
|
||||
def test_single_endpoint_fallback(self):
|
||||
config = _make_api_config() # no embed_api_endpoints
|
||||
with patch("httpx.Client"):
|
||||
embedder = APIEmbedder(config)
|
||||
self.assertEqual(len(embedder._endpoints), 1)
|
||||
|
||||
|
||||
class TestAPIEmbedderUrlNormalization(unittest.TestCase):
|
||||
def test_appends_embeddings_path(self):
|
||||
config = _make_api_config(embed_api_url="https://api.example.com/v1")
|
||||
with patch("httpx.Client") as mock_client_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.post.return_value = _mock_200(1, 384)
|
||||
embedder = APIEmbedder(config)
|
||||
ep = embedder._endpoints[0]
|
||||
self.assertTrue(ep.url.endswith("/embeddings"))
|
||||
|
||||
def test_does_not_double_append(self):
|
||||
config = _make_api_config(embed_api_url="https://api.example.com/v1/embeddings")
|
||||
with patch("httpx.Client"):
|
||||
embedder = APIEmbedder(config)
|
||||
ep = embedder._endpoints[0]
|
||||
self.assertFalse(ep.url.endswith("/embeddings/embeddings"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user