mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-29 16:12:06 -04:00
Retry oversized embedding requests (#1106)
This commit is contained in:
+41
-16
@@ -55,6 +55,8 @@ class EmbeddingClient:
|
|||||||
# of stalling startup ~30s per probe. Read stays generous for a real
|
# of stalling startup ~30s per probe. Read stays generous for a real
|
||||||
# endpoint (embedding a short string returns in well under a second).
|
# endpoint (embedding a short string returns in well under a second).
|
||||||
self._client = httpx.Client(timeout=httpx.Timeout(connect=3.0, read=10.0, write=5.0, pool=3.0))
|
self._client = httpx.Client(timeout=httpx.Timeout(connect=3.0, read=10.0, write=5.0, pool=3.0))
|
||||||
|
self._batch_size = max(1, int(os.getenv("EMBEDDING_BATCH_SIZE", "8")))
|
||||||
|
self._max_chars = max(200, int(os.getenv("EMBEDDING_MAX_CHARS", "900")))
|
||||||
|
|
||||||
def get_sentence_embedding_dimension(self) -> int:
|
def get_sentence_embedding_dimension(self) -> int:
|
||||||
"""Probe the endpoint for embedding dimension if not yet known."""
|
"""Probe the endpoint for embedding dimension if not yet known."""
|
||||||
@@ -73,23 +75,10 @@ class EmbeddingClient:
|
|||||||
if not texts:
|
if not texts:
|
||||||
return np.array([], dtype="float32")
|
return np.array([], dtype="float32")
|
||||||
|
|
||||||
# Batch in chunks of 64 to avoid oversized requests
|
|
||||||
all_vecs = []
|
all_vecs = []
|
||||||
for i in range(0, len(texts), 64):
|
for i in range(0, len(texts), self._batch_size):
|
||||||
batch = texts[i : i + 64]
|
batch = texts[i : i + self._batch_size]
|
||||||
resp = self._client.post(
|
all_vecs.extend(self._embed_batch(batch))
|
||||||
self.url,
|
|
||||||
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
|
||||||
json={"input": batch, "model": self.model},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
|
|
||||||
# OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]}
|
|
||||||
embeddings = data.get("data", [])
|
|
||||||
embeddings.sort(key=lambda e: e.get("index", 0))
|
|
||||||
for emb in embeddings:
|
|
||||||
all_vecs.append(emb["embedding"])
|
|
||||||
|
|
||||||
vecs = np.array(all_vecs, dtype="float32")
|
vecs = np.array(all_vecs, dtype="float32")
|
||||||
|
|
||||||
@@ -103,6 +92,42 @@ class EmbeddingClient:
|
|||||||
|
|
||||||
return vecs
|
return vecs
|
||||||
|
|
||||||
|
def _embed_batch(self, batch: List[str]) -> List[List[float]]:
|
||||||
|
try:
|
||||||
|
return self._post_embeddings(batch)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
status = e.response.status_code if e.response is not None else None
|
||||||
|
if status != 400:
|
||||||
|
raise
|
||||||
|
if len(batch) > 1:
|
||||||
|
vecs = []
|
||||||
|
for text in batch:
|
||||||
|
vecs.extend(self._embed_batch([text]))
|
||||||
|
return vecs
|
||||||
|
text = batch[0]
|
||||||
|
trimmed = text[: self._max_chars]
|
||||||
|
if trimmed != text:
|
||||||
|
logger.warning(
|
||||||
|
"Embedding input exceeded endpoint context; retrying with %d chars",
|
||||||
|
len(trimmed),
|
||||||
|
)
|
||||||
|
return self._post_embeddings([trimmed])
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _post_embeddings(self, batch: List[str]) -> List[List[float]]:
|
||||||
|
resp = self._client.post(
|
||||||
|
self.url,
|
||||||
|
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||||
|
json={"input": batch, "model": self.model},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
# OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||||
|
embeddings = data.get("data", [])
|
||||||
|
embeddings.sort(key=lambda e: e.get("index", 0))
|
||||||
|
return [emb["embedding"] for emb in embeddings]
|
||||||
|
|
||||||
|
|
||||||
class FastEmbedClient:
|
class FastEmbedClient:
|
||||||
"""Local embedding client using fastembed (ONNX). No external service needed."""
|
"""Local embedding client using fastembed (ONNX). No external service needed."""
|
||||||
|
|||||||
@@ -0,0 +1,95 @@
|
|||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.embeddings import EmbeddingClient
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEmbeddingHttpClient:
|
||||||
|
def __init__(self, handler):
|
||||||
|
self.handler = handler
|
||||||
|
self.headers = []
|
||||||
|
|
||||||
|
def post(self, url, headers=None, json=None):
|
||||||
|
self.headers.append(headers or {})
|
||||||
|
request = httpx.Request("POST", url)
|
||||||
|
status, body = self.handler(json)
|
||||||
|
return httpx.Response(status, request=request, json=body)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_400_batch_retry_falls_back_to_single_inputs(monkeypatch):
|
||||||
|
monkeypatch.setenv("EMBEDDING_BATCH_SIZE", "8")
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def handler(payload):
|
||||||
|
texts = payload["input"]
|
||||||
|
calls.append(list(texts))
|
||||||
|
if len(texts) > 1:
|
||||||
|
return 400, {"error": "batch too large"}
|
||||||
|
text = texts[0]
|
||||||
|
return 200, {"data": [{"index": 0, "embedding": [float(len(text)), 1.0]}]}
|
||||||
|
|
||||||
|
client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test")
|
||||||
|
client._client = _FakeEmbeddingHttpClient(handler)
|
||||||
|
|
||||||
|
vecs = client.encode(["a", "bbbb"], normalize_embeddings=False)
|
||||||
|
|
||||||
|
assert calls == [["a", "bbbb"], ["a"], ["bbbb"]]
|
||||||
|
assert vecs.tolist() == [[1.0, 1.0], [4.0, 1.0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_400_single_input_retries_with_truncated_text(monkeypatch):
|
||||||
|
monkeypatch.setenv("EMBEDDING_MAX_CHARS", "200")
|
||||||
|
lengths = []
|
||||||
|
|
||||||
|
def handler(payload):
|
||||||
|
text = payload["input"][0]
|
||||||
|
lengths.append(len(text))
|
||||||
|
if len(text) > 200:
|
||||||
|
return 400, {"error": "context length exceeded"}
|
||||||
|
return 200, {"data": [{"index": 0, "embedding": [2.0, 0.0]}]}
|
||||||
|
|
||||||
|
client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test")
|
||||||
|
client._client = _FakeEmbeddingHttpClient(handler)
|
||||||
|
|
||||||
|
vecs = client.encode(["x" * 250], normalize_embeddings=False)
|
||||||
|
|
||||||
|
assert lengths == [250, 200]
|
||||||
|
assert vecs.tolist() == [[2.0, 0.0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_non_400_errors_are_not_retried_or_swallowed():
|
||||||
|
calls = 0
|
||||||
|
|
||||||
|
def handler(payload):
|
||||||
|
nonlocal calls
|
||||||
|
calls += 1
|
||||||
|
return 500, {"error": "server error"}
|
||||||
|
|
||||||
|
client = EmbeddingClient(url="http://embeddings.test/v1/embeddings", model="embed-test")
|
||||||
|
client._client = _FakeEmbeddingHttpClient(handler)
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError):
|
||||||
|
client.encode(["a"], normalize_embeddings=False)
|
||||||
|
|
||||||
|
assert calls == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_retry_path_preserves_api_key_header():
|
||||||
|
seen_headers = []
|
||||||
|
|
||||||
|
def handler(payload):
|
||||||
|
return 200, {"data": [{"index": 0, "embedding": [1.0, 0.0]}]}
|
||||||
|
|
||||||
|
client = EmbeddingClient(
|
||||||
|
url="http://embeddings.test/v1/embeddings",
|
||||||
|
model="embed-test",
|
||||||
|
api_key="secret-key",
|
||||||
|
)
|
||||||
|
fake = _FakeEmbeddingHttpClient(handler)
|
||||||
|
client._client = fake
|
||||||
|
|
||||||
|
vecs = client.encode(["a"], normalize_embeddings=False)
|
||||||
|
seen_headers.extend(fake.headers)
|
||||||
|
|
||||||
|
assert vecs.tolist() == [[1.0, 0.0]]
|
||||||
|
assert seen_headers == [{"Authorization": "Bearer secret-key"}]
|
||||||
Reference in New Issue
Block a user