mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-28 15:45:22 -04:00
test: split embedding lane tests (#4389)
* test: split embedding lane tests * test: preserve embedding focus selector after lane split
This commit is contained in:
committed by
GitHub
parent
fc1351d0f8
commit
62a23ca4aa
@@ -0,0 +1,124 @@
|
||||
"""Shared fakes for embedding-lane tests."""
|
||||
|
||||
|
||||
class FakeEmbedder:
|
||||
def __init__(self, dim, model, url):
|
||||
self.dim = dim
|
||||
self.model = model
|
||||
self.url = url
|
||||
|
||||
def get_sentence_embedding_dimension(self):
|
||||
return self.dim
|
||||
|
||||
def encode(self, texts, normalize_embeddings=True):
|
||||
return [[float(i + 1)] * self.dim for i, _ in enumerate(texts)]
|
||||
|
||||
|
||||
class FailingEmbedder(FakeEmbedder):
|
||||
def encode(self, texts, normalize_embeddings=True):
|
||||
raise RuntimeError("embedding endpoint rate limited")
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self, name, metadata=None):
|
||||
self.name = name
|
||||
self.metadata = metadata or {}
|
||||
self.rows = {}
|
||||
self.dim = None
|
||||
|
||||
def count(self):
|
||||
return len(self.rows)
|
||||
|
||||
def add(self, ids, embeddings, documents=None, metadatas=None):
|
||||
self._check_dim(embeddings)
|
||||
documents = documents or [None] * len(ids)
|
||||
metadatas = metadatas or [{}] * len(ids)
|
||||
for row_id, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
|
||||
self.rows[row_id] = {"embedding": emb, "document": doc, "metadata": meta}
|
||||
|
||||
def upsert(self, ids, embeddings, documents=None, metadatas=None):
|
||||
self.add(ids, embeddings, documents=documents, metadatas=metadatas)
|
||||
|
||||
def get(self, ids=None, include=None, where=None, limit=None):
|
||||
selected = list(self.rows.items())
|
||||
if ids is not None:
|
||||
id_set = set(ids)
|
||||
selected = [(row_id, row) for row_id, row in selected if row_id in id_set]
|
||||
if where:
|
||||
selected = [
|
||||
(row_id, row)
|
||||
for row_id, row in selected
|
||||
if all(row["metadata"].get(k) == v for k, v in where.items())
|
||||
]
|
||||
if limit is not None:
|
||||
selected = selected[:limit]
|
||||
return {
|
||||
"ids": [row_id for row_id, _ in selected],
|
||||
"documents": [row["document"] for _, row in selected],
|
||||
"metadatas": [row["metadata"] for _, row in selected],
|
||||
"embeddings": [row["embedding"] for _, row in selected],
|
||||
}
|
||||
|
||||
def query(self, query_embeddings, n_results, where=None, include=None):
|
||||
self._check_dim(query_embeddings)
|
||||
rows = self.get(where=where)
|
||||
ids = rows["ids"][:n_results]
|
||||
docs = rows["documents"][:n_results]
|
||||
metas = rows["metadatas"][:n_results]
|
||||
return {
|
||||
"ids": [ids],
|
||||
"documents": [docs],
|
||||
"metadatas": [metas],
|
||||
"distances": [[0.1 + i * 0.01 for i in range(len(ids))]],
|
||||
}
|
||||
|
||||
def delete(self, ids):
|
||||
for row_id in ids:
|
||||
self.rows.pop(row_id, None)
|
||||
|
||||
def _check_dim(self, embeddings):
|
||||
if not embeddings:
|
||||
return
|
||||
dim = len(embeddings[0])
|
||||
if self.dim is None:
|
||||
self.dim = dim
|
||||
elif self.dim != dim:
|
||||
raise RuntimeError(f"Collection expecting embedding with dimension of {self.dim}, got {dim}")
|
||||
|
||||
|
||||
class FakeChroma:
|
||||
def __init__(self):
|
||||
self.collections = {}
|
||||
self.deleted = []
|
||||
self.fail_next_add_for = {}
|
||||
|
||||
def get_or_create_collection(self, name, metadata=None):
|
||||
if name not in self.collections:
|
||||
self.collections[name] = FakeCollection(name, metadata=metadata)
|
||||
if self.fail_next_add_for.get(name, 0) > 0:
|
||||
original_add = self.collections[name].add
|
||||
|
||||
def fail_once(*args, **kwargs):
|
||||
self.fail_next_add_for[name] -= 1
|
||||
self.collections[name].add = original_add
|
||||
raise RuntimeError("chroma write failed")
|
||||
|
||||
self.collections[name].add = fail_once
|
||||
elif metadata is not None:
|
||||
self.collections[name].metadata = metadata
|
||||
return self.collections[name]
|
||||
|
||||
def get_collection(self, name):
|
||||
if name not in self.collections:
|
||||
raise KeyError(name)
|
||||
return self.collections[name]
|
||||
|
||||
def delete_collection(self, name):
|
||||
self.deleted.append(name)
|
||||
self.collections.pop(name, None)
|
||||
|
||||
|
||||
def patch_chroma(monkeypatch, fake):
|
||||
import src.chroma_client as chroma_client
|
||||
|
||||
monkeypatch.setattr(chroma_client, "get_chroma_client", lambda: fake)
|
||||
Reference in New Issue
Block a user