mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 18:25:26 -04:00
fix: split Chroma embedding lanes (#3046)
This commit is contained in:
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
embedding_lanes.py
|
||||
|
||||
Helpers for keeping FastEmbed fallback vectors separate from user-configured
|
||||
embedding vectors. ChromaDB fixes a collection's dimension on first insert, so
|
||||
different embedding models must never share one collection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LANE_FASTEMBED = "fastembed"
|
||||
LANE_CUSTOM = "custom"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingLane:
|
||||
name: str
|
||||
client: Any
|
||||
collection: Any
|
||||
collection_name: str
|
||||
model: str
|
||||
url: str
|
||||
dimension: int
|
||||
fingerprint: str
|
||||
|
||||
@property
|
||||
def healthy(self) -> bool:
|
||||
return self.collection is not None and self.client is not None
|
||||
|
||||
def encode(self, texts: Sequence[str]) -> List[List[float]]:
|
||||
vecs = self.client.encode(list(texts), normalize_embeddings=True)
|
||||
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
|
||||
|
||||
def count(self) -> int:
|
||||
try:
|
||||
return int(self.collection.count())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"collection": self.collection_name,
|
||||
"model": self.model,
|
||||
"url": self.url,
|
||||
"dimension": self.dimension,
|
||||
"fingerprint": self.fingerprint,
|
||||
"count": self.count(),
|
||||
"healthy": self.healthy,
|
||||
}
|
||||
|
||||
|
||||
def reset_embedding_lane_state() -> None:
|
||||
"""Reset process-local embedding lane state after endpoint config changes."""
|
||||
try:
|
||||
from src.embeddings import reset_http_embed_state
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def collection_name(base_name: str, lane_name: str) -> str:
|
||||
return f"{base_name}_{lane_name}"
|
||||
|
||||
|
||||
def _fingerprint(lane_name: str, url: str, model: str, dimension: int) -> str:
|
||||
raw = f"{lane_name}\n{url}\n{model}\n{dimension}"
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
|
||||
def _metadata(lane_name: str, url: str, model: str, dimension: int, fingerprint: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"hnsw:space": "cosine",
|
||||
"embedding_lane": lane_name,
|
||||
"embedding_url": url,
|
||||
"embedding_model": model,
|
||||
"embedding_dimension": dimension,
|
||||
"embedding_fingerprint": fingerprint,
|
||||
}
|
||||
|
||||
|
||||
def _load_custom_endpoint() -> Dict[str, str]:
|
||||
try:
|
||||
from src.embeddings import _load_persisted_endpoint
|
||||
persisted = _load_persisted_endpoint()
|
||||
except Exception:
|
||||
persisted = {}
|
||||
|
||||
url = persisted.get("url") or os.environ.get("EMBEDDING_URL", "")
|
||||
if not url:
|
||||
return {}
|
||||
|
||||
model = persisted.get("model") or os.environ.get("EMBEDDING_MODEL", "")
|
||||
api_key = persisted.get("api_key") or os.environ.get("EMBEDDING_API_KEY", "")
|
||||
if persisted.get("api_key"):
|
||||
try:
|
||||
from src.secret_storage import decrypt
|
||||
api_key = decrypt(api_key)
|
||||
except Exception:
|
||||
logger.warning("Could not decrypt saved embedding endpoint API key")
|
||||
api_key = ""
|
||||
|
||||
return {"url": url, "model": model, "api_key": api_key}
|
||||
|
||||
|
||||
def _build_fastembed_client():
|
||||
from src.embeddings import FastEmbedClient
|
||||
|
||||
client = FastEmbedClient()
|
||||
client.get_sentence_embedding_dimension()
|
||||
return client
|
||||
|
||||
|
||||
def _build_custom_client():
|
||||
from src.embeddings import EmbeddingClient, get_embedding_client
|
||||
|
||||
client = get_embedding_client()
|
||||
if isinstance(client, EmbeddingClient):
|
||||
return client
|
||||
raise RuntimeError("HTTP embedding lane unavailable")
|
||||
|
||||
|
||||
def _encode_with_client(client: Any, texts: Sequence[str]) -> List[List[float]]:
|
||||
vecs = client.encode(list(texts), normalize_embeddings=True)
|
||||
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
|
||||
|
||||
|
||||
def _get_or_reset_collection(chroma_client, name: str, metadata: Dict[str, Any], client: Any):
|
||||
try:
|
||||
collection = chroma_client.get_collection(name)
|
||||
except Exception:
|
||||
return chroma_client.get_or_create_collection(name=name, metadata=metadata)
|
||||
|
||||
current = collection.metadata or {}
|
||||
if not (
|
||||
current.get("embedding_fingerprint") not in (None, metadata["embedding_fingerprint"])
|
||||
or current.get("embedding_dimension") not in (None, metadata["embedding_dimension"])
|
||||
or current.get("embedding_lane") not in (None, metadata["embedding_lane"])
|
||||
):
|
||||
return collection
|
||||
|
||||
logger.info(
|
||||
"Recreating Chroma collection %s for embedding lane change (%s -> %s)",
|
||||
name,
|
||||
current.get("embedding_fingerprint"),
|
||||
metadata["embedding_fingerprint"],
|
||||
)
|
||||
preserved = {"ids": [], "documents": [], "metadatas": [], "embeddings": []}
|
||||
try:
|
||||
preserved = collection.get(include=["documents", "metadatas", "embeddings"]) or preserved
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not preserve documents before resetting {name}: {e}") from e
|
||||
|
||||
ids = preserved.get("ids") or []
|
||||
docs = preserved.get("documents") or []
|
||||
metas = preserved.get("metadatas") or []
|
||||
prepared_batches = []
|
||||
if ids and docs:
|
||||
try:
|
||||
for start in range(0, len(ids), 100):
|
||||
batch_ids = ids[start:start + 100]
|
||||
batch_docs = docs[start:start + 100]
|
||||
batch_metas = metas[start:start + 100]
|
||||
if len(batch_metas) < len(batch_ids):
|
||||
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||
prepared_batches.append((
|
||||
batch_ids,
|
||||
batch_docs,
|
||||
batch_metas,
|
||||
_encode_with_client(client, batch_docs),
|
||||
))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not re-embed preserved rows for {name}: {e}") from e
|
||||
|
||||
chroma_client.delete_collection(name)
|
||||
collection = chroma_client.get_or_create_collection(name=name, metadata=metadata)
|
||||
|
||||
try:
|
||||
for batch_ids, batch_docs, batch_metas, embeddings in prepared_batches:
|
||||
collection.add(
|
||||
ids=batch_ids,
|
||||
documents=batch_docs,
|
||||
metadatas=batch_metas,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Could not write reset collection %s; restoring previous rows: %s", name, e)
|
||||
try:
|
||||
chroma_client.delete_collection(name)
|
||||
restored = chroma_client.get_or_create_collection(name=name, metadata=current)
|
||||
old_embeddings = preserved.get("embeddings") or []
|
||||
if ids and docs and old_embeddings:
|
||||
for start in range(0, len(ids), 100):
|
||||
batch_ids = ids[start:start + 100]
|
||||
batch_docs = docs[start:start + 100]
|
||||
batch_metas = metas[start:start + 100]
|
||||
batch_embeddings = old_embeddings[start:start + 100]
|
||||
if len(batch_metas) < len(batch_ids):
|
||||
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||
restored.add(
|
||||
ids=batch_ids,
|
||||
documents=batch_docs,
|
||||
metadatas=batch_metas,
|
||||
embeddings=batch_embeddings,
|
||||
)
|
||||
except Exception as restore_error:
|
||||
logger.warning("Could not restore previous collection %s: %s", name, restore_error)
|
||||
raise RuntimeError(f"Could not write reset collection {name}: {e}") from e
|
||||
if prepared_batches:
|
||||
logger.info("Re-embedded %s rows after resetting %s", len(ids), name)
|
||||
|
||||
return collection
|
||||
|
||||
|
||||
def _create_lane(chroma_client, base_name: str, lane_name: str, client: Any) -> EmbeddingLane:
|
||||
dimension = int(client.get_sentence_embedding_dimension())
|
||||
model = getattr(client, "model", "")
|
||||
url = getattr(client, "url", "")
|
||||
fp = _fingerprint(lane_name, url, model, dimension)
|
||||
name = collection_name(base_name, lane_name)
|
||||
metadata = _metadata(lane_name, url, model, dimension, fp)
|
||||
collection = _get_or_reset_collection(chroma_client, name, metadata, client)
|
||||
return EmbeddingLane(
|
||||
name=lane_name,
|
||||
client=client,
|
||||
collection=collection,
|
||||
collection_name=name,
|
||||
model=model,
|
||||
url=url,
|
||||
dimension=dimension,
|
||||
fingerprint=fp,
|
||||
)
|
||||
|
||||
|
||||
def build_embedding_lanes(base_name: str) -> List[EmbeddingLane]:
|
||||
"""Return healthy lanes in retrieval preference order: custom, fastembed."""
|
||||
from src.chroma_client import get_chroma_client
|
||||
|
||||
chroma_client = get_chroma_client()
|
||||
lanes: List[EmbeddingLane] = []
|
||||
|
||||
try:
|
||||
custom = _build_custom_client()
|
||||
if custom is not None:
|
||||
lanes.append(_create_lane(chroma_client, base_name, LANE_CUSTOM, custom))
|
||||
except Exception as e:
|
||||
logger.warning("Custom embedding lane unavailable for %s: %s", base_name, e)
|
||||
|
||||
try:
|
||||
fastembed = _build_fastembed_client()
|
||||
lanes.append(_create_lane(chroma_client, base_name, LANE_FASTEMBED, fastembed))
|
||||
except Exception as e:
|
||||
logger.warning("FastEmbed lane unavailable for %s: %s", base_name, e)
|
||||
|
||||
return lanes
|
||||
|
||||
|
||||
def migrate_legacy_collection(base_name: str, lanes: Sequence[EmbeddingLane]) -> None:
|
||||
"""Backfill empty lanes from a legacy unsuffixed collection, if present."""
|
||||
if not lanes:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chroma_client import get_chroma_client
|
||||
|
||||
chroma_client = get_chroma_client()
|
||||
legacy = chroma_client.get_collection(base_name)
|
||||
data = legacy.get(include=["documents", "metadatas"])
|
||||
except Exception:
|
||||
return
|
||||
|
||||
ids = data.get("ids") or []
|
||||
docs = data.get("documents") or []
|
||||
metas = data.get("metadatas") or []
|
||||
if not ids or not docs:
|
||||
return
|
||||
|
||||
for lane in lanes:
|
||||
try:
|
||||
existing = lane.collection.get(ids=ids)
|
||||
existing_ids = set(existing.get("ids") or [])
|
||||
except Exception:
|
||||
existing_ids = set()
|
||||
all_metas = list(metas or [])
|
||||
if len(all_metas) < len(ids):
|
||||
all_metas += [{}] * (len(ids) - len(all_metas))
|
||||
missing = [
|
||||
(row_id, doc, meta)
|
||||
for row_id, doc, meta in zip(ids, docs, all_metas)
|
||||
if row_id not in existing_ids
|
||||
]
|
||||
if not missing:
|
||||
continue
|
||||
|
||||
for start in range(0, len(missing), 100):
|
||||
batch = missing[start:start + 100]
|
||||
batch_ids = [row_id for row_id, _doc, _meta in batch]
|
||||
batch_docs = [doc for _row_id, doc, _meta in batch]
|
||||
batch_metas = [meta or {} for _row_id, _doc, meta in batch]
|
||||
if len(batch_metas) < len(batch_ids):
|
||||
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||
try:
|
||||
embeddings = lane.encode(batch_docs)
|
||||
lane.collection.add(
|
||||
ids=batch_ids,
|
||||
documents=batch_docs,
|
||||
metadatas=batch_metas,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not backfill %s lane from legacy collection %s: %s",
|
||||
lane.name,
|
||||
base_name,
|
||||
e,
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info("Backfilled %s %s lane rows from legacy collection %s", len(missing), lane.name, base_name)
|
||||
|
||||
|
||||
def lane_count(lanes: Sequence[EmbeddingLane]) -> int:
|
||||
return max((lane.count() for lane in lanes), default=0)
|
||||
|
||||
|
||||
def dedupe_results(results: Iterable[Dict[str, Any]], id_key: str = "id", limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
seen = set()
|
||||
out: List[Dict[str, Any]] = []
|
||||
for row in results:
|
||||
row_id = row.get(id_key)
|
||||
if not row_id or row_id in seen:
|
||||
continue
|
||||
seen.add(row_id)
|
||||
out.append(row)
|
||||
if limit is not None and len(out) >= limit:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def query_lanes(
|
||||
lanes: Sequence[EmbeddingLane],
|
||||
query: str,
|
||||
n_results: Callable[[EmbeddingLane], int],
|
||||
include: Sequence[str],
|
||||
where: Optional[Dict[str, Any]] = None,
|
||||
raise_if_all_failed: bool = False,
|
||||
) -> List[tuple[EmbeddingLane, Dict[str, Any]]]:
|
||||
out: List[tuple[EmbeddingLane, Dict[str, Any]]] = []
|
||||
attempted = 0
|
||||
failures: List[str] = []
|
||||
for lane in lanes:
|
||||
try:
|
||||
count = lane.count()
|
||||
if count == 0:
|
||||
continue
|
||||
attempted += 1
|
||||
n = min(n_results(lane), count)
|
||||
if n <= 0:
|
||||
continue
|
||||
results = lane.collection.query(
|
||||
query_embeddings=lane.encode([query]),
|
||||
n_results=n,
|
||||
where=where,
|
||||
include=list(include),
|
||||
)
|
||||
out.append((lane, results))
|
||||
except Exception as e:
|
||||
failures.append(f"{lane.name}: {e}")
|
||||
logger.warning("%s lane query failed for %s: %s", lane.name, lane.collection_name, e)
|
||||
if raise_if_all_failed and attempted and not out and failures:
|
||||
raise RuntimeError("; ".join(failures))
|
||||
return out
|
||||
Reference in New Issue
Block a user