From 6cd489f79d240963f11202c9bb4c3ac228f8c3c0 Mon Sep 17 00:00:00 2001 From: nikakhalatiani <124719066+nikakhalatiani@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:21:27 +0200 Subject: [PATCH] Retry oversized embedding requests (#1106) --- src/embeddings.py | 57 ++++++++++++++------ tests/test_embeddings_client.py | 95 +++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 16 deletions(-) create mode 100644 tests/test_embeddings_client.py diff --git a/src/embeddings.py b/src/embeddings.py index 746044c47..19cd09e43 100644 --- a/src/embeddings.py +++ b/src/embeddings.py @@ -55,6 +55,8 @@ class EmbeddingClient: # of stalling startup ~30s per probe. Read stays generous for a real # endpoint (embedding a short string returns in well under a second). self._client = httpx.Client(timeout=httpx.Timeout(connect=3.0, read=10.0, write=5.0, pool=3.0)) + self._batch_size = max(1, int(os.getenv("EMBEDDING_BATCH_SIZE", "8"))) + self._max_chars = max(200, int(os.getenv("EMBEDDING_MAX_CHARS", "900"))) def get_sentence_embedding_dimension(self) -> int: """Probe the endpoint for embedding dimension if not yet known.""" @@ -73,23 +75,10 @@ class EmbeddingClient: if not texts: return np.array([], dtype="float32") - # Batch in chunks of 64 to avoid oversized requests all_vecs = [] - for i in range(0, len(texts), 64): - batch = texts[i : i + 64] - resp = self._client.post( - self.url, - headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}, - json={"input": batch, "model": self.model}, - ) - resp.raise_for_status() - data = resp.json() - - # OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]} - embeddings = data.get("data", []) - embeddings.sort(key=lambda e: e.get("index", 0)) - for emb in embeddings: - all_vecs.append(emb["embedding"]) + for i in range(0, len(texts), self._batch_size): + batch = texts[i : i + self._batch_size] + all_vecs.extend(self._embed_batch(batch)) vecs = np.array(all_vecs, dtype="float32") @@ -103,6 +92,42 @@ class EmbeddingClient: return vecs + def _embed_batch(self, batch: List[str]) -> List[List[float]]: + try: + return self._post_embeddings(batch) + except httpx.HTTPStatusError as e: + status = e.response.status_code if e.response is not None else None + if status != 400: + raise + if len(batch) > 1: + vecs = [] + for text in batch: + vecs.extend(self._embed_batch([text])) + return vecs + text = batch[0] + trimmed = text[: self._max_chars] + if trimmed != text: + logger.warning( + "Embedding input exceeded endpoint context; retrying with %d chars", + len(trimmed), + ) + return self._post_embeddings([trimmed]) + raise + + def _post_embeddings(self, batch: List[str]) -> List[List[float]]: + resp = self._client.post( + self.url, + headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {}, + json={"input": batch, "model": self.model}, + ) + resp.raise_for_status() + data = resp.json() + + # OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]} + embeddings = data.get("data", []) + embeddings.sort(key=lambda e: e.get("index", 0)) + return [emb["embedding"] for emb in embeddings] + class FastEmbedClient: """Local embedding client using fastembed (ONNX). No external service needed.""" diff --git a/tests/test_embeddings_client.py b/tests/test_embeddings_client.py new file mode 100644 index 000000000..c3039bc44 --- /dev/null +++ b/tests/test_embeddings_client.py @@ -0,0 +1,95 @@ +import httpx +import pytest + +from src.embeddings import EmbeddingClient + + +class _FakeEmbeddingHttpClient: + def __init__(self, handler): + self.handler = handler + self.headers = [] + + def post(self, url, headers=None, json=None): + self.headers.append(headers or {}) + request = httpx.Request("POST", url) + status, body = self.handler(json) + return httpx.Response(status, request=request, json=body) + + +def test_embedding_400_batch_retry_falls_back_to_single_inputs(monkeypatch): + monkeypatch.setenv("EMBEDDING_BATCH_SIZE", "8") + calls = [] + + def handler(payload): + texts = payload["input"] + calls.append(list(texts)) + if len(texts) > 1: + return 400, {"error": "batch too large"} + text = texts[0] + return 200, {"data": [{"index": 0, "embedding": [float(len(text)), 1.0]}]} + + client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test") + client._client = _FakeEmbeddingHttpClient(handler) + + vecs = client.encode(["a", "bbbb"], normalize_embeddings=False) + + assert calls == [["a", "bbbb"], ["a"], ["bbbb"]] + assert vecs.tolist() == [[1.0, 1.0], [4.0, 1.0]] + + +def test_embedding_400_single_input_retries_with_truncated_text(monkeypatch): + monkeypatch.setenv("EMBEDDING_MAX_CHARS", "200") + lengths = [] + + def handler(payload): + text = payload["input"][0] + lengths.append(len(text)) + if len(text) > 200: + return 400, {"error": "context length exceeded"} + return 200, {"data": [{"index": 0, "embedding": [2.0, 0.0]}]} + + client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test") + client._client = _FakeEmbeddingHttpClient(handler) + + vecs = client.encode(["x" * 250], normalize_embeddings=False) + + assert lengths == [250, 200] + assert vecs.tolist() == [[2.0, 0.0]] + + +def test_embedding_non_400_errors_are_not_retried_or_swallowed(): + calls = 0 + + def handler(payload): + nonlocal calls + calls += 1 + return 500, {"error": "server error"} + + client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test") + client._client = _FakeEmbeddingHttpClient(handler) + + with pytest.raises(httpx.HTTPStatusError): + client.encode(["a"], normalize_embeddings=False) + + assert calls == 1 + + +def test_embedding_retry_path_preserves_api_key_header(): + seen_headers = [] + + def handler(payload): + return 200, {"data": [{"index": 0, "embedding": [1.0, 0.0]}]} + + client = EmbeddingClient( + url="http://embeddings.test/v1/embeddings", + model="embed-test", + api_key="secret-key", + ) + fake = _FakeEmbeddingHttpClient(handler) + client._client = fake + + vecs = client.encode(["a"], normalize_embeddings=False) + seen_headers.extend(fake.headers) + + assert vecs.tolist() == [[1.0, 0.0]] + assert seen_headers == [{"Authorization": "Bearer secret-key"}]