Files
odysseus/src/embedding_lanes.py
T
Mazen Tamer Salah 3c4ec8828b fix(embeddings): survive numpy embeddings when restoring a reset lane (#3410)
When a lane reset fails to rewrite the recreated collection, the recovery path
re-adds the preserved rows. It read the embeddings with
`preserved.get("embeddings") or []` and gated the loop with
`if ids and docs and old_embeddings:`. chromadb returns embeddings as a numpy
ndarray, whose truth value is ambiguous, so both expressions raise ValueError
inside the except block — the restore is abandoned and every preserved row is
lost (the collection was already deleted), exactly when the code is trying to
avoid data loss.

Use an explicit `is None` check and `len(...)`, and convert ndarray batches to
lists before re-adding.

Adds tests/test_embedding_lane_ndarray_restore.py (preserved embeddings come
back as np.ndarray); existing test_embedding_lanes.py still passes.
2026-06-09 10:40:17 +02:00

390 lines
14 KiB
Python

"""
embedding_lanes.py
Helpers for keeping FastEmbed fallback vectors separate from user-configured
embedding vectors. ChromaDB fixes a collection's dimension on first insert, so
different embedding models must never share one collection.
"""
from __future__ import annotations
from dataclasses import dataclass
import hashlib
import logging
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
logger = logging.getLogger(__name__)
LANE_FASTEMBED = "fastembed"
LANE_CUSTOM = "custom"
@dataclass
class EmbeddingLane:
name: str
client: Any
collection: Any
collection_name: str
model: str
url: str
dimension: int
fingerprint: str
@property
def healthy(self) -> bool:
return self.collection is not None and self.client is not None
def encode(self, texts: Sequence[str]) -> List[List[float]]:
vecs = self.client.encode(list(texts), normalize_embeddings=True)
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
def count(self) -> int:
try:
return int(self.collection.count())
except Exception:
return 0
def stats(self) -> Dict[str, Any]:
return {
"name": self.name,
"collection": self.collection_name,
"model": self.model,
"url": self.url,
"dimension": self.dimension,
"fingerprint": self.fingerprint,
"count": self.count(),
"healthy": self.healthy,
}
def reset_embedding_lane_state() -> None:
"""Reset process-local embedding lane state after endpoint config changes."""
try:
from src.embeddings import reset_http_embed_state
reset_http_embed_state()
except Exception:
pass
def collection_name(base_name: str, lane_name: str) -> str:
return f"{base_name}_{lane_name}"
def _fingerprint(lane_name: str, url: str, model: str, dimension: int) -> str:
raw = f"{lane_name}\n{url}\n{model}\n{dimension}"
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
def _metadata(lane_name: str, url: str, model: str, dimension: int, fingerprint: str) -> Dict[str, Any]:
return {
"hnsw:space": "cosine",
"embedding_lane": lane_name,
"embedding_url": url,
"embedding_model": model,
"embedding_dimension": dimension,
"embedding_fingerprint": fingerprint,
}
def _load_custom_endpoint() -> Dict[str, str]:
try:
from src.embeddings import _load_persisted_endpoint
persisted = _load_persisted_endpoint()
except Exception:
persisted = {}
url = persisted.get("url") or os.environ.get("EMBEDDING_URL", "")
if not url:
return {}
model = persisted.get("model") or os.environ.get("EMBEDDING_MODEL", "")
api_key = persisted.get("api_key") or os.environ.get("EMBEDDING_API_KEY", "")
if persisted.get("api_key"):
try:
from src.secret_storage import decrypt
api_key = decrypt(api_key)
except Exception:
logger.warning("Could not decrypt saved embedding endpoint API key")
api_key = ""
return {"url": url, "model": model, "api_key": api_key}
def _build_fastembed_client():
from src.embeddings import FastEmbedClient
client = FastEmbedClient()
client.get_sentence_embedding_dimension()
return client
def _build_custom_client():
from src.embeddings import EmbeddingClient, get_embedding_client
client = get_embedding_client()
if isinstance(client, EmbeddingClient):
return client
raise RuntimeError("HTTP embedding lane unavailable")
def _encode_with_client(client: Any, texts: Sequence[str]) -> List[List[float]]:
vecs = client.encode(list(texts), normalize_embeddings=True)
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
def _get_or_reset_collection(chroma_client, name: str, metadata: Dict[str, Any], client: Any):
try:
collection = chroma_client.get_collection(name)
except Exception:
return chroma_client.get_or_create_collection(name=name, metadata=metadata)
current = collection.metadata or {}
if not (
current.get("embedding_fingerprint") not in (None, metadata["embedding_fingerprint"])
or current.get("embedding_dimension") not in (None, metadata["embedding_dimension"])
or current.get("embedding_lane") not in (None, metadata["embedding_lane"])
):
return collection
logger.info(
"Recreating Chroma collection %s for embedding lane change (%s -> %s)",
name,
current.get("embedding_fingerprint"),
metadata["embedding_fingerprint"],
)
preserved = {"ids": [], "documents": [], "metadatas": [], "embeddings": []}
try:
preserved = collection.get(include=["documents", "metadatas", "embeddings"]) or preserved
except Exception as e:
raise RuntimeError(f"Could not preserve documents before resetting {name}: {e}") from e
ids = preserved.get("ids") or []
docs = preserved.get("documents") or []
metas = preserved.get("metadatas") or []
prepared_batches = []
if ids and docs:
try:
for start in range(0, len(ids), 100):
batch_ids = ids[start:start + 100]
batch_docs = docs[start:start + 100]
batch_metas = metas[start:start + 100]
if len(batch_metas) < len(batch_ids):
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
prepared_batches.append((
batch_ids,
batch_docs,
batch_metas,
_encode_with_client(client, batch_docs),
))
except Exception as e:
raise RuntimeError(f"Could not re-embed preserved rows for {name}: {e}") from e
chroma_client.delete_collection(name)
collection = chroma_client.get_or_create_collection(name=name, metadata=metadata)
try:
for batch_ids, batch_docs, batch_metas, embeddings in prepared_batches:
collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metas,
embeddings=embeddings,
)
except Exception as e:
logger.warning("Could not write reset collection %s; restoring previous rows: %s", name, e)
try:
chroma_client.delete_collection(name)
restored = chroma_client.get_or_create_collection(name=name, metadata=current)
# chromadb returns embeddings as a numpy ndarray, whose truth value
# is ambiguous — `preserved.get("embeddings") or []` and a bare
# `if ... and old_embeddings:` both raise ValueError, which aborts
# the restore and loses the rows the reset was supposed to keep.
# Use explicit None/len checks instead.
old_embeddings = preserved.get("embeddings")
if old_embeddings is None:
old_embeddings = []
if ids and docs and len(old_embeddings):
for start in range(0, len(ids), 100):
batch_ids = ids[start:start + 100]
batch_docs = docs[start:start + 100]
batch_metas = metas[start:start + 100]
batch_embeddings = old_embeddings[start:start + 100]
if hasattr(batch_embeddings, "tolist"):
batch_embeddings = batch_embeddings.tolist()
if len(batch_metas) < len(batch_ids):
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
restored.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metas,
embeddings=batch_embeddings,
)
except Exception as restore_error:
logger.warning("Could not restore previous collection %s: %s", name, restore_error)
raise RuntimeError(f"Could not write reset collection {name}: {e}") from e
if prepared_batches:
logger.info("Re-embedded %s rows after resetting %s", len(ids), name)
return collection
def _create_lane(chroma_client, base_name: str, lane_name: str, client: Any) -> EmbeddingLane:
dimension = int(client.get_sentence_embedding_dimension())
model = getattr(client, "model", "")
url = getattr(client, "url", "")
fp = _fingerprint(lane_name, url, model, dimension)
name = collection_name(base_name, lane_name)
metadata = _metadata(lane_name, url, model, dimension, fp)
collection = _get_or_reset_collection(chroma_client, name, metadata, client)
return EmbeddingLane(
name=lane_name,
client=client,
collection=collection,
collection_name=name,
model=model,
url=url,
dimension=dimension,
fingerprint=fp,
)
def build_embedding_lanes(base_name: str) -> List[EmbeddingLane]:
"""Return healthy lanes in retrieval preference order: custom, fastembed."""
from src.chroma_client import get_chroma_client
chroma_client = get_chroma_client()
lanes: List[EmbeddingLane] = []
try:
custom = _build_custom_client()
if custom is not None:
lanes.append(_create_lane(chroma_client, base_name, LANE_CUSTOM, custom))
except Exception as e:
logger.warning("Custom embedding lane unavailable for %s: %s", base_name, e)
try:
fastembed = _build_fastembed_client()
lanes.append(_create_lane(chroma_client, base_name, LANE_FASTEMBED, fastembed))
except Exception as e:
logger.warning("FastEmbed lane unavailable for %s: %s", base_name, e)
return lanes
def migrate_legacy_collection(base_name: str, lanes: Sequence[EmbeddingLane]) -> None:
"""Backfill empty lanes from a legacy unsuffixed collection, if present."""
if not lanes:
return
try:
from src.chroma_client import get_chroma_client
chroma_client = get_chroma_client()
legacy = chroma_client.get_collection(base_name)
data = legacy.get(include=["documents", "metadatas"])
except Exception:
return
ids = data.get("ids") or []
docs = data.get("documents") or []
metas = data.get("metadatas") or []
if not ids or not docs:
return
for lane in lanes:
try:
existing = lane.collection.get(ids=ids)
existing_ids = set(existing.get("ids") or [])
except Exception:
existing_ids = set()
all_metas = list(metas or [])
if len(all_metas) < len(ids):
all_metas += [{}] * (len(ids) - len(all_metas))
missing = [
(row_id, doc, meta)
for row_id, doc, meta in zip(ids, docs, all_metas)
if row_id not in existing_ids
]
if not missing:
continue
for start in range(0, len(missing), 100):
batch = missing[start:start + 100]
batch_ids = [row_id for row_id, _doc, _meta in batch]
batch_docs = [doc for _row_id, doc, _meta in batch]
batch_metas = [meta or {} for _row_id, _doc, meta in batch]
if len(batch_metas) < len(batch_ids):
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
try:
embeddings = lane.encode(batch_docs)
lane.collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metas,
embeddings=embeddings,
)
except Exception as e:
logger.warning(
"Could not backfill %s lane from legacy collection %s: %s",
lane.name,
base_name,
e,
)
break
else:
logger.info("Backfilled %s %s lane rows from legacy collection %s", len(missing), lane.name, base_name)
def lane_count(lanes: Sequence[EmbeddingLane]) -> int:
return max((lane.count() for lane in lanes), default=0)
def dedupe_results(results: Iterable[Dict[str, Any]], id_key: str = "id", limit: Optional[int] = None) -> List[Dict[str, Any]]:
seen = set()
out: List[Dict[str, Any]] = []
for row in results:
row_id = row.get(id_key)
if not row_id or row_id in seen:
continue
seen.add(row_id)
out.append(row)
if limit is not None and len(out) >= limit:
break
return out
def query_lanes(
lanes: Sequence[EmbeddingLane],
query: str,
n_results: Callable[[EmbeddingLane], int],
include: Sequence[str],
where: Optional[Dict[str, Any]] = None,
raise_if_all_failed: bool = False,
) -> List[tuple[EmbeddingLane, Dict[str, Any]]]:
out: List[tuple[EmbeddingLane, Dict[str, Any]]] = []
attempted = 0
failures: List[str] = []
for lane in lanes:
try:
count = lane.count()
if count == 0:
continue
attempted += 1
n = min(n_results(lane), count)
if n <= 0:
continue
results = lane.collection.query(
query_embeddings=lane.encode([query]),
n_results=n,
where=where,
include=list(include),
)
out.append((lane, results))
except Exception as e:
failures.append(f"{lane.name}: {e}")
logger.warning("%s lane query failed for %s: %s", lane.name, lane.collection_name, e)
if raise_if_all_failed and attempted and not out and failures:
raise RuntimeError("; ".join(failures))
return out