Files
odysseus/tests/helpers/embedding_lanes.py
T
Alexandre Teixeira 62a23ca4aa test: split embedding lane tests (#4389)
* test: split embedding lane tests

* test: preserve embedding focus selector after lane split
2026-06-26 18:28:40 +02:00

125 lines
4.2 KiB
Python

"""Shared fakes for embedding-lane tests."""
class FakeEmbedder:
def __init__(self, dim, model, url):
self.dim = dim
self.model = model
self.url = url
def get_sentence_embedding_dimension(self):
return self.dim
def encode(self, texts, normalize_embeddings=True):
return [[float(i + 1)] * self.dim for i, _ in enumerate(texts)]
class FailingEmbedder(FakeEmbedder):
def encode(self, texts, normalize_embeddings=True):
raise RuntimeError("embedding endpoint rate limited")
class FakeCollection:
def __init__(self, name, metadata=None):
self.name = name
self.metadata = metadata or {}
self.rows = {}
self.dim = None
def count(self):
return len(self.rows)
def add(self, ids, embeddings, documents=None, metadatas=None):
self._check_dim(embeddings)
documents = documents or [None] * len(ids)
metadatas = metadatas or [{}] * len(ids)
for row_id, emb, doc, meta in zip(ids, embeddings, documents, metadatas):
self.rows[row_id] = {"embedding": emb, "document": doc, "metadata": meta}
def upsert(self, ids, embeddings, documents=None, metadatas=None):
self.add(ids, embeddings, documents=documents, metadatas=metadatas)
def get(self, ids=None, include=None, where=None, limit=None):
selected = list(self.rows.items())
if ids is not None:
id_set = set(ids)
selected = [(row_id, row) for row_id, row in selected if row_id in id_set]
if where:
selected = [
(row_id, row)
for row_id, row in selected
if all(row["metadata"].get(k) == v for k, v in where.items())
]
if limit is not None:
selected = selected[:limit]
return {
"ids": [row_id for row_id, _ in selected],
"documents": [row["document"] for _, row in selected],
"metadatas": [row["metadata"] for _, row in selected],
"embeddings": [row["embedding"] for _, row in selected],
}
def query(self, query_embeddings, n_results, where=None, include=None):
self._check_dim(query_embeddings)
rows = self.get(where=where)
ids = rows["ids"][:n_results]
docs = rows["documents"][:n_results]
metas = rows["metadatas"][:n_results]
return {
"ids": [ids],
"documents": [docs],
"metadatas": [metas],
"distances": [[0.1 + i * 0.01 for i in range(len(ids))]],
}
def delete(self, ids):
for row_id in ids:
self.rows.pop(row_id, None)
def _check_dim(self, embeddings):
if not embeddings:
return
dim = len(embeddings[0])
if self.dim is None:
self.dim = dim
elif self.dim != dim:
raise RuntimeError(f"Collection expecting embedding with dimension of {self.dim}, got {dim}")
class FakeChroma:
def __init__(self):
self.collections = {}
self.deleted = []
self.fail_next_add_for = {}
def get_or_create_collection(self, name, metadata=None):
if name not in self.collections:
self.collections[name] = FakeCollection(name, metadata=metadata)
if self.fail_next_add_for.get(name, 0) > 0:
original_add = self.collections[name].add
def fail_once(*args, **kwargs):
self.fail_next_add_for[name] -= 1
self.collections[name].add = original_add
raise RuntimeError("chroma write failed")
self.collections[name].add = fail_once
elif metadata is not None:
self.collections[name].metadata = metadata
return self.collections[name]
def get_collection(self, name):
if name not in self.collections:
raise KeyError(name)
return self.collections[name]
def delete_collection(self, name):
self.deleted.append(name)
self.collections.pop(name, None)
def patch_chroma(monkeypatch, fake):
import src.chroma_client as chroma_client
monkeypatch.setattr(chroma_client, "get_chroma_client", lambda: fake)