mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
fix: split Chroma embedding lanes (#3046)
This commit is contained in:
@@ -316,6 +316,16 @@ def setup_embedding_routes():
|
|||||||
reset_http_embed_state()
|
reset_http_embed_state()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from src.embedding_lanes import reset_embedding_lane_state
|
||||||
|
reset_embedding_lane_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from src.tool_index import reset_tool_index
|
||||||
|
reset_tool_index()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
||||||
try:
|
try:
|
||||||
@@ -347,6 +357,16 @@ def setup_embedding_routes():
|
|||||||
reset_http_embed_state()
|
reset_http_embed_state()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
from src.embedding_lanes import reset_embedding_lane_state
|
||||||
|
reset_embedding_lane_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from src.tool_index import reset_tool_index
|
||||||
|
reset_tool_index()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Reset ChromaDB client
|
# Reset ChromaDB client
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -0,0 +1,380 @@
|
|||||||
|
"""
|
||||||
|
embedding_lanes.py
|
||||||
|
|
||||||
|
Helpers for keeping FastEmbed fallback vectors separate from user-configured
|
||||||
|
embedding vectors. ChromaDB fixes a collection's dimension on first insert, so
|
||||||
|
different embedding models must never share one collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
LANE_FASTEMBED = "fastembed"
|
||||||
|
LANE_CUSTOM = "custom"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingLane:
|
||||||
|
name: str
|
||||||
|
client: Any
|
||||||
|
collection: Any
|
||||||
|
collection_name: str
|
||||||
|
model: str
|
||||||
|
url: str
|
||||||
|
dimension: int
|
||||||
|
fingerprint: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def healthy(self) -> bool:
|
||||||
|
return self.collection is not None and self.client is not None
|
||||||
|
|
||||||
|
def encode(self, texts: Sequence[str]) -> List[List[float]]:
|
||||||
|
vecs = self.client.encode(list(texts), normalize_embeddings=True)
|
||||||
|
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
try:
|
||||||
|
return int(self.collection.count())
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def stats(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"collection": self.collection_name,
|
||||||
|
"model": self.model,
|
||||||
|
"url": self.url,
|
||||||
|
"dimension": self.dimension,
|
||||||
|
"fingerprint": self.fingerprint,
|
||||||
|
"count": self.count(),
|
||||||
|
"healthy": self.healthy,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def reset_embedding_lane_state() -> None:
|
||||||
|
"""Reset process-local embedding lane state after endpoint config changes."""
|
||||||
|
try:
|
||||||
|
from src.embeddings import reset_http_embed_state
|
||||||
|
reset_http_embed_state()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def collection_name(base_name: str, lane_name: str) -> str:
|
||||||
|
return f"{base_name}_{lane_name}"
|
||||||
|
|
||||||
|
|
||||||
|
def _fingerprint(lane_name: str, url: str, model: str, dimension: int) -> str:
|
||||||
|
raw = f"{lane_name}\n{url}\n{model}\n{dimension}"
|
||||||
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
def _metadata(lane_name: str, url: str, model: str, dimension: int, fingerprint: str) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"hnsw:space": "cosine",
|
||||||
|
"embedding_lane": lane_name,
|
||||||
|
"embedding_url": url,
|
||||||
|
"embedding_model": model,
|
||||||
|
"embedding_dimension": dimension,
|
||||||
|
"embedding_fingerprint": fingerprint,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_custom_endpoint() -> Dict[str, str]:
|
||||||
|
try:
|
||||||
|
from src.embeddings import _load_persisted_endpoint
|
||||||
|
persisted = _load_persisted_endpoint()
|
||||||
|
except Exception:
|
||||||
|
persisted = {}
|
||||||
|
|
||||||
|
url = persisted.get("url") or os.environ.get("EMBEDDING_URL", "")
|
||||||
|
if not url:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
model = persisted.get("model") or os.environ.get("EMBEDDING_MODEL", "")
|
||||||
|
api_key = persisted.get("api_key") or os.environ.get("EMBEDDING_API_KEY", "")
|
||||||
|
if persisted.get("api_key"):
|
||||||
|
try:
|
||||||
|
from src.secret_storage import decrypt
|
||||||
|
api_key = decrypt(api_key)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Could not decrypt saved embedding endpoint API key")
|
||||||
|
api_key = ""
|
||||||
|
|
||||||
|
return {"url": url, "model": model, "api_key": api_key}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fastembed_client():
|
||||||
|
from src.embeddings import FastEmbedClient
|
||||||
|
|
||||||
|
client = FastEmbedClient()
|
||||||
|
client.get_sentence_embedding_dimension()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _build_custom_client():
|
||||||
|
from src.embeddings import EmbeddingClient, get_embedding_client
|
||||||
|
|
||||||
|
client = get_embedding_client()
|
||||||
|
if isinstance(client, EmbeddingClient):
|
||||||
|
return client
|
||||||
|
raise RuntimeError("HTTP embedding lane unavailable")
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_with_client(client: Any, texts: Sequence[str]) -> List[List[float]]:
|
||||||
|
vecs = client.encode(list(texts), normalize_embeddings=True)
|
||||||
|
return vecs.tolist() if hasattr(vecs, "tolist") else [list(v) for v in vecs]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_reset_collection(chroma_client, name: str, metadata: Dict[str, Any], client: Any):
|
||||||
|
try:
|
||||||
|
collection = chroma_client.get_collection(name)
|
||||||
|
except Exception:
|
||||||
|
return chroma_client.get_or_create_collection(name=name, metadata=metadata)
|
||||||
|
|
||||||
|
current = collection.metadata or {}
|
||||||
|
if not (
|
||||||
|
current.get("embedding_fingerprint") not in (None, metadata["embedding_fingerprint"])
|
||||||
|
or current.get("embedding_dimension") not in (None, metadata["embedding_dimension"])
|
||||||
|
or current.get("embedding_lane") not in (None, metadata["embedding_lane"])
|
||||||
|
):
|
||||||
|
return collection
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Recreating Chroma collection %s for embedding lane change (%s -> %s)",
|
||||||
|
name,
|
||||||
|
current.get("embedding_fingerprint"),
|
||||||
|
metadata["embedding_fingerprint"],
|
||||||
|
)
|
||||||
|
preserved = {"ids": [], "documents": [], "metadatas": [], "embeddings": []}
|
||||||
|
try:
|
||||||
|
preserved = collection.get(include=["documents", "metadatas", "embeddings"]) or preserved
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Could not preserve documents before resetting {name}: {e}") from e
|
||||||
|
|
||||||
|
ids = preserved.get("ids") or []
|
||||||
|
docs = preserved.get("documents") or []
|
||||||
|
metas = preserved.get("metadatas") or []
|
||||||
|
prepared_batches = []
|
||||||
|
if ids and docs:
|
||||||
|
try:
|
||||||
|
for start in range(0, len(ids), 100):
|
||||||
|
batch_ids = ids[start:start + 100]
|
||||||
|
batch_docs = docs[start:start + 100]
|
||||||
|
batch_metas = metas[start:start + 100]
|
||||||
|
if len(batch_metas) < len(batch_ids):
|
||||||
|
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||||
|
prepared_batches.append((
|
||||||
|
batch_ids,
|
||||||
|
batch_docs,
|
||||||
|
batch_metas,
|
||||||
|
_encode_with_client(client, batch_docs),
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Could not re-embed preserved rows for {name}: {e}") from e
|
||||||
|
|
||||||
|
chroma_client.delete_collection(name)
|
||||||
|
collection = chroma_client.get_or_create_collection(name=name, metadata=metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for batch_ids, batch_docs, batch_metas, embeddings in prepared_batches:
|
||||||
|
collection.add(
|
||||||
|
ids=batch_ids,
|
||||||
|
documents=batch_docs,
|
||||||
|
metadatas=batch_metas,
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not write reset collection %s; restoring previous rows: %s", name, e)
|
||||||
|
try:
|
||||||
|
chroma_client.delete_collection(name)
|
||||||
|
restored = chroma_client.get_or_create_collection(name=name, metadata=current)
|
||||||
|
old_embeddings = preserved.get("embeddings") or []
|
||||||
|
if ids and docs and old_embeddings:
|
||||||
|
for start in range(0, len(ids), 100):
|
||||||
|
batch_ids = ids[start:start + 100]
|
||||||
|
batch_docs = docs[start:start + 100]
|
||||||
|
batch_metas = metas[start:start + 100]
|
||||||
|
batch_embeddings = old_embeddings[start:start + 100]
|
||||||
|
if len(batch_metas) < len(batch_ids):
|
||||||
|
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||||
|
restored.add(
|
||||||
|
ids=batch_ids,
|
||||||
|
documents=batch_docs,
|
||||||
|
metadatas=batch_metas,
|
||||||
|
embeddings=batch_embeddings,
|
||||||
|
)
|
||||||
|
except Exception as restore_error:
|
||||||
|
logger.warning("Could not restore previous collection %s: %s", name, restore_error)
|
||||||
|
raise RuntimeError(f"Could not write reset collection {name}: {e}") from e
|
||||||
|
if prepared_batches:
|
||||||
|
logger.info("Re-embedded %s rows after resetting %s", len(ids), name)
|
||||||
|
|
||||||
|
return collection
|
||||||
|
|
||||||
|
|
||||||
|
def _create_lane(chroma_client, base_name: str, lane_name: str, client: Any) -> EmbeddingLane:
|
||||||
|
dimension = int(client.get_sentence_embedding_dimension())
|
||||||
|
model = getattr(client, "model", "")
|
||||||
|
url = getattr(client, "url", "")
|
||||||
|
fp = _fingerprint(lane_name, url, model, dimension)
|
||||||
|
name = collection_name(base_name, lane_name)
|
||||||
|
metadata = _metadata(lane_name, url, model, dimension, fp)
|
||||||
|
collection = _get_or_reset_collection(chroma_client, name, metadata, client)
|
||||||
|
return EmbeddingLane(
|
||||||
|
name=lane_name,
|
||||||
|
client=client,
|
||||||
|
collection=collection,
|
||||||
|
collection_name=name,
|
||||||
|
model=model,
|
||||||
|
url=url,
|
||||||
|
dimension=dimension,
|
||||||
|
fingerprint=fp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_embedding_lanes(base_name: str) -> List[EmbeddingLane]:
|
||||||
|
"""Return healthy lanes in retrieval preference order: custom, fastembed."""
|
||||||
|
from src.chroma_client import get_chroma_client
|
||||||
|
|
||||||
|
chroma_client = get_chroma_client()
|
||||||
|
lanes: List[EmbeddingLane] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
custom = _build_custom_client()
|
||||||
|
if custom is not None:
|
||||||
|
lanes.append(_create_lane(chroma_client, base_name, LANE_CUSTOM, custom))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Custom embedding lane unavailable for %s: %s", base_name, e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fastembed = _build_fastembed_client()
|
||||||
|
lanes.append(_create_lane(chroma_client, base_name, LANE_FASTEMBED, fastembed))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("FastEmbed lane unavailable for %s: %s", base_name, e)
|
||||||
|
|
||||||
|
return lanes
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_legacy_collection(base_name: str, lanes: Sequence[EmbeddingLane]) -> None:
|
||||||
|
"""Backfill empty lanes from a legacy unsuffixed collection, if present."""
|
||||||
|
if not lanes:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.chroma_client import get_chroma_client
|
||||||
|
|
||||||
|
chroma_client = get_chroma_client()
|
||||||
|
legacy = chroma_client.get_collection(base_name)
|
||||||
|
data = legacy.get(include=["documents", "metadatas"])
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
ids = data.get("ids") or []
|
||||||
|
docs = data.get("documents") or []
|
||||||
|
metas = data.get("metadatas") or []
|
||||||
|
if not ids or not docs:
|
||||||
|
return
|
||||||
|
|
||||||
|
for lane in lanes:
|
||||||
|
try:
|
||||||
|
existing = lane.collection.get(ids=ids)
|
||||||
|
existing_ids = set(existing.get("ids") or [])
|
||||||
|
except Exception:
|
||||||
|
existing_ids = set()
|
||||||
|
all_metas = list(metas or [])
|
||||||
|
if len(all_metas) < len(ids):
|
||||||
|
all_metas += [{}] * (len(ids) - len(all_metas))
|
||||||
|
missing = [
|
||||||
|
(row_id, doc, meta)
|
||||||
|
for row_id, doc, meta in zip(ids, docs, all_metas)
|
||||||
|
if row_id not in existing_ids
|
||||||
|
]
|
||||||
|
if not missing:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for start in range(0, len(missing), 100):
|
||||||
|
batch = missing[start:start + 100]
|
||||||
|
batch_ids = [row_id for row_id, _doc, _meta in batch]
|
||||||
|
batch_docs = [doc for _row_id, doc, _meta in batch]
|
||||||
|
batch_metas = [meta or {} for _row_id, _doc, meta in batch]
|
||||||
|
if len(batch_metas) < len(batch_ids):
|
||||||
|
batch_metas += [{}] * (len(batch_ids) - len(batch_metas))
|
||||||
|
try:
|
||||||
|
embeddings = lane.encode(batch_docs)
|
||||||
|
lane.collection.add(
|
||||||
|
ids=batch_ids,
|
||||||
|
documents=batch_docs,
|
||||||
|
metadatas=batch_metas,
|
||||||
|
embeddings=embeddings,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Could not backfill %s lane from legacy collection %s: %s",
|
||||||
|
lane.name,
|
||||||
|
base_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.info("Backfilled %s %s lane rows from legacy collection %s", len(missing), lane.name, base_name)
|
||||||
|
|
||||||
|
|
||||||
|
def lane_count(lanes: Sequence[EmbeddingLane]) -> int:
|
||||||
|
return max((lane.count() for lane in lanes), default=0)
|
||||||
|
|
||||||
|
|
||||||
|
def dedupe_results(results: Iterable[Dict[str, Any]], id_key: str = "id", limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||||
|
seen = set()
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for row in results:
|
||||||
|
row_id = row.get(id_key)
|
||||||
|
if not row_id or row_id in seen:
|
||||||
|
continue
|
||||||
|
seen.add(row_id)
|
||||||
|
out.append(row)
|
||||||
|
if limit is not None and len(out) >= limit:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def query_lanes(
|
||||||
|
lanes: Sequence[EmbeddingLane],
|
||||||
|
query: str,
|
||||||
|
n_results: Callable[[EmbeddingLane], int],
|
||||||
|
include: Sequence[str],
|
||||||
|
where: Optional[Dict[str, Any]] = None,
|
||||||
|
raise_if_all_failed: bool = False,
|
||||||
|
) -> List[tuple[EmbeddingLane, Dict[str, Any]]]:
|
||||||
|
out: List[tuple[EmbeddingLane, Dict[str, Any]]] = []
|
||||||
|
attempted = 0
|
||||||
|
failures: List[str] = []
|
||||||
|
for lane in lanes:
|
||||||
|
try:
|
||||||
|
count = lane.count()
|
||||||
|
if count == 0:
|
||||||
|
continue
|
||||||
|
attempted += 1
|
||||||
|
n = min(n_results(lane), count)
|
||||||
|
if n <= 0:
|
||||||
|
continue
|
||||||
|
results = lane.collection.query(
|
||||||
|
query_embeddings=lane.encode([query]),
|
||||||
|
n_results=n,
|
||||||
|
where=where,
|
||||||
|
include=list(include),
|
||||||
|
)
|
||||||
|
out.append((lane, results))
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{lane.name}: {e}")
|
||||||
|
logger.warning("%s lane query failed for %s: %s", lane.name, lane.collection_name, e)
|
||||||
|
if raise_if_all_failed and attempted and not out and failures:
|
||||||
|
raise RuntimeError("; ".join(failures))
|
||||||
|
return out
|
||||||
+153
-77
@@ -9,6 +9,16 @@ Stores pre-computed embeddings (ChromaDB does not manage embedding).
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
from src.embedding_lanes import (
|
||||||
|
LANE_CUSTOM,
|
||||||
|
LANE_FASTEMBED,
|
||||||
|
build_embedding_lanes,
|
||||||
|
collection_name,
|
||||||
|
dedupe_results,
|
||||||
|
lane_count,
|
||||||
|
migrate_legacy_collection,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -20,30 +30,28 @@ class MemoryVectorStore:
|
|||||||
def __init__(self, data_dir: str, embedding_model=None):
|
def __init__(self, data_dir: str, embedding_model=None):
|
||||||
self._model = embedding_model
|
self._model = embedding_model
|
||||||
self._collection = None
|
self._collection = None
|
||||||
|
self._lanes = []
|
||||||
self._healthy = False
|
self._healthy = False
|
||||||
|
|
||||||
self._initialize()
|
self._initialize()
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
try:
|
try:
|
||||||
from src.chroma_client import get_chroma_client
|
self._lanes = build_embedding_lanes(self.COLLECTION_NAME)
|
||||||
|
if not self._lanes:
|
||||||
if self._model is None:
|
raise RuntimeError("No embedding lanes available")
|
||||||
from src.embeddings import get_embedding_client
|
|
||||||
self._model = get_embedding_client()
|
|
||||||
if self._model is None:
|
|
||||||
raise RuntimeError("No embedding backend available")
|
|
||||||
logger.info(f"MemoryVectorStore using embeddings: {self._model.url}")
|
|
||||||
|
|
||||||
client = get_chroma_client()
|
|
||||||
self._collection = client.get_or_create_collection(
|
|
||||||
name=self.COLLECTION_NAME,
|
|
||||||
metadata={"hnsw:space": "cosine"},
|
|
||||||
)
|
|
||||||
|
|
||||||
self._healthy = True
|
self._healthy = True
|
||||||
count = self._collection.count()
|
self._collection = next(
|
||||||
logger.info(f"MemoryVectorStore ready (entries={count})")
|
(lane.collection for lane in self._lanes if lane.name == LANE_FASTEMBED),
|
||||||
|
self._lanes[0].collection,
|
||||||
|
)
|
||||||
|
migrate_legacy_collection(self.COLLECTION_NAME, self._lanes)
|
||||||
|
logger.info(
|
||||||
|
"MemoryVectorStore ready (lanes=%s entries=%s)",
|
||||||
|
[lane.name for lane in self._lanes],
|
||||||
|
self.count(),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MemoryVectorStore init failed: {e}")
|
logger.error(f"MemoryVectorStore init failed: {e}")
|
||||||
@@ -53,39 +61,73 @@ class MemoryVectorStore:
|
|||||||
return self._healthy
|
return self._healthy
|
||||||
|
|
||||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
vecs = self._model.encode(texts, normalize_embeddings=True)
|
if not self._lanes:
|
||||||
return vecs.tolist()
|
return []
|
||||||
|
return self._lanes[0].encode(texts)
|
||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""Return the number of stored vectors."""
|
"""Return the number of stored vectors."""
|
||||||
if not self._healthy:
|
if not self._healthy:
|
||||||
return 0
|
return 0
|
||||||
return self._collection.count()
|
return lane_count(self._lanes)
|
||||||
|
|
||||||
|
def _collections_for_delete(self):
|
||||||
|
collections = []
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
def add(collection) -> None:
|
||||||
|
if collection is None:
|
||||||
|
return
|
||||||
|
key = getattr(collection, "name", None) or id(collection)
|
||||||
|
if key in seen:
|
||||||
|
return
|
||||||
|
seen.add(key)
|
||||||
|
collections.append(collection)
|
||||||
|
|
||||||
|
for lane in self._lanes:
|
||||||
|
add(lane.collection)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.chroma_client import get_chroma_client
|
||||||
|
|
||||||
|
client = get_chroma_client()
|
||||||
|
for lane_name in (LANE_CUSTOM, LANE_FASTEMBED):
|
||||||
|
try:
|
||||||
|
add(client.get_collection(collection_name(self.COLLECTION_NAME, lane_name)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return collections
|
||||||
|
|
||||||
def add(self, memory_id: str, text: str):
|
def add(self, memory_id: str, text: str):
|
||||||
"""Add a single memory entry to the vector index."""
|
"""Add a single memory entry to the vector index."""
|
||||||
if not self._healthy:
|
if not self._healthy:
|
||||||
return
|
return
|
||||||
# Skip if already exists
|
for lane in self._lanes:
|
||||||
existing = self._collection.get(ids=[memory_id])
|
try:
|
||||||
if existing["ids"]:
|
existing = lane.collection.get(ids=[memory_id])
|
||||||
return
|
if existing["ids"]:
|
||||||
embeddings = self._embed([text])
|
continue
|
||||||
self._collection.add(
|
lane.collection.add(
|
||||||
ids=[memory_id],
|
ids=[memory_id],
|
||||||
embeddings=embeddings,
|
embeddings=lane.encode([text]),
|
||||||
documents=[text],
|
documents=[text],
|
||||||
metadatas=[{"source": "memory"}],
|
metadatas=[{"source": "memory"}],
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("memory add failed in %s lane for %s: %s", lane.name, memory_id, e)
|
||||||
|
|
||||||
def remove(self, memory_id: str):
|
def remove(self, memory_id: str):
|
||||||
"""Remove a memory entry. O(1) — no rebuild needed."""
|
"""Remove a memory entry. O(1) — no rebuild needed."""
|
||||||
if not self._healthy:
|
if not self._healthy:
|
||||||
return
|
return
|
||||||
try:
|
for collection in self._collections_for_delete():
|
||||||
self._collection.delete(ids=[memory_id])
|
try:
|
||||||
except Exception as e:
|
collection.delete(ids=[memory_id])
|
||||||
logger.warning(f"memory remove {memory_id}: {e}")
|
except Exception as e:
|
||||||
|
logger.warning(f"memory remove {memory_id}: {e}")
|
||||||
|
|
||||||
def search(self, query: str, k: int = 8) -> List[Dict]:
|
def search(self, query: str, k: int = 8) -> List[Dict]:
|
||||||
"""Search for the most relevant memory IDs by semantic similarity.
|
"""Search for the most relevant memory IDs by semantic similarity.
|
||||||
@@ -94,41 +136,53 @@ class MemoryVectorStore:
|
|||||||
ChromaDB cosine distance = 1 - cosine_similarity.
|
ChromaDB cosine distance = 1 - cosine_similarity.
|
||||||
We convert back: similarity = 1.0 - distance.
|
We convert back: similarity = 1.0 - distance.
|
||||||
"""
|
"""
|
||||||
if not self._healthy or self._collection.count() == 0:
|
if not self._healthy or self.count() == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
embeddings = self._embed([query])
|
|
||||||
actual_k = min(k, self._collection.count())
|
|
||||||
results = self._collection.query(
|
|
||||||
query_embeddings=embeddings,
|
|
||||||
n_results=actual_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
for idx, mid in enumerate(results["ids"][0]):
|
lane_priority = {LANE_CUSTOM: 0, LANE_FASTEMBED: 1}
|
||||||
distance = results["distances"][0][idx]
|
for lane in self._lanes:
|
||||||
out.append({
|
try:
|
||||||
"memory_id": mid,
|
if lane.count() == 0:
|
||||||
"score": round(1.0 - distance, 4),
|
continue
|
||||||
})
|
results = lane.collection.query(
|
||||||
return out
|
query_embeddings=lane.encode([query]),
|
||||||
|
n_results=min(k, lane.count()),
|
||||||
|
include=["distances"],
|
||||||
|
)
|
||||||
|
for idx, mid in enumerate(results["ids"][0]):
|
||||||
|
distance = results["distances"][0][idx]
|
||||||
|
out.append({
|
||||||
|
"memory_id": mid,
|
||||||
|
"score": round(1.0 - distance, 4),
|
||||||
|
"embedding_lane": lane.name,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("memory search failed in %s lane: %s", lane.name, e)
|
||||||
|
out.sort(key=lambda row: (-row["score"], lane_priority.get(row["embedding_lane"], 99)))
|
||||||
|
return dedupe_results(out, id_key="memory_id", limit=k)
|
||||||
|
|
||||||
def find_similar(self, text: str, threshold: float = 0.92) -> Optional[str]:
|
def find_similar(self, text: str, threshold: float = 0.92) -> Optional[str]:
|
||||||
"""Check if a near-duplicate exists. Returns memory_id if found, else None."""
|
"""Check if a near-duplicate exists. Returns memory_id if found, else None."""
|
||||||
if not self._healthy or self._collection.count() == 0:
|
if not self._healthy or self.count() == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
embeddings = self._embed([text])
|
for lane in self._lanes:
|
||||||
results = self._collection.query(
|
try:
|
||||||
query_embeddings=embeddings,
|
if lane.count() == 0:
|
||||||
n_results=1,
|
continue
|
||||||
)
|
results = lane.collection.query(
|
||||||
|
query_embeddings=lane.encode([text]),
|
||||||
if results["ids"][0]:
|
n_results=1,
|
||||||
distance = results["distances"][0][0]
|
include=["distances"],
|
||||||
similarity = 1.0 - distance
|
)
|
||||||
if similarity >= threshold:
|
if results["ids"][0]:
|
||||||
return results["ids"][0][0]
|
distance = results["distances"][0][0]
|
||||||
|
similarity = 1.0 - distance
|
||||||
|
if similarity >= threshold:
|
||||||
|
return results["ids"][0][0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("memory similarity search failed in %s lane: %s", lane.name, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def rebuild(self, memories: List[Dict]):
|
def rebuild(self, memories: List[Dict]):
|
||||||
@@ -139,15 +193,23 @@ class MemoryVectorStore:
|
|||||||
|
|
||||||
from src.chroma_client import get_chroma_client
|
from src.chroma_client import get_chroma_client
|
||||||
|
|
||||||
# Delete and recreate collection for a clean rebuild
|
|
||||||
client = get_chroma_client()
|
client = get_chroma_client()
|
||||||
try:
|
lane_names = [
|
||||||
client.delete_collection(self.COLLECTION_NAME)
|
self.COLLECTION_NAME,
|
||||||
except Exception:
|
collection_name(self.COLLECTION_NAME, LANE_CUSTOM),
|
||||||
pass
|
collection_name(self.COLLECTION_NAME, LANE_FASTEMBED),
|
||||||
self._collection = client.get_or_create_collection(
|
]
|
||||||
name=self.COLLECTION_NAME,
|
for name in lane_names:
|
||||||
metadata={"hnsw:space": "cosine"},
|
try:
|
||||||
|
client.delete_collection(name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Explicit rebuilds must start from the supplied memory list, so clear
|
||||||
|
# legacy unsuffixed collections too.
|
||||||
|
self._lanes = build_embedding_lanes(self.COLLECTION_NAME)
|
||||||
|
self._collection = next(
|
||||||
|
(lane.collection for lane in self._lanes if lane.name == LANE_FASTEMBED),
|
||||||
|
self._lanes[0].collection if self._lanes else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
texts = []
|
texts = []
|
||||||
@@ -161,15 +223,29 @@ class MemoryVectorStore:
|
|||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
# Batch in chunks of 100 to avoid oversized requests
|
# Batch in chunks of 100 to avoid oversized requests
|
||||||
|
failed_lanes = set()
|
||||||
for i in range(0, len(texts), 100):
|
for i in range(0, len(texts), 100):
|
||||||
batch_texts = texts[i:i + 100]
|
batch_texts = texts[i:i + 100]
|
||||||
batch_ids = ids[i:i + 100]
|
batch_ids = ids[i:i + 100]
|
||||||
embeddings = self._embed(batch_texts)
|
for lane in self._lanes:
|
||||||
self._collection.add(
|
if lane.name in failed_lanes:
|
||||||
ids=batch_ids,
|
continue
|
||||||
embeddings=embeddings,
|
try:
|
||||||
documents=batch_texts,
|
lane.collection.add(
|
||||||
metadatas=[{"source": "memory"}] * len(batch_ids),
|
ids=batch_ids,
|
||||||
)
|
embeddings=lane.encode(batch_texts),
|
||||||
|
documents=batch_texts,
|
||||||
|
metadatas=[{"source": "memory"}] * len(batch_ids),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
failed_lanes.add(lane.name)
|
||||||
|
logger.warning("memory rebuild failed in %s lane: %s", lane.name, e)
|
||||||
|
|
||||||
logger.info(f"MemoryVectorStore rebuilt with {len(ids)} entries")
|
logger.info(f"MemoryVectorStore rebuilt with {len(ids)} entries across {len(self._lanes)} lanes")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
return {
|
||||||
|
"healthy": self.healthy,
|
||||||
|
"count": self.count(),
|
||||||
|
"lanes": [lane.stats() for lane in self._lanes],
|
||||||
|
}
|
||||||
|
|||||||
+232
-153
@@ -14,6 +14,17 @@ import numpy as np
|
|||||||
from typing import List, Dict, Any, Optional, Set
|
from typing import List, Dict, Any, Optional, Set
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.embedding_lanes import (
|
||||||
|
LANE_CUSTOM,
|
||||||
|
LANE_FASTEMBED,
|
||||||
|
build_embedding_lanes,
|
||||||
|
collection_name,
|
||||||
|
dedupe_results,
|
||||||
|
lane_count,
|
||||||
|
migrate_legacy_collection,
|
||||||
|
query_lanes,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_FILE_EXTENSIONS: Set[str] = {
|
DEFAULT_FILE_EXTENSIONS: Set[str] = {
|
||||||
@@ -44,6 +55,7 @@ class VectorRAG:
|
|||||||
self.persist_directory = persist_directory
|
self.persist_directory = persist_directory
|
||||||
self._collection = None
|
self._collection = None
|
||||||
self._model = None
|
self._model = None
|
||||||
|
self._lanes = []
|
||||||
self._healthy = False
|
self._healthy = False
|
||||||
|
|
||||||
Path(self.persist_directory).mkdir(parents=True, exist_ok=True)
|
Path(self.persist_directory).mkdir(parents=True, exist_ok=True)
|
||||||
@@ -55,22 +67,20 @@ class VectorRAG:
|
|||||||
|
|
||||||
def _initialize_system(self) -> bool:
|
def _initialize_system(self) -> bool:
|
||||||
try:
|
try:
|
||||||
from src.chroma_client import get_chroma_client
|
self._lanes = build_embedding_lanes(COLLECTION_NAME)
|
||||||
from src.embeddings import get_embedding_client
|
if not self._lanes:
|
||||||
|
raise RuntimeError("No embedding lanes available")
|
||||||
self._model = get_embedding_client()
|
self._collection = next(
|
||||||
if self._model is None:
|
(lane.collection for lane in self._lanes if lane.name == LANE_FASTEMBED),
|
||||||
raise RuntimeError("No embedding backend available")
|
self._lanes[0].collection,
|
||||||
logger.info(f"Embedding: {self._model.url} model={self._model.model}")
|
)
|
||||||
|
self._model = self._lanes[0].client
|
||||||
client = get_chroma_client()
|
migrate_legacy_collection(COLLECTION_NAME, self._lanes)
|
||||||
self._collection = client.get_or_create_collection(
|
logger.info(
|
||||||
name=COLLECTION_NAME,
|
"VectorRAG ready (lanes=%s docs=%s)",
|
||||||
metadata={"hnsw:space": "cosine"},
|
[lane.name for lane in self._lanes],
|
||||||
|
lane_count(self._lanes),
|
||||||
)
|
)
|
||||||
|
|
||||||
count = self._collection.count()
|
|
||||||
logger.info(f"VectorRAG ready ({count} docs)")
|
|
||||||
self._healthy = True
|
self._healthy = True
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -80,8 +90,9 @@ class VectorRAG:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
vecs = self._model.encode(texts, normalize_embeddings=True)
|
if not self._lanes:
|
||||||
return np.array(vecs, dtype=np.float32).tolist()
|
return []
|
||||||
|
return np.array(self._lanes[0].encode(texts), dtype=np.float32).tolist()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Properties
|
# Properties
|
||||||
@@ -89,13 +100,57 @@ class VectorRAG:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def healthy(self) -> bool:
|
def healthy(self) -> bool:
|
||||||
return self._healthy and self._collection is not None
|
if getattr(self, "_lanes", None):
|
||||||
|
return self._healthy and bool(self._lanes)
|
||||||
|
return self._healthy and getattr(self, "_collection", None) is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def collection(self):
|
def collection(self):
|
||||||
"""Expose the ChromaDB collection for direct access by personal_routes etc."""
|
"""Expose the ChromaDB collection for direct access by personal_routes etc."""
|
||||||
return self._collection
|
return self._collection
|
||||||
|
|
||||||
|
def _active_collections(self):
|
||||||
|
lanes = getattr(self, "_lanes", None)
|
||||||
|
if lanes:
|
||||||
|
return [(lane.name, lane.collection) for lane in lanes]
|
||||||
|
collection = getattr(self, "_collection", None)
|
||||||
|
return [("legacy", collection)] if collection is not None else []
|
||||||
|
|
||||||
|
def _collections_for_delete(self):
|
||||||
|
collections = []
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
def add(lane_name: str, collection) -> None:
|
||||||
|
if collection is None:
|
||||||
|
return
|
||||||
|
key = getattr(collection, "name", None) or id(collection)
|
||||||
|
if key in seen:
|
||||||
|
return
|
||||||
|
seen.add(key)
|
||||||
|
collections.append((lane_name, collection))
|
||||||
|
|
||||||
|
for lane_name, collection in self._active_collections():
|
||||||
|
add(lane_name, collection)
|
||||||
|
|
||||||
|
if getattr(self, "_lanes", None):
|
||||||
|
try:
|
||||||
|
from src.chroma_client import get_chroma_client
|
||||||
|
|
||||||
|
client = get_chroma_client()
|
||||||
|
try:
|
||||||
|
add("legacy", client.get_collection(COLLECTION_NAME))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for lane_name in (LANE_CUSTOM, LANE_FASTEMBED):
|
||||||
|
try:
|
||||||
|
add(lane_name, client.get_collection(collection_name(COLLECTION_NAME, lane_name)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return collections
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Document operations
|
# Document operations
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -109,23 +164,24 @@ class VectorRAG:
|
|||||||
if not metadata or not isinstance(metadata, dict):
|
if not metadata or not isinstance(metadata, dict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
doc_id = _generate_doc_id(text, metadata.get("owner") or "")
|
||||||
doc_id = _generate_doc_id(text, metadata.get("owner") or "")
|
wrote = False
|
||||||
# Check if already exists
|
for lane in self._lanes:
|
||||||
existing = self._collection.get(ids=[doc_id])
|
try:
|
||||||
if existing["ids"]:
|
existing = lane.collection.get(ids=[doc_id])
|
||||||
return True # already exists
|
if existing["ids"]:
|
||||||
embeddings = self._embed([text])
|
wrote = True
|
||||||
self._collection.add(
|
continue
|
||||||
ids=[doc_id],
|
lane.collection.add(
|
||||||
embeddings=embeddings,
|
ids=[doc_id],
|
||||||
documents=[text],
|
embeddings=lane.encode([text]),
|
||||||
metadatas=[metadata],
|
documents=[text],
|
||||||
)
|
metadatas=[metadata],
|
||||||
return True
|
)
|
||||||
except Exception as e:
|
wrote = True
|
||||||
logger.error(f"add_document failed: {e}")
|
except Exception as e:
|
||||||
return False
|
logger.warning("add_document failed in %s lane: %s", lane.name, e)
|
||||||
|
return wrote
|
||||||
|
|
||||||
def add_documents_batch(self, docs: List[tuple]) -> Dict[str, Any]:
|
def add_documents_batch(self, docs: List[tuple]) -> Dict[str, Any]:
|
||||||
if not self.healthy:
|
if not self.healthy:
|
||||||
@@ -140,42 +196,57 @@ class VectorRAG:
|
|||||||
if not valid:
|
if not valid:
|
||||||
return {"success": False, "message": "No valid documents"}
|
return {"success": False, "message": "No valid documents"}
|
||||||
|
|
||||||
try:
|
added_ids = set()
|
||||||
# Get existing IDs to avoid duplicates
|
attempted_new = False
|
||||||
|
write_failed = False
|
||||||
|
for lane in self._lanes:
|
||||||
|
all_ids = [_generate_doc_id(t, m.get("owner") or "") for t, m in valid]
|
||||||
|
try:
|
||||||
|
existing = lane.collection.get(ids=all_ids)
|
||||||
|
existing_ids = set(existing.get("ids") or [])
|
||||||
|
except Exception:
|
||||||
|
existing_ids = set()
|
||||||
|
|
||||||
new_texts = []
|
new_texts = []
|
||||||
new_metas = []
|
new_metas = []
|
||||||
new_ids = []
|
new_ids = []
|
||||||
for t, m in valid:
|
for (text, meta), doc_id in zip(valid, all_ids):
|
||||||
doc_id = _generate_doc_id(t, m.get("owner") or "")
|
if doc_id not in existing_ids:
|
||||||
existing = self._collection.get(ids=[doc_id])
|
new_texts.append(text)
|
||||||
if not existing["ids"]:
|
new_metas.append(meta)
|
||||||
new_texts.append(t)
|
|
||||||
new_metas.append(m)
|
|
||||||
new_ids.append(doc_id)
|
new_ids.append(doc_id)
|
||||||
|
|
||||||
if new_texts:
|
if new_texts:
|
||||||
# Batch in chunks of 100
|
attempted_new = True
|
||||||
|
lane_failed = False
|
||||||
for i in range(0, len(new_texts), 100):
|
for i in range(0, len(new_texts), 100):
|
||||||
batch_texts = new_texts[i:i + 100]
|
batch_texts = new_texts[i:i + 100]
|
||||||
batch_ids = new_ids[i:i + 100]
|
batch_ids = new_ids[i:i + 100]
|
||||||
batch_metas = new_metas[i:i + 100]
|
batch_metas = new_metas[i:i + 100]
|
||||||
embeddings = self._embed(batch_texts)
|
try:
|
||||||
self._collection.add(
|
lane.collection.add(
|
||||||
ids=batch_ids,
|
ids=batch_ids,
|
||||||
embeddings=embeddings,
|
embeddings=lane.encode(batch_texts),
|
||||||
documents=batch_texts,
|
documents=batch_texts,
|
||||||
metadatas=batch_metas,
|
metadatas=batch_metas,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
lane_failed = True
|
||||||
|
write_failed = True
|
||||||
|
logger.warning("add_documents_batch failed in %s lane: %s", lane.name, e)
|
||||||
|
break
|
||||||
|
if not lane_failed:
|
||||||
|
added_ids.update(new_ids)
|
||||||
|
|
||||||
return {
|
if attempted_new and write_failed and not added_ids:
|
||||||
"success": True,
|
return {"success": False, "message": "No embedding lane accepted the batch"}
|
||||||
"added_count": len(new_texts),
|
|
||||||
"total_count": len(docs),
|
return {
|
||||||
"failed_count": len(docs) - len(valid),
|
"success": True,
|
||||||
}
|
"added_count": len(added_ids),
|
||||||
except Exception as e:
|
"total_count": len(docs),
|
||||||
logger.error(f"add_documents_batch failed: {e}")
|
"failed_count": len(docs) - len(valid),
|
||||||
return {"success": False, "message": str(e)}
|
}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Search — hybrid: vector similarity + keyword overlap
|
# Search — hybrid: vector similarity + keyword overlap
|
||||||
@@ -186,58 +257,51 @@ class VectorRAG:
|
|||||||
return []
|
return []
|
||||||
if not query or not isinstance(query, str):
|
if not query or not isinstance(query, str):
|
||||||
return []
|
return []
|
||||||
if self._collection.count() == 0:
|
if lane_count(self._lanes) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch extra candidates when owner-filtering
|
|
||||||
fetch_k = min(k * 3, max(k, 20), self._collection.count())
|
|
||||||
if owner:
|
|
||||||
fetch_k = min(fetch_k * 2, self._collection.count())
|
|
||||||
|
|
||||||
query_embeddings = self._embed([query])
|
|
||||||
|
|
||||||
# Use ChromaDB where filter for owner if specified
|
|
||||||
where_filter = {"owner": owner} if owner else None
|
where_filter = {"owner": owner} if owner else None
|
||||||
|
|
||||||
results = self._collection.query(
|
|
||||||
query_embeddings=query_embeddings,
|
|
||||||
n_results=fetch_k,
|
|
||||||
where=where_filter,
|
|
||||||
include=["documents", "metadatas", "distances"],
|
|
||||||
)
|
|
||||||
|
|
||||||
query_words = set(query.lower().split())
|
query_words = set(query.lower().split())
|
||||||
candidates = []
|
candidates = []
|
||||||
|
|
||||||
for idx in range(len(results["ids"][0])):
|
for lane, results in query_lanes(
|
||||||
doc_id = results["ids"][0][idx]
|
self._lanes,
|
||||||
distance = results["distances"][0][idx]
|
query,
|
||||||
doc_text = results["documents"][0][idx]
|
n_results=lambda lane: min(
|
||||||
meta = results["metadatas"][0][idx]
|
(k * 6 if owner else k * 3),
|
||||||
|
max(k, 20),
|
||||||
|
lane.count(),
|
||||||
|
),
|
||||||
|
where=where_filter,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
raise_if_all_failed=True,
|
||||||
|
):
|
||||||
|
for idx in range(len(results["ids"][0])):
|
||||||
|
doc_id = results["ids"][0][idx]
|
||||||
|
distance = results["distances"][0][idx]
|
||||||
|
doc_text = results["documents"][0][idx]
|
||||||
|
meta = results["metadatas"][0][idx]
|
||||||
|
|
||||||
# ChromaDB cosine distance = 1 - cosine_similarity
|
vector_sim = 1.0 - distance
|
||||||
vector_sim = 1.0 - distance
|
doc_words = set(doc_text.lower().split())
|
||||||
|
overlap = len(query_words & doc_words)
|
||||||
|
keyword_score = overlap / len(query_words) if query_words else 0.0
|
||||||
|
hybrid_score = (VECTOR_WEIGHT * vector_sim) + (KEYWORD_WEIGHT * keyword_score)
|
||||||
|
|
||||||
# Keyword overlap score
|
candidates.append({
|
||||||
doc_words = set(doc_text.lower().split())
|
"id": doc_id,
|
||||||
overlap = len(query_words & doc_words)
|
"document": doc_text,
|
||||||
keyword_score = overlap / len(query_words) if query_words else 0.0
|
"metadata": meta,
|
||||||
|
"distance": round(distance, 4),
|
||||||
hybrid_score = (VECTOR_WEIGHT * vector_sim) + (KEYWORD_WEIGHT * keyword_score)
|
"similarity": round(hybrid_score, 4),
|
||||||
|
"vector_similarity": round(vector_sim, 4),
|
||||||
candidates.append({
|
"keyword_score": round(keyword_score, 4),
|
||||||
"id": doc_id,
|
"embedding_lane": lane.name,
|
||||||
"document": doc_text,
|
})
|
||||||
"metadata": meta,
|
|
||||||
"distance": round(distance, 4),
|
|
||||||
"similarity": round(hybrid_score, 4),
|
|
||||||
"vector_similarity": round(vector_sim, 4),
|
|
||||||
"keyword_score": round(keyword_score, 4),
|
|
||||||
})
|
|
||||||
|
|
||||||
candidates.sort(key=lambda c: c["similarity"], reverse=True)
|
candidates.sort(key=lambda c: c["similarity"], reverse=True)
|
||||||
top = candidates[:k]
|
top = dedupe_results(candidates, limit=k)
|
||||||
logger.info(f"Hybrid search for '{query[:60]}': {len(top)} results")
|
logger.info(f"Hybrid search for '{query[:60]}': {len(top)} results")
|
||||||
return top
|
return top
|
||||||
|
|
||||||
@@ -247,39 +311,36 @@ class VectorRAG:
|
|||||||
|
|
||||||
def _keyword_search_fallback(self, query: str, k: int = 5, owner: Optional[str] = None) -> List[Dict[str, Any]]:
|
def _keyword_search_fallback(self, query: str, k: int = 5, owner: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
if self._collection.count() == 0:
|
if not self._active_collections():
|
||||||
return []
|
|
||||||
|
|
||||||
# Fetch all documents for keyword search fallback
|
|
||||||
all_docs = self._collection.get(include=["documents", "metadatas"])
|
|
||||||
if not all_docs["ids"]:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_words = query.lower().split()
|
query_words = query.lower().split()
|
||||||
scored = []
|
scored = []
|
||||||
for i, doc in enumerate(all_docs["documents"]):
|
for lane_name, collection in self._active_collections():
|
||||||
meta = all_docs["metadatas"][i]
|
if collection.count() == 0:
|
||||||
if owner:
|
continue
|
||||||
# Match the primary path's strict where={"owner": owner}
|
all_docs = collection.get(include=["documents", "metadatas"])
|
||||||
# filter. The old `if doc_owner and doc_owner != owner`
|
if not all_docs["ids"]:
|
||||||
# let docs with a missing/empty owner fall through, leaking
|
continue
|
||||||
# owner-less documents into another user's results.
|
for i, doc in enumerate(all_docs["documents"]):
|
||||||
if meta.get("owner") != owner:
|
meta = all_docs["metadatas"][i]
|
||||||
|
if owner and meta.get("owner") != owner:
|
||||||
continue
|
continue
|
||||||
doc_lower = doc.lower()
|
doc_lower = doc.lower()
|
||||||
score = sum(1 for w in query_words if w in doc_lower)
|
score = sum(1 for w in query_words if w in doc_lower)
|
||||||
if score > 0:
|
if score > 0:
|
||||||
scored.append({
|
scored.append({
|
||||||
"id": all_docs["ids"][i],
|
"id": all_docs["ids"][i],
|
||||||
"document": doc,
|
"document": doc,
|
||||||
"metadata": meta,
|
"metadata": meta,
|
||||||
"distance": 0,
|
"distance": 0,
|
||||||
"similarity": score,
|
"similarity": score,
|
||||||
"search_type": "keyword_fallback",
|
"search_type": "keyword_fallback",
|
||||||
})
|
"embedding_lane": lane_name,
|
||||||
|
})
|
||||||
|
|
||||||
scored.sort(key=lambda x: x["similarity"], reverse=True)
|
scored.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
return scored[:k]
|
return dedupe_results(scored, limit=k)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"keyword fallback failed: {e}")
|
logger.error(f"keyword fallback failed: {e}")
|
||||||
return []
|
return []
|
||||||
@@ -296,9 +357,20 @@ class VectorRAG:
|
|||||||
client.delete_collection(COLLECTION_NAME)
|
client.delete_collection(COLLECTION_NAME)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._collection = client.get_or_create_collection(
|
for name in (
|
||||||
name=COLLECTION_NAME,
|
collection_name(COLLECTION_NAME, LANE_CUSTOM),
|
||||||
metadata={"hnsw:space": "cosine"},
|
collection_name(COLLECTION_NAME, LANE_FASTEMBED),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
client.delete_collection(name)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Rebuild means empty current lanes. Clear the legacy unsuffixed
|
||||||
|
# collection too so startup migration cannot resurrect stale docs.
|
||||||
|
self._lanes = build_embedding_lanes(COLLECTION_NAME)
|
||||||
|
self._collection = next(
|
||||||
|
(lane.collection for lane in self._lanes if lane.name == LANE_FASTEMBED),
|
||||||
|
self._lanes[0].collection if self._lanes else None,
|
||||||
)
|
)
|
||||||
self._healthy = True
|
self._healthy = True
|
||||||
return True
|
return True
|
||||||
@@ -312,10 +384,11 @@ class VectorRAG:
|
|||||||
return {"error": "Collection not initialized"}
|
return {"error": "Collection not initialized"}
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
"document_count": self._collection.count(),
|
"document_count": lane_count(self._lanes),
|
||||||
"embedding_model": f"{self._model.model} @ {self._model.url}" if self._model else "N/A",
|
"embedding_model": f"{self._lanes[0].model} @ {self._lanes[0].url}" if self._lanes else "N/A",
|
||||||
"persist_directory": self.persist_directory,
|
"persist_directory": self.persist_directory,
|
||||||
"collection_name": COLLECTION_NAME,
|
"collection_name": COLLECTION_NAME,
|
||||||
|
"embedding_lanes": [lane.stats() for lane in self._lanes],
|
||||||
"healthy": True,
|
"healthy": True,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -400,19 +473,23 @@ class VectorRAG:
|
|||||||
return {"success": False, "message": "Collection not initialized"}
|
return {"success": False, "message": "Collection not initialized"}
|
||||||
directory = os.path.abspath(directory)
|
directory = os.path.abspath(directory)
|
||||||
try:
|
try:
|
||||||
results = self._collection.get(include=["metadatas"])
|
removed_ids = set()
|
||||||
ids = [
|
for _lane_name, collection in self._collections_for_delete():
|
||||||
results["ids"][i]
|
results = collection.get(include=["metadatas"])
|
||||||
for i, m in enumerate(results["metadatas"])
|
ids = [
|
||||||
if isinstance(m, dict)
|
results["ids"][i]
|
||||||
and isinstance(m.get("source"), str)
|
for i, m in enumerate(results["metadatas"])
|
||||||
and (m["source"] == directory or m["source"].startswith(directory + os.sep))
|
if isinstance(m, dict)
|
||||||
]
|
and isinstance(m.get("source"), str)
|
||||||
if not ids:
|
and (m["source"] == directory or m["source"].startswith(directory + os.sep))
|
||||||
|
]
|
||||||
|
if ids:
|
||||||
|
collection.delete(ids=ids)
|
||||||
|
removed_ids.update(ids)
|
||||||
|
if not removed_ids:
|
||||||
return {"success": True, "removed_count": 0, "message": "No docs found"}
|
return {"success": True, "removed_count": 0, "message": "No docs found"}
|
||||||
|
|
||||||
self._collection.delete(ids=ids)
|
n = len(removed_ids)
|
||||||
n = len(ids)
|
|
||||||
logger.info(f"Removed {n} chunks from {directory}")
|
logger.info(f"Removed {n} chunks from {directory}")
|
||||||
return {"success": True, "removed_count": n, "message": f"Removed {n} chunks"}
|
return {"success": True, "removed_count": n, "message": f"Removed {n} chunks"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -504,16 +581,18 @@ class VectorRAG:
|
|||||||
if not self.healthy:
|
if not self.healthy:
|
||||||
return 0
|
return 0
|
||||||
try:
|
try:
|
||||||
results = self._collection.get(
|
removed_ids = set()
|
||||||
where={"source": source},
|
for _lane_name, collection in self._collections_for_delete():
|
||||||
include=[],
|
results = collection.get(
|
||||||
)
|
where={"source": source},
|
||||||
ids = results.get("ids", [])
|
include=[],
|
||||||
if not ids:
|
)
|
||||||
return 0
|
ids = results.get("ids", [])
|
||||||
self._collection.delete(ids=ids)
|
if ids:
|
||||||
logger.info(f"Deleted {len(ids)} chunks for source={source}")
|
collection.delete(ids=ids)
|
||||||
return len(ids)
|
removed_ids.update(ids)
|
||||||
|
logger.info(f"Deleted {len(removed_ids)} chunks for source={source}")
|
||||||
|
return len(removed_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"delete_by_source failed: {e}")
|
logger.error(f"delete_by_source failed: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
+106
-64
@@ -12,6 +12,14 @@ import re
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from src.embedding_lanes import (
|
||||||
|
LANE_CUSTOM,
|
||||||
|
LANE_FASTEMBED,
|
||||||
|
build_embedding_lanes,
|
||||||
|
dedupe_results,
|
||||||
|
migrate_legacy_collection,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -155,32 +163,30 @@ class ToolIndex:
|
|||||||
"""ChromaDB-backed tool index for RAG-based tool selection."""
|
"""ChromaDB-backed tool index for RAG-based tool selection."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from src.chroma_client import get_chroma_client
|
self._lanes = build_embedding_lanes(COLLECTION_NAME)
|
||||||
from src.embeddings import get_embedding_client
|
if not self._lanes:
|
||||||
|
raise RuntimeError("No embedding lanes available")
|
||||||
self._embedder = get_embedding_client()
|
self._embedder = self._lanes[0].client
|
||||||
if not self._embedder:
|
self._collection = next(
|
||||||
raise RuntimeError("No embedding client available")
|
(lane.collection for lane in self._lanes if lane.name == LANE_FASTEMBED),
|
||||||
|
self._lanes[0].collection,
|
||||||
client = get_chroma_client()
|
|
||||||
self._collection = client.get_or_create_collection(
|
|
||||||
name=COLLECTION_NAME,
|
|
||||||
metadata={"hnsw:space": "cosine"},
|
|
||||||
)
|
)
|
||||||
|
migrate_legacy_collection(COLLECTION_NAME, self._lanes)
|
||||||
self._fingerprint = ""
|
self._fingerprint = ""
|
||||||
self._mcp_generation = -1
|
self._mcp_generation = -1
|
||||||
self._healthy = True
|
self._healthy = True
|
||||||
logger.info("ToolIndex initialized")
|
logger.info("ToolIndex initialized (lanes=%s)", [lane.name for lane in self._lanes])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def healthy(self):
|
def healthy(self):
|
||||||
return self._healthy
|
return self._healthy
|
||||||
|
|
||||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
vecs = self._embedder.encode(texts, normalize_embeddings=True)
|
if not self._lanes:
|
||||||
|
return []
|
||||||
|
vecs = self._lanes[0].encode(texts)
|
||||||
if np is not None:
|
if np is not None:
|
||||||
return np.array(vecs, dtype=np.float32).tolist()
|
return np.array(vecs, dtype=np.float32).tolist()
|
||||||
# Fallback without numpy
|
|
||||||
return [list(v) for v in vecs]
|
return [list(v) for v in vecs]
|
||||||
|
|
||||||
def index_builtin_tools(self):
|
def index_builtin_tools(self):
|
||||||
@@ -201,23 +207,31 @@ class ToolIndex:
|
|||||||
# registry (e.g. removed tools like the old vault_* set).
|
# registry (e.g. removed tools like the old vault_* set).
|
||||||
# Without this, upsert leaves them in place and RAG keeps
|
# Without this, upsert leaves them in place and RAG keeps
|
||||||
# surfacing tools that no longer exist.
|
# surfacing tools that no longer exist.
|
||||||
try:
|
indexed = False
|
||||||
existing = self._collection.get(where={"tool_type": "builtin"})
|
for lane in self._lanes:
|
||||||
existing_ids = (existing or {}).get("ids") or []
|
try:
|
||||||
stale = [i for i in existing_ids if i not in set(ids)]
|
existing = lane.collection.get(where={"tool_type": "builtin"})
|
||||||
if stale:
|
existing_ids = (existing or {}).get("ids") or []
|
||||||
self._collection.delete(ids=stale)
|
stale = [i for i in existing_ids if i not in set(ids)]
|
||||||
logger.info(f"Pruned {len(stale)} stale builtin tool entries from index")
|
if stale:
|
||||||
except Exception as e:
|
lane.collection.delete(ids=stale)
|
||||||
logger.debug(f"Stale-pruning skipped: {e}")
|
logger.info(f"Pruned {len(stale)} stale builtin tool entries from {lane.name} index")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Stale-pruning skipped for {lane.name}: {e}")
|
||||||
|
|
||||||
embeddings = self._embed(docs)
|
try:
|
||||||
self._collection.upsert(
|
lane.collection.upsert(
|
||||||
ids=ids,
|
ids=ids,
|
||||||
documents=docs,
|
documents=docs,
|
||||||
embeddings=embeddings,
|
embeddings=lane.encode(docs),
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
|
indexed = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Builtin tool indexing failed in %s lane: %s", lane.name, e)
|
||||||
|
if not indexed:
|
||||||
|
self._healthy = False
|
||||||
|
raise RuntimeError("Builtin tool indexing failed in all embedding lanes")
|
||||||
self._fingerprint = hashlib.sha256(
|
self._fingerprint = hashlib.sha256(
|
||||||
",".join(sorted(BUILTIN_TOOL_DESCRIPTIONS.keys())).encode()
|
",".join(sorted(BUILTIN_TOOL_DESCRIPTIONS.keys())).encode()
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
@@ -232,15 +246,15 @@ class ToolIndex:
|
|||||||
gen = getattr(mcp_mgr, '_generation', 0)
|
gen = getattr(mcp_mgr, '_generation', 0)
|
||||||
if gen == self._mcp_generation:
|
if gen == self._mcp_generation:
|
||||||
return
|
return
|
||||||
self._mcp_generation = gen
|
|
||||||
|
|
||||||
# Remove old MCP entries
|
# Remove old MCP entries
|
||||||
try:
|
for lane in self._lanes:
|
||||||
existing = self._collection.get(where={"tool_type": "mcp"})
|
try:
|
||||||
if existing and existing["ids"]:
|
existing = lane.collection.get(where={"tool_type": "mcp"})
|
||||||
self._collection.delete(ids=existing["ids"])
|
if existing and existing["ids"]:
|
||||||
except Exception:
|
lane.collection.delete(ids=existing["ids"])
|
||||||
pass
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Get current MCP tools
|
# Get current MCP tools
|
||||||
try:
|
try:
|
||||||
@@ -249,6 +263,7 @@ class ToolIndex:
|
|||||||
all_tools = ""
|
all_tools = ""
|
||||||
|
|
||||||
if not all_tools:
|
if not all_tools:
|
||||||
|
self._mcp_generation = gen
|
||||||
return
|
return
|
||||||
|
|
||||||
# Parse MCP tool descriptions from the prompt text
|
# Parse MCP tool descriptions from the prompt text
|
||||||
@@ -276,39 +291,59 @@ class ToolIndex:
|
|||||||
metadatas.append({"tool_name": name, "tool_type": "mcp"})
|
metadatas.append({"tool_name": name, "tool_type": "mcp"})
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
|
self._mcp_generation = gen
|
||||||
return
|
return
|
||||||
|
|
||||||
embeddings = self._embed(docs)
|
indexed = False
|
||||||
self._collection.upsert(
|
for lane in self._lanes:
|
||||||
ids=ids,
|
try:
|
||||||
documents=docs,
|
lane.collection.upsert(
|
||||||
embeddings=embeddings,
|
ids=ids,
|
||||||
metadatas=metadatas,
|
documents=docs,
|
||||||
)
|
embeddings=lane.encode(docs),
|
||||||
|
metadatas=metadatas,
|
||||||
|
)
|
||||||
|
indexed = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("MCP tool indexing failed in %s lane: %s", lane.name, e)
|
||||||
|
if not indexed:
|
||||||
|
logger.warning("MCP tool indexing failed in all embedding lanes")
|
||||||
|
return
|
||||||
|
self._mcp_generation = gen
|
||||||
logger.info(f"Indexed {len(docs)} MCP tools")
|
logger.info(f"Indexed {len(docs)} MCP tools")
|
||||||
|
|
||||||
def retrieve(self, query: str, k: int = 8) -> List[str]:
|
def retrieve(self, query: str, k: int = 8) -> List[str]:
|
||||||
"""Retrieve the top-K most relevant tool names for a query."""
|
"""Retrieve the top-K most relevant tool names for a query."""
|
||||||
try:
|
rows = []
|
||||||
query_embedding = self._embed([query])
|
lane_priority = {LANE_CUSTOM: 0, LANE_FASTEMBED: 1}
|
||||||
results = self._collection.query(
|
for lane in self._lanes:
|
||||||
query_embeddings=query_embedding,
|
try:
|
||||||
n_results=min(k, self._collection.count() or k),
|
count = lane.count()
|
||||||
include=["metadatas", "distances"],
|
if count == 0:
|
||||||
)
|
continue
|
||||||
if not results or not results.get("metadatas"):
|
results = lane.collection.query(
|
||||||
return []
|
query_embeddings=lane.encode([query]),
|
||||||
|
n_results=min(k, count),
|
||||||
tool_names = []
|
include=["metadatas", "distances"],
|
||||||
for meta_list in results["metadatas"]:
|
)
|
||||||
for meta in meta_list:
|
if not results or not results.get("metadatas"):
|
||||||
name = meta.get("tool_name", "")
|
continue
|
||||||
if name and name not in tool_names:
|
distances = results.get("distances") or []
|
||||||
tool_names.append(name)
|
for list_idx, meta_list in enumerate(results["metadatas"]):
|
||||||
return tool_names
|
distance_list = distances[list_idx] if list_idx < len(distances) else []
|
||||||
except Exception as e:
|
for idx, meta in enumerate(meta_list):
|
||||||
logger.warning(f"Tool retrieval failed: {e}")
|
name = meta.get("tool_name", "")
|
||||||
return []
|
if name:
|
||||||
|
distance = distance_list[idx] if idx < len(distance_list) else 1.0
|
||||||
|
rows.append({
|
||||||
|
"tool_name": name,
|
||||||
|
"score": round(1.0 - distance, 4),
|
||||||
|
"embedding_lane": lane.name,
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Tool retrieval failed in %s lane: %s", lane.name, e)
|
||||||
|
rows.sort(key=lambda row: (-row["score"], lane_priority.get(row["embedding_lane"], 99)))
|
||||||
|
return [row["tool_name"] for row in dedupe_results(rows, id_key="tool_name", limit=k)]
|
||||||
|
|
||||||
# Structural recurring-schedule intent. Typo-resilient (matches "every dya"
|
# Structural recurring-schedule intent. Typo-resilient (matches "every dya"
|
||||||
# via "every <word>"), and catches bare clock times ("at 7:30 am", "7am").
|
# via "every <word>"), and catches bare clock times ("at 7:30 am", "7am").
|
||||||
@@ -511,3 +546,10 @@ def get_tool_index() -> Optional[ToolIndex]:
|
|||||||
logger.warning(f"ToolIndex init failed (will retry in {_RETRY_INTERVAL}s): {e}")
|
logger.warning(f"ToolIndex init failed (will retry in {_RETRY_INTERVAL}s): {e}")
|
||||||
_tool_index = None
|
_tool_index = None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def reset_tool_index() -> None:
|
||||||
|
"""Clear the singleton so embedding endpoint changes rebuild tool lanes."""
|
||||||
|
global _tool_index, _last_attempt
|
||||||
|
_tool_index = None
|
||||||
|
_last_attempt = 0.0
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user