mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-29 08:02:06 -04:00
Retry oversized embedding requests (#1106)
This commit is contained in:
+41
-16
@@ -55,6 +55,8 @@ class EmbeddingClient:
|
||||
# of stalling startup ~30s per probe. Read stays generous for a real
|
||||
# endpoint (embedding a short string returns in well under a second).
|
||||
self._client = httpx.Client(timeout=httpx.Timeout(connect=3.0, read=10.0, write=5.0, pool=3.0))
|
||||
self._batch_size = max(1, int(os.getenv("EMBEDDING_BATCH_SIZE", "8")))
|
||||
self._max_chars = max(200, int(os.getenv("EMBEDDING_MAX_CHARS", "900")))
|
||||
|
||||
def get_sentence_embedding_dimension(self) -> int:
|
||||
"""Probe the endpoint for embedding dimension if not yet known."""
|
||||
@@ -73,23 +75,10 @@ class EmbeddingClient:
|
||||
if not texts:
|
||||
return np.array([], dtype="float32")
|
||||
|
||||
# Batch in chunks of 64 to avoid oversized requests
|
||||
all_vecs = []
|
||||
for i in range(0, len(texts), 64):
|
||||
batch = texts[i : i + 64]
|
||||
resp = self._client.post(
|
||||
self.url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||
json={"input": batch, "model": self.model},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||
embeddings = data.get("data", [])
|
||||
embeddings.sort(key=lambda e: e.get("index", 0))
|
||||
for emb in embeddings:
|
||||
all_vecs.append(emb["embedding"])
|
||||
for i in range(0, len(texts), self._batch_size):
|
||||
batch = texts[i : i + self._batch_size]
|
||||
all_vecs.extend(self._embed_batch(batch))
|
||||
|
||||
vecs = np.array(all_vecs, dtype="float32")
|
||||
|
||||
@@ -103,6 +92,42 @@ class EmbeddingClient:
|
||||
|
||||
return vecs
|
||||
|
||||
def _embed_batch(self, batch: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
return self._post_embeddings(batch)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else None
|
||||
if status != 400:
|
||||
raise
|
||||
if len(batch) > 1:
|
||||
vecs = []
|
||||
for text in batch:
|
||||
vecs.extend(self._embed_batch([text]))
|
||||
return vecs
|
||||
text = batch[0]
|
||||
trimmed = text[: self._max_chars]
|
||||
if trimmed != text:
|
||||
logger.warning(
|
||||
"Embedding input exceeded endpoint context; retrying with %d chars",
|
||||
len(trimmed),
|
||||
)
|
||||
return self._post_embeddings([trimmed])
|
||||
raise
|
||||
|
||||
def _post_embeddings(self, batch: List[str]) -> List[List[float]]:
|
||||
resp = self._client.post(
|
||||
self.url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"} if self.api_key else {},
|
||||
json={"input": batch, "model": self.model},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# OpenAI format: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||
embeddings = data.get("data", [])
|
||||
embeddings.sort(key=lambda e: e.get("index", 0))
|
||||
return [emb["embedding"] for emb in embeddings]
|
||||
|
||||
|
||||
class FastEmbedClient:
|
||||
"""Local embedding client using fastembed (ONNX). No external service needed."""
|
||||
|
||||
Reference in New Issue
Block a user