mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-27 07:05:23 -04:00
62a23ca4aa
* test: split embedding lane tests * test: preserve embedding focus selector after lane split
125 lines
4.2 KiB
Python
125 lines
4.2 KiB
Python
"""Shared fakes for embedding-lane tests."""
|
|
|
|
|
|
class FakeEmbedder:
|
|
def __init__(self, dim, model, url):
|
|
self.dim = dim
|
|
self.model = model
|
|
self.url = url
|
|
|
|
def get_sentence_embedding_dimension(self):
|
|
return self.dim
|
|
|
|
def encode(self, texts, normalize_embeddings=True):
|
|
return [[float(i + 1)] * self.dim for i, _ in enumerate(texts)]
|
|
|
|
|
|
class FailingEmbedder(FakeEmbedder):
|
|
def encode(self, texts, normalize_embeddings=True):
|
|
raise RuntimeError("embedding endpoint rate limited")
|
|
|
|
|
|
class FakeCollection:
|
|
def __init__(self, name, metadata=None):
|
|
self.name = name
|
|
self.metadata = metadata or {}
|
|
self.rows = {}
|
|
self.dim = None
|
|
|
|
def count(self):
|
|
return len(self.rows)
|
|
|
|
def add(self, ids, embeddings, documents=None, metadatas=None):
|
|
self._check_dim(embeddings)
|
|
documents = documents or [None] * len(ids)
|
|
metadatas = metadatas or [{}] * len(ids)
|
|
for row_id, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
|
|
self.rows[row_id] = {"embedding": emb, "document": doc, "metadata": meta}
|
|
|
|
def upsert(self, ids, embeddings, documents=None, metadatas=None):
|
|
self.add(ids, embeddings, documents=documents, metadatas=metadatas)
|
|
|
|
def get(self, ids=None, include=None, where=None, limit=None):
|
|
selected = list(self.rows.items())
|
|
if ids is not None:
|
|
id_set = set(ids)
|
|
selected = [(row_id, row) for row_id, row in selected if row_id in id_set]
|
|
if where:
|
|
selected = [
|
|
(row_id, row)
|
|
for row_id, row in selected
|
|
if all(row["metadata"].get(k) == v for k, v in where.items())
|
|
]
|
|
if limit is not None:
|
|
selected = selected[:limit]
|
|
return {
|
|
"ids": [row_id for row_id, _ in selected],
|
|
"documents": [row["document"] for _, row in selected],
|
|
"metadatas": [row["metadata"] for _, row in selected],
|
|
"embeddings": [row["embedding"] for _, row in selected],
|
|
}
|
|
|
|
def query(self, query_embeddings, n_results, where=None, include=None):
|
|
self._check_dim(query_embeddings)
|
|
rows = self.get(where=where)
|
|
ids = rows["ids"][:n_results]
|
|
docs = rows["documents"][:n_results]
|
|
metas = rows["metadatas"][:n_results]
|
|
return {
|
|
"ids": [ids],
|
|
"documents": [docs],
|
|
"metadatas": [metas],
|
|
"distances": [[0.1 + i * 0.01 for i in range(len(ids))]],
|
|
}
|
|
|
|
def delete(self, ids):
|
|
for row_id in ids:
|
|
self.rows.pop(row_id, None)
|
|
|
|
def _check_dim(self, embeddings):
|
|
if not embeddings:
|
|
return
|
|
dim = len(embeddings[0])
|
|
if self.dim is None:
|
|
self.dim = dim
|
|
elif self.dim != dim:
|
|
raise RuntimeError(f"Collection expecting embedding with dimension of {self.dim}, got {dim}")
|
|
|
|
|
|
class FakeChroma:
|
|
def __init__(self):
|
|
self.collections = {}
|
|
self.deleted = []
|
|
self.fail_next_add_for = {}
|
|
|
|
def get_or_create_collection(self, name, metadata=None):
|
|
if name not in self.collections:
|
|
self.collections[name] = FakeCollection(name, metadata=metadata)
|
|
if self.fail_next_add_for.get(name, 0) > 0:
|
|
original_add = self.collections[name].add
|
|
|
|
def fail_once(*args, **kwargs):
|
|
self.fail_next_add_for[name] -= 1
|
|
self.collections[name].add = original_add
|
|
raise RuntimeError("chroma write failed")
|
|
|
|
self.collections[name].add = fail_once
|
|
elif metadata is not None:
|
|
self.collections[name].metadata = metadata
|
|
return self.collections[name]
|
|
|
|
def get_collection(self, name):
|
|
if name not in self.collections:
|
|
raise KeyError(name)
|
|
return self.collections[name]
|
|
|
|
def delete_collection(self, name):
|
|
self.deleted.append(name)
|
|
self.collections.pop(name, None)
|
|
|
|
|
|
def patch_chroma(monkeypatch, fake):
|
|
import src.chroma_client as chroma_client
|
|
|
|
monkeypatch.setattr(chroma_client, "get_chroma_client", lambda: fake)
|