Files
Claude-Code-Workflow/codex-lens-v2/tests/unit/test_embed.py
catlog22 f37189dc64 feat: add APIEmbedder for remote embedding with multi-endpoint support
- Introduced APIEmbedder class to handle embeddings via a remote HTTP API.
- Implemented token packing to optimize batch sizes based on token limits.
- Added support for multiple API endpoints with round-robin dispatching.
- Included retry logic for API calls with exponential backoff on failures.
- Enhanced indexing pipeline with file exclusion checks and smart chunking strategies.
- Updated tests to cover new APIEmbedder functionality and ensure robustness.
2026-03-17 17:17:24 +08:00

259 lines
9.0 KiB
Python

from __future__ import annotations
import sys
import types
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
def _make_fastembed_mock():
"""Build a minimal fastembed stub so imports succeed without the real package."""
fastembed_mod = types.ModuleType("fastembed")
fastembed_mod.TextEmbedding = MagicMock()
sys.modules.setdefault("fastembed", fastembed_mod)
return fastembed_mod
_make_fastembed_mock()
from codexlens_search.config import Config # noqa: E402
from codexlens_search.embed.base import BaseEmbedder # noqa: E402
from codexlens_search.embed.local import EMBED_PROFILES, FastEmbedEmbedder # noqa: E402
from codexlens_search.embed.api import APIEmbedder # noqa: E402
class TestEmbedSingle(unittest.TestCase):
def test_embed_single_returns_float32_ndarray(self):
config = Config()
embedder = FastEmbedEmbedder(config)
mock_model = MagicMock()
mock_model.embed.return_value = iter([np.ones(384, dtype=np.float64)])
# Inject mock model directly to bypass lazy load (no real fastembed needed)
embedder._model = mock_model
result = embedder.embed_single("hello world")
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.dtype, np.float32)
self.assertEqual(result.shape, (384,))
class TestEmbedBatch(unittest.TestCase):
def test_embed_batch_returns_list(self):
config = Config()
embedder = FastEmbedEmbedder(config)
vecs = [np.ones(384, dtype=np.float64) * i for i in range(3)]
mock_model = MagicMock()
mock_model.embed.return_value = iter(vecs)
embedder._model = mock_model
result = embedder.embed_batch(["a", "b", "c"])
self.assertIsInstance(result, list)
self.assertEqual(len(result), 3)
for arr in result:
self.assertIsInstance(arr, np.ndarray)
self.assertEqual(arr.dtype, np.float32)
class TestEmbedProfiles(unittest.TestCase):
def test_embed_profiles_all_have_valid_keys(self):
expected_keys = {"small", "base", "large", "code"}
self.assertEqual(set(EMBED_PROFILES.keys()), expected_keys)
def test_embed_profiles_model_ids_non_empty(self):
for key, model_id in EMBED_PROFILES.items():
self.assertIsInstance(model_id, str, msg=f"{key} model id should be str")
self.assertTrue(len(model_id) > 0, msg=f"{key} model id should be non-empty")
class TestBaseEmbedderAbstract(unittest.TestCase):
def test_base_embedder_is_abstract(self):
with self.assertRaises(TypeError):
BaseEmbedder() # type: ignore[abstract]
# ---------------------------------------------------------------------------
# APIEmbedder
# ---------------------------------------------------------------------------
def _make_api_config(**overrides) -> Config:
defaults = dict(
embed_api_url="https://api.example.com/v1",
embed_api_key="test-key",
embed_api_model="text-embedding-3-small",
embed_dim=384,
embed_batch_size=2,
embed_api_max_tokens_per_batch=8192,
embed_api_concurrency=2,
)
defaults.update(overrides)
return Config(**defaults)
def _mock_200(count=1, dim=384):
r = MagicMock()
r.status_code = 200
r.json.return_value = {
"data": [{"index": j, "embedding": [0.1 * (j + 1)] * dim} for j in range(count)]
}
r.raise_for_status = MagicMock()
return r
class TestAPIEmbedderSingle(unittest.TestCase):
def test_embed_single_returns_float32(self):
config = _make_api_config()
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.return_value = _mock_200(1, 384)
embedder = APIEmbedder(config)
result = embedder.embed_single("hello")
self.assertIsInstance(result, np.ndarray)
self.assertEqual(result.dtype, np.float32)
self.assertEqual(result.shape, (384,))
class TestAPIEmbedderBatch(unittest.TestCase):
def test_embed_batch_splits_by_batch_size(self):
config = _make_api_config(embed_batch_size=2)
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [_mock_200(2, 384), _mock_200(1, 384)]
embedder = APIEmbedder(config)
result = embedder.embed_batch(["a", "b", "c"])
self.assertEqual(len(result), 3)
for arr in result:
self.assertIsInstance(arr, np.ndarray)
self.assertEqual(arr.dtype, np.float32)
def test_embed_batch_empty_returns_empty(self):
config = _make_api_config()
with patch("httpx.Client"):
embedder = APIEmbedder(config)
result = embedder.embed_batch([])
self.assertEqual(result, [])
class TestAPIEmbedderRetry(unittest.TestCase):
def test_retry_on_429(self):
config = _make_api_config()
mock_429 = MagicMock()
mock_429.status_code = 429
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.side_effect = [mock_429, _mock_200(1, 384)]
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
with patch("time.sleep"):
result = embedder._call_api(["test"], ep)
self.assertEqual(len(result), 1)
self.assertEqual(mock_client.post.call_count, 2)
def test_raises_after_max_retries(self):
config = _make_api_config()
mock_429 = MagicMock()
mock_429.status_code = 429
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.return_value = mock_429
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
with patch("time.sleep"):
with self.assertRaises(RuntimeError):
embedder._call_api(["test"], ep, max_retries=2)
class TestAPIEmbedderTokenPacking(unittest.TestCase):
def test_packs_small_texts_together(self):
config = _make_api_config(
embed_batch_size=100,
embed_api_max_tokens_per_batch=100, # ~400 chars
)
with patch("httpx.Client"):
embedder = APIEmbedder(config)
# 5 texts of 80 chars each (~20 tokens) -> 100 tokens = 1 batch at limit
texts = ["x" * 80] * 5
batches = embedder._pack_batches(texts)
# Should pack as many as fit under 100 tokens
self.assertTrue(len(batches) >= 1)
total_items = sum(len(b) for b in batches)
self.assertEqual(total_items, 5)
def test_large_text_gets_own_batch(self):
config = _make_api_config(
embed_batch_size=100,
embed_api_max_tokens_per_batch=50, # ~200 chars
)
with patch("httpx.Client"):
embedder = APIEmbedder(config)
# Mix of small and large texts
texts = ["small" * 10, "x" * 800, "tiny"]
batches = embedder._pack_batches(texts)
# Large text (200 tokens) exceeds 50 limit, should be separate
self.assertTrue(len(batches) >= 2)
class TestAPIEmbedderMultiEndpoint(unittest.TestCase):
def test_multi_endpoint_config(self):
config = _make_api_config(
embed_api_endpoints=[
{"url": "https://ep1.example.com/v1", "key": "k1", "model": "m1"},
{"url": "https://ep2.example.com/v1", "key": "k2", "model": "m2"},
]
)
with patch("httpx.Client"):
embedder = APIEmbedder(config)
self.assertEqual(len(embedder._endpoints), 2)
self.assertTrue(embedder._endpoints[0].url.endswith("/embeddings"))
self.assertTrue(embedder._endpoints[1].url.endswith("/embeddings"))
def test_single_endpoint_fallback(self):
config = _make_api_config() # no embed_api_endpoints
with patch("httpx.Client"):
embedder = APIEmbedder(config)
self.assertEqual(len(embedder._endpoints), 1)
class TestAPIEmbedderUrlNormalization(unittest.TestCase):
def test_appends_embeddings_path(self):
config = _make_api_config(embed_api_url="https://api.example.com/v1")
with patch("httpx.Client") as mock_client_cls:
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.post.return_value = _mock_200(1, 384)
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
self.assertTrue(ep.url.endswith("/embeddings"))
def test_does_not_double_append(self):
config = _make_api_config(embed_api_url="https://api.example.com/v1/embeddings")
with patch("httpx.Client"):
embedder = APIEmbedder(config)
ep = embedder._endpoints[0]
self.assertFalse(ep.url.endswith("/embeddings/embeddings"))
if __name__ == "__main__":
unittest.main()