mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 01:35:36 -04:00
381 lines
13 KiB
Python
381 lines
13 KiB
Python
"""
|
|
embedding_lanes.py
|
|
|
|
Helpers for keeping FastEmbed fallback vectors separate from user-configured
|
|
embedding vectors. ChromaDB fixes a collection's dimension on first insert, so
|
|
different embedding models must never share one collection.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
LANE_FASTEMBED = "fastembed"
|
|
LANE_CUSTOM = "custom"
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingLane:
|
|
name: str
|
|
client: Any
|
|
collection: Any
|
|
collection_name: str
|
|
model: str
|
|
url: str
|
|
dimension: int
|
|
fingerprint: str
|
|
|
|
@property
|
|
def healthy(self) -> bool:
|
|
return self.collection is not None and self.client is not None
|
|
|
|
def encode(self, texts: Sequence[str]) -> List[List[float]]:
|
|
vecs = self.client.encode(list(texts), normalize_embeddings=True)
|
|
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
|
|
|
|
def count(self) -> int:
|
|
try:
|
|
return int(self.collection.count())
|
|
except Exception:
|
|
return 0
|
|
|
|
def stats(self) -> Dict[str, Any]:
|
|
return {
|
|
"name": self.name,
|
|
"collection": self.collection_name,
|
|
"model": self.model,
|
|
"url": self.url,
|
|
"dimension": self.dimension,
|
|
"fingerprint": self.fingerprint,
|
|
"count": self.count(),
|
|
"healthy": self.healthy,
|
|
}
|
|
|
|
|
|
def reset_embedding_lane_state() -> None:
|
|
"""Reset process-local embedding lane state after endpoint config changes."""
|
|
try:
|
|
from src.embeddings import reset_http_embed_state
|
|
reset_http_embed_state()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def collection_name(base_name: str, lane_name: str) -> str:
|
|
return f"{base_name}_{lane_name}"
|
|
|
|
|
|
def _fingerprint(lane_name: str, url: str, model: str, dimension: int) -> str:
|
|
raw = f"{lane_name}\n{url}\n{model}\n{dimension}"
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
|
|
|
|
|
|
def _metadata(lane_name: str, url: str, model: str, dimension: int, fingerprint: str) -> Dict[str, Any]:
|
|
return {
|
|
"hnsw:space": "cosine",
|
|
"embedding_lane": lane_name,
|
|
"embedding_url": url,
|
|
"embedding_model": model,
|
|
"embedding_dimension": dimension,
|
|
"embedding_fingerprint": fingerprint,
|
|
}
|
|
|
|
|
|
def _load_custom_endpoint() -> Dict[str, str]:
|
|
try:
|
|
from src.embeddings import _load_persisted_endpoint
|
|
persisted = _load_persisted_endpoint()
|
|
except Exception:
|
|
persisted = {}
|
|
|
|
url = persisted.get("url") or os.environ.get("EMBEDDING_URL", "")
|
|
if not url:
|
|
return {}
|
|
|
|
model = persisted.get("model") or os.environ.get("EMBEDDING_MODEL", "")
|
|
api_key = persisted.get("api_key") or os.environ.get("EMBEDDING_API_KEY", "")
|
|
if persisted.get("api_key"):
|
|
try:
|
|
from src.secret_storage import decrypt
|
|
api_key = decrypt(api_key)
|
|
except Exception:
|
|
logger.warning("Could not decrypt saved embedding endpoint API key")
|
|
api_key = ""
|
|
|
|
return {"url": url, "model": model, "api_key": api_key}
|
|
|
|
|
|
def _build_fastembed_client():
|
|
from src.embeddings import FastEmbedClient
|
|
|
|
client = FastEmbedClient()
|
|
client.get_sentence_embedding_dimension()
|
|
return client
|
|
|
|
|
|
def _build_custom_client():
|
|
from src.embeddings import EmbeddingClient, get_embedding_client
|
|
|
|
client = get_embedding_client()
|
|
if isinstance(client, EmbeddingClient):
|
|
return client
|
|
raise RuntimeError("HTTP embedding lane unavailable")
|
|
|
|
|
|
def _encode_with_client(client: Any, texts: Sequence[str]) -> List[List[float]]:
|
|
vecs = client.encode(list(texts), normalize_embeddings=True)
|
|
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
|
|
|
|
|
|
def _get_or_reset_collection(chroma_client, name: str, metadata: Dict[str, Any], client: Any):
|
|
try:
|
|
collection = chroma_client.get_collection(name)
|
|
except Exception:
|
|
return chroma_client.get_or_create_collection(name=name, metadata=metadata)
|
|
|
|
current = collection.metadata or {}
|
|
if not (
|
|
current.get("embedding_fingerprint") not in (None, metadata["embedding_fingerprint"])
|
|
or current.get("embedding_dimension") not in (None, metadata["embedding_dimension"])
|
|
or current.get("embedding_lane") not in (None, metadata["embedding_lane"])
|
|
):
|
|
return collection
|
|
|
|
logger.info(
|
|
"Recreating Chroma collection %s for embedding lane change (%s -> %s)",
|
|
name,
|
|
current.get("embedding_fingerprint"),
|
|
metadata["embedding_fingerprint"],
|
|
)
|
|
preserved = {"ids": [], "documents": [], "metadatas": [], "embeddings": []}
|
|
try:
|
|
preserved = collection.get(include=["documents", "metadatas", "embeddings"]) or preserved
|
|
except Exception as e:
|
|
raise RuntimeError(f"Could not preserve documents before resetting {name}: {e}") from e
|
|
|
|
ids = preserved.get("ids") or []
|
|
docs = preserved.get("documents") or []
|
|
metas = preserved.get("metadatas") or []
|
|
prepared_batches = []
|
|
if ids and docs:
|
|
try:
|
|
for start in range(0, len(ids), 100):
|
|
batch_ids = ids[start:start + 100]
|
|
batch_docs = docs[start:start + 100]
|
|
batch_metas = metas[start:start + 100]
|
|
if len(batch_metas) < len(batch_ids):
|
|
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
|
prepared_batches.append((
|
|
batch_ids,
|
|
batch_docs,
|
|
batch_metas,
|
|
_encode_with_client(client, batch_docs),
|
|
))
|
|
except Exception as e:
|
|
raise RuntimeError(f"Could not re-embed preserved rows for {name}: {e}") from e
|
|
|
|
chroma_client.delete_collection(name)
|
|
collection = chroma_client.get_or_create_collection(name=name, metadata=metadata)
|
|
|
|
try:
|
|
for batch_ids, batch_docs, batch_metas, embeddings in prepared_batches:
|
|
collection.add(
|
|
ids=batch_ids,
|
|
documents=batch_docs,
|
|
metadatas=batch_metas,
|
|
embeddings=embeddings,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Could not write reset collection %s; restoring previous rows: %s", name, e)
|
|
try:
|
|
chroma_client.delete_collection(name)
|
|
restored = chroma_client.get_or_create_collection(name=name, metadata=current)
|
|
old_embeddings = preserved.get("embeddings") or []
|
|
if ids and docs and old_embeddings:
|
|
for start in range(0, len(ids), 100):
|
|
batch_ids = ids[start:start + 100]
|
|
batch_docs = docs[start:start + 100]
|
|
batch_metas = metas[start:start + 100]
|
|
batch_embeddings = old_embeddings[start:start + 100]
|
|
if len(batch_metas) < len(batch_ids):
|
|
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
|
restored.add(
|
|
ids=batch_ids,
|
|
documents=batch_docs,
|
|
metadatas=batch_metas,
|
|
embeddings=batch_embeddings,
|
|
)
|
|
except Exception as restore_error:
|
|
logger.warning("Could not restore previous collection %s: %s", name, restore_error)
|
|
raise RuntimeError(f"Could not write reset collection {name}: {e}") from e
|
|
if prepared_batches:
|
|
logger.info("Re-embedded %s rows after resetting %s", len(ids), name)
|
|
|
|
return collection
|
|
|
|
|
|
def _create_lane(chroma_client, base_name: str, lane_name: str, client: Any) -> EmbeddingLane:
|
|
dimension = int(client.get_sentence_embedding_dimension())
|
|
model = getattr(client, "model", "")
|
|
url = getattr(client, "url", "")
|
|
fp = _fingerprint(lane_name, url, model, dimension)
|
|
name = collection_name(base_name, lane_name)
|
|
metadata = _metadata(lane_name, url, model, dimension, fp)
|
|
collection = _get_or_reset_collection(chroma_client, name, metadata, client)
|
|
return EmbeddingLane(
|
|
name=lane_name,
|
|
client=client,
|
|
collection=collection,
|
|
collection_name=name,
|
|
model=model,
|
|
url=url,
|
|
dimension=dimension,
|
|
fingerprint=fp,
|
|
)
|
|
|
|
|
|
def build_embedding_lanes(base_name: str) -> List[EmbeddingLane]:
|
|
"""Return healthy lanes in retrieval preference order: custom, fastembed."""
|
|
from src.chroma_client import get_chroma_client
|
|
|
|
chroma_client = get_chroma_client()
|
|
lanes: List[EmbeddingLane] = []
|
|
|
|
try:
|
|
custom = _build_custom_client()
|
|
if custom is not None:
|
|
lanes.append(_create_lane(chroma_client, base_name, LANE_CUSTOM, custom))
|
|
except Exception as e:
|
|
logger.warning("Custom embedding lane unavailable for %s: %s", base_name, e)
|
|
|
|
try:
|
|
fastembed = _build_fastembed_client()
|
|
lanes.append(_create_lane(chroma_client, base_name, LANE_FASTEMBED, fastembed))
|
|
except Exception as e:
|
|
logger.warning("FastEmbed lane unavailable for %s: %s", base_name, e)
|
|
|
|
return lanes
|
|
|
|
|
|
def migrate_legacy_collection(base_name: str, lanes: Sequence[EmbeddingLane]) -> None:
|
|
"""Backfill empty lanes from a legacy unsuffixed collection, if present."""
|
|
if not lanes:
|
|
return
|
|
|
|
try:
|
|
from src.chroma_client import get_chroma_client
|
|
|
|
chroma_client = get_chroma_client()
|
|
legacy = chroma_client.get_collection(base_name)
|
|
data = legacy.get(include=["documents", "metadatas"])
|
|
except Exception:
|
|
return
|
|
|
|
ids = data.get("ids") or []
|
|
docs = data.get("documents") or []
|
|
metas = data.get("metadatas") or []
|
|
if not ids or not docs:
|
|
return
|
|
|
|
for lane in lanes:
|
|
try:
|
|
existing = lane.collection.get(ids=ids)
|
|
existing_ids = set(existing.get("ids") or [])
|
|
except Exception:
|
|
existing_ids = set()
|
|
all_metas = list(metas or [])
|
|
if len(all_metas) < len(ids):
|
|
all_metas += [{}] * (len(ids) - len(all_metas))
|
|
missing = [
|
|
(row_id, doc, meta)
|
|
for row_id, doc, meta in zip(ids, docs, all_metas)
|
|
if row_id not in existing_ids
|
|
]
|
|
if not missing:
|
|
continue
|
|
|
|
for start in range(0, len(missing), 100):
|
|
batch = missing[start:start + 100]
|
|
batch_ids = [row_id for row_id, _doc, _meta in batch]
|
|
batch_docs = [doc for _row_id, doc, _meta in batch]
|
|
batch_metas = [meta or {} for _row_id, _doc, meta in batch]
|
|
if len(batch_metas) < len(batch_ids):
|
|
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
|
try:
|
|
embeddings = lane.encode(batch_docs)
|
|
lane.collection.add(
|
|
ids=batch_ids,
|
|
documents=batch_docs,
|
|
metadatas=batch_metas,
|
|
embeddings=embeddings,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Could not backfill %s lane from legacy collection %s: %s",
|
|
lane.name,
|
|
base_name,
|
|
e,
|
|
)
|
|
break
|
|
else:
|
|
logger.info("Backfilled %s %s lane rows from legacy collection %s", len(missing), lane.name, base_name)
|
|
|
|
|
|
def lane_count(lanes: Sequence[EmbeddingLane]) -> int:
|
|
return max((lane.count() for lane in lanes), default=0)
|
|
|
|
|
|
def dedupe_results(results: Iterable[Dict[str, Any]], id_key: str = "id", limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
seen = set()
|
|
out: List[Dict[str, Any]] = []
|
|
for row in results:
|
|
row_id = row.get(id_key)
|
|
if not row_id or row_id in seen:
|
|
continue
|
|
seen.add(row_id)
|
|
out.append(row)
|
|
if limit is not None and len(out) >= limit:
|
|
break
|
|
return out
|
|
|
|
|
|
def query_lanes(
|
|
lanes: Sequence[EmbeddingLane],
|
|
query: str,
|
|
n_results: Callable[[EmbeddingLane], int],
|
|
include: Sequence[str],
|
|
where: Optional[Dict[str, Any]] = None,
|
|
raise_if_all_failed: bool = False,
|
|
) -> List[tuple[EmbeddingLane, Dict[str, Any]]]:
|
|
out: List[tuple[EmbeddingLane, Dict[str, Any]]] = []
|
|
attempted = 0
|
|
failures: List[str] = []
|
|
for lane in lanes:
|
|
try:
|
|
count = lane.count()
|
|
if count == 0:
|
|
continue
|
|
attempted += 1
|
|
n = min(n_results(lane), count)
|
|
if n <= 0:
|
|
continue
|
|
results = lane.collection.query(
|
|
query_embeddings=lane.encode([query]),
|
|
n_results=n,
|
|
where=where,
|
|
include=list(include),
|
|
)
|
|
out.append((lane, results))
|
|
except Exception as e:
|
|
failures.append(f"{lane.name}: {e}")
|
|
logger.warning("%s lane query failed for %s: %s", lane.name, lane.collection_name, e)
|
|
if raise_if_all_failed and attempted and not out and failures:
|
|
raise RuntimeError("; ".join(failures))
|
|
return out
|