mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 01:35:36 -04:00
fix(auth): clean up rename and null-owner ownership (#4340)
This commit is contained in:
@@ -527,6 +527,7 @@ memory_vector = components.get("memory_vector")
|
||||
upload_handler = components["upload_handler"]
|
||||
app.state.upload_handler = upload_handler
|
||||
personal_docs_mgr = components["personal_docs_manager"]
|
||||
app.state.personal_docs_manager = personal_docs_mgr
|
||||
api_key_manager = components["api_key_manager"]
|
||||
preset_manager = components["preset_manager"]
|
||||
chat_processor = components["chat_processor"]
|
||||
|
||||
+10
-6
@@ -573,16 +573,20 @@ class AuthManager:
|
||||
return None
|
||||
return self.create_session_trusted(username)
|
||||
|
||||
def create_session_trusted(self, username: str) -> str:
|
||||
def create_session_trusted(self, username: str) -> Optional[str]:
|
||||
"""Issue a session token for an already-verified user.
|
||||
Call only after verify_password (and TOTP if enabled) have passed."""
|
||||
username = username.strip().lower()
|
||||
token = secrets.token_hex(32)
|
||||
with self._sessions_lock:
|
||||
self._sessions[token] = {
|
||||
"username": username,
|
||||
"expiry": time.time() + TOKEN_TTL,
|
||||
}
|
||||
with self._config_lock:
|
||||
if username not in self.users:
|
||||
logger.warning("Refused to issue session for missing user '%s'", username)
|
||||
return None
|
||||
with self._sessions_lock:
|
||||
self._sessions[token] = {
|
||||
"username": username,
|
||||
"expiry": time.time() + TOKEN_TTL,
|
||||
}
|
||||
self._save_sessions()
|
||||
return token
|
||||
|
||||
|
||||
@@ -144,6 +144,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
raise HTTPException(401, "Invalid 2FA code")
|
||||
# All checks passed — create session (password already verified above)
|
||||
token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
|
||||
if not token:
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
cookie_kwargs = dict(
|
||||
key=SESSION_COOKIE,
|
||||
value=token,
|
||||
@@ -432,6 +434,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename upload owner references %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# direct personal RAG uploads live in per-owner directories and the
|
||||
# vector metadata also carries the username used for owner-filtered
|
||||
# search. Keep both in sync with the auth rename.
|
||||
try:
|
||||
from routes.personal_routes import rename_personal_upload_owner
|
||||
personal_docs_manager = getattr(request.app.state, "personal_docs_manager", None)
|
||||
if personal_docs_manager is not None:
|
||||
rag_manager = getattr(personal_docs_manager, "rag_manager", None)
|
||||
rename_personal_upload_owner(
|
||||
old_username,
|
||||
new_username,
|
||||
personal_docs_manager=personal_docs_manager,
|
||||
rag_manager=rag_manager,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename personal RAG upload owner references %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# skills: SKILL.md frontmatter carries owner: <username>; the usage
|
||||
# sidecar (_usage.json) keys entries as owner::skill-name. Both must
|
||||
# be updated or the renamed user's Skills panel goes empty.
|
||||
|
||||
@@ -102,8 +102,11 @@ def _owner_session_filter(q, user):
|
||||
|
||||
The owner backfill runs in init_db before the app serves requests, so
|
||||
by the time this filter is live there are no NULL-owner rows to leak;
|
||||
we therefore match the owner strictly."""
|
||||
if user is None:
|
||||
we therefore match the owner strictly for authenticated callers."""
|
||||
if not user:
|
||||
from src.auth_helpers import _auth_disabled
|
||||
if user == "" or _auth_disabled():
|
||||
return q
|
||||
return q.filter(False)
|
||||
return q.filter(Document.owner == user)
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
"""Routes for personal documents management."""
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
||||
from src.request_models import DirectoryRequest
|
||||
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
|
||||
@@ -18,14 +19,15 @@ UPLOADS_DIR = PERSONAL_UPLOADS_DIR
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _personal_upload_dir_for_owner(owner: str | None) -> str:
|
||||
def _personal_upload_dir_for_owner(owner: str | None, *, create: bool = True) -> str:
|
||||
"""Return the per-owner upload directory used for direct RAG uploads."""
|
||||
owner_segment = secure_filename((owner or "local").strip())[:80] or "local"
|
||||
upload_dir = os.path.abspath(os.path.join(UPLOADS_DIR, owner_segment))
|
||||
base_abs = os.path.abspath(UPLOADS_DIR)
|
||||
if os.path.commonpath([upload_dir, base_abs]) != base_abs:
|
||||
raise ValueError("Unsafe upload owner path")
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
if create:
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
return upload_dir
|
||||
|
||||
|
||||
@@ -44,6 +46,87 @@ def _unique_personal_upload_path(upload_dir: str, original_name: str | None) ->
|
||||
raise ValueError("Unsafe upload filename")
|
||||
return file_path, filename, safe_name
|
||||
|
||||
|
||||
def _unique_existing_target(path: str) -> str:
|
||||
"""Return a non-existing sibling path for rename collision handling."""
|
||||
if not os.path.exists(path):
|
||||
return path
|
||||
stem, ext = os.path.splitext(path)
|
||||
while True:
|
||||
candidate = f"{stem}-{uuid.uuid4().hex[:10]}{ext}"
|
||||
if not os.path.exists(candidate):
|
||||
return candidate
|
||||
|
||||
|
||||
def _remove_empty_tree(path: str) -> None:
|
||||
"""Best-effort removal of empty directories under ``path``."""
|
||||
if not os.path.isdir(path):
|
||||
return
|
||||
for root, dirs, _files in os.walk(path, topdown=False):
|
||||
for dirname in dirs:
|
||||
candidate = os.path.join(root, dirname)
|
||||
try:
|
||||
os.rmdir(candidate)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.rmdir(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def rename_personal_upload_owner(
|
||||
old_owner: str,
|
||||
new_owner: str,
|
||||
*,
|
||||
personal_docs_manager: Any = None,
|
||||
rag_manager: Any = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Move direct personal uploads and rewrite RAG owner metadata on user rename."""
|
||||
old_dir = _personal_upload_dir_for_owner(old_owner, create=False)
|
||||
new_dir = _personal_upload_dir_for_owner(new_owner, create=False)
|
||||
path_map: Dict[str, str] = {}
|
||||
moved_files = 0
|
||||
|
||||
if os.path.isdir(old_dir) and old_dir != new_dir:
|
||||
os.makedirs(new_dir, exist_ok=True)
|
||||
for root, _dirs, files in os.walk(old_dir):
|
||||
rel_root = os.path.relpath(root, old_dir)
|
||||
target_root = new_dir if rel_root == "." else os.path.join(new_dir, rel_root)
|
||||
os.makedirs(target_root, exist_ok=True)
|
||||
for filename in files:
|
||||
source = os.path.abspath(os.path.join(root, filename))
|
||||
target = _unique_existing_target(os.path.abspath(os.path.join(target_root, filename)))
|
||||
shutil.move(source, target)
|
||||
path_map[source] = target
|
||||
moved_files += 1
|
||||
_remove_empty_tree(old_dir)
|
||||
|
||||
if personal_docs_manager is not None:
|
||||
rename_directory = getattr(personal_docs_manager, "rename_directory", None)
|
||||
if callable(rename_directory):
|
||||
rename_directory(old_dir, new_dir, path_map=path_map)
|
||||
|
||||
rag_result = None
|
||||
if rag_manager is not None:
|
||||
rename_owner = getattr(rag_manager, "rename_owner", None)
|
||||
if callable(rename_owner):
|
||||
rag_result = rename_owner(
|
||||
old_owner,
|
||||
new_owner,
|
||||
path_map=path_map,
|
||||
path_prefixes=[(old_dir, new_dir)],
|
||||
)
|
||||
|
||||
return {
|
||||
"old_dir": old_dir,
|
||||
"new_dir": new_dir,
|
||||
"moved_files": moved_files,
|
||||
"path_map": path_map,
|
||||
"rag_result": rag_result,
|
||||
}
|
||||
|
||||
|
||||
def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
"""
|
||||
Setup personal documents related routes.
|
||||
|
||||
@@ -1004,6 +1004,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
"""
|
||||
from src.llm_core import llm_call
|
||||
user = effective_user(request)
|
||||
single_user_mode = not user and _auth_disabled()
|
||||
user_sessions = session_manager.get_sessions_for_user(user)
|
||||
|
||||
# Delete empty and throwaway sessions before sorting
|
||||
@@ -1022,7 +1023,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
}
|
||||
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
|
||||
try:
|
||||
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all()
|
||||
rows_q = db.query(DbSession).filter(DbSession.archived == False)
|
||||
if user:
|
||||
rows_q = rows_q.filter(DbSession.owner == user)
|
||||
elif not single_user_mode:
|
||||
rows_q = rows_q.filter(DbSession.owner == user)
|
||||
rows = rows_q.limit(2000).all()
|
||||
folder_map = {r.id: r.folder for r in rows}
|
||||
# Precompute per-session message counts in TWO aggregate queries
|
||||
# instead of 1–3 queries PER session — with many chats the per-row
|
||||
@@ -1242,7 +1248,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for sid, folder_name in assignments.items():
|
||||
db_session = db.query(DbSession).filter(DbSession.id == sid, DbSession.owner == user).first()
|
||||
db_session_q = db.query(DbSession).filter(DbSession.id == sid)
|
||||
if user:
|
||||
db_session_q = db_session_q.filter(DbSession.owner == user)
|
||||
elif not single_user_mode:
|
||||
db_session_q = db_session_q.filter(DbSession.owner == user)
|
||||
db_session = db_session_q.first()
|
||||
if db_session:
|
||||
db_session.folder = folder_name
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
|
||||
@@ -322,6 +322,47 @@ class PersonalDocsManager:
|
||||
else:
|
||||
logger.info(f"Directory not in index: {directory}")
|
||||
|
||||
def rename_directory(self, old_directory: str, new_directory: str, *, path_map: Dict[str, str] = None):
|
||||
"""Rewrite tracked directory and excluded-file paths after an owner rename."""
|
||||
old_directory = os.path.abspath(old_directory)
|
||||
new_directory = os.path.abspath(new_directory)
|
||||
path_map = {os.path.abspath(k): os.path.abspath(v) for k, v in (path_map or {}).items()}
|
||||
|
||||
def rewrite(path: str) -> str:
|
||||
abs_path = os.path.abspath(path)
|
||||
mapped = path_map.get(abs_path)
|
||||
if mapped:
|
||||
return mapped
|
||||
if abs_path == old_directory:
|
||||
return new_directory
|
||||
if abs_path.startswith(old_directory + os.sep):
|
||||
return new_directory + abs_path[len(old_directory):]
|
||||
return abs_path
|
||||
|
||||
changed_dirs = False
|
||||
rewritten_dirs = []
|
||||
for directory in self.indexed_directories:
|
||||
rewritten = rewrite(directory)
|
||||
changed_dirs = changed_dirs or rewritten != os.path.abspath(directory)
|
||||
if rewritten not in rewritten_dirs:
|
||||
rewritten_dirs.append(rewritten)
|
||||
if changed_dirs:
|
||||
self.indexed_directories = rewritten_dirs
|
||||
self.save_directories()
|
||||
|
||||
changed_excluded = False
|
||||
rewritten_excluded = set()
|
||||
for path in self.excluded_files:
|
||||
rewritten = rewrite(path)
|
||||
changed_excluded = changed_excluded or rewritten != os.path.abspath(path)
|
||||
rewritten_excluded.add(rewritten)
|
||||
if changed_excluded:
|
||||
self.excluded_files = rewritten_excluded
|
||||
self._save_excluded()
|
||||
|
||||
if changed_dirs or changed_excluded:
|
||||
self.refresh_index()
|
||||
|
||||
def get_indexed_directories(self):
|
||||
"""Get the list of all indexed directories."""
|
||||
return self.indexed_directories.copy()
|
||||
|
||||
@@ -50,6 +50,23 @@ def _generate_doc_id(text: str, owner: str = "") -> str:
|
||||
return f"doc_{hashlib.sha256(key.encode('utf-8')).hexdigest()[:16]}"
|
||||
|
||||
|
||||
def _rewrite_owner_path(value: str, path_map: Dict[str, str], path_prefixes: List[tuple]) -> str:
|
||||
if not isinstance(value, str) or not value:
|
||||
return value
|
||||
abs_value = os.path.abspath(value)
|
||||
mapped = path_map.get(abs_value)
|
||||
if mapped:
|
||||
return mapped
|
||||
for old_prefix, new_prefix in path_prefixes:
|
||||
old_abs = os.path.abspath(old_prefix)
|
||||
new_abs = os.path.abspath(new_prefix)
|
||||
if abs_value == old_abs:
|
||||
return new_abs
|
||||
if abs_value.startswith(old_abs + os.sep):
|
||||
return new_abs + abs_value[len(old_abs):]
|
||||
return value
|
||||
|
||||
|
||||
class VectorRAG:
|
||||
"""RAG system using ChromaDB vector storage with hybrid search."""
|
||||
|
||||
@@ -250,6 +267,75 @@ class VectorRAG:
|
||||
"failed_count": len(docs) - len(valid),
|
||||
}
|
||||
|
||||
def rename_owner(
|
||||
self,
|
||||
old_owner: str,
|
||||
new_owner: str,
|
||||
*,
|
||||
path_map: Optional[Dict[str, str]] = None,
|
||||
path_prefixes: Optional[List[tuple]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Rewrite existing RAG metadata after an auth username rename."""
|
||||
if not self.healthy:
|
||||
return {"success": False, "updated_count": 0, "message": "Collection not initialized"}
|
||||
|
||||
old_owner = (old_owner or "").strip().lower()
|
||||
new_owner = (new_owner or "").strip().lower()
|
||||
if not old_owner or not new_owner or old_owner == new_owner:
|
||||
return {"success": True, "updated_count": 0, "message": "No owner rename needed"}
|
||||
|
||||
path_map = {os.path.abspath(k): os.path.abspath(v) for k, v in (path_map or {}).items()}
|
||||
path_prefixes = path_prefixes or []
|
||||
updated_ids = set()
|
||||
failed_count = 0
|
||||
|
||||
for lane_name, collection in self._collections_for_delete():
|
||||
try:
|
||||
results = collection.get(
|
||||
where={"owner": old_owner},
|
||||
include=["metadatas"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("rename_owner metadata scan failed in %s lane: %s", lane_name, e)
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
ids = results.get("ids") or []
|
||||
metadatas = results.get("metadatas") or []
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
new_metas = []
|
||||
selected_ids = []
|
||||
for doc_id, meta in zip(ids, metadatas):
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
next_meta = dict(meta)
|
||||
if str(next_meta.get("owner", "")).strip().lower() == old_owner:
|
||||
next_meta["owner"] = new_owner
|
||||
for key in ("source", "directory"):
|
||||
next_meta[key] = _rewrite_owner_path(next_meta.get(key), path_map, path_prefixes)
|
||||
selected_ids.append(doc_id)
|
||||
new_metas.append(next_meta)
|
||||
|
||||
if not selected_ids:
|
||||
continue
|
||||
|
||||
try:
|
||||
collection.update(ids=selected_ids, metadatas=new_metas)
|
||||
updated_ids.update(selected_ids)
|
||||
except Exception as e:
|
||||
logger.warning("rename_owner metadata update failed in %s lane: %s", lane_name, e)
|
||||
failed_count += len(selected_ids)
|
||||
|
||||
success = failed_count == 0
|
||||
return {
|
||||
"success": success,
|
||||
"updated_count": len(updated_ids),
|
||||
"failed_count": failed_count,
|
||||
"message": f"Updated {len(updated_ids)} RAG chunk(s)",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Search — hybrid: vector similarity + keyword overlap
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -80,6 +80,16 @@ def test_password_change_allows_new_password_and_blocks_old_password(tmp_path):
|
||||
assert mgr.create_session("alice", "new-password") is not None
|
||||
|
||||
|
||||
def test_create_session_trusted_rejects_username_renamed_after_verification(tmp_path):
|
||||
mgr = _make_manager(tmp_path)
|
||||
assert mgr.create_user("admin", "admin-password", is_admin=True)
|
||||
|
||||
assert mgr.verify_password("alice", "old-password") is True
|
||||
assert mgr.rename_user("alice", "alice2", "admin") is True
|
||||
|
||||
assert mgr.create_session_trusted("alice") is None
|
||||
|
||||
|
||||
def _change_password_endpoint(auth_manager):
|
||||
sys.modules.pop("routes.auth_routes", None)
|
||||
_real_core_package()
|
||||
@@ -92,6 +102,39 @@ def _change_password_endpoint(auth_manager):
|
||||
raise AssertionError("change-password route not found")
|
||||
|
||||
|
||||
def _login_endpoint(auth_manager):
|
||||
sys.modules.pop("routes.auth_routes", None)
|
||||
_real_core_package()
|
||||
from routes.auth_routes import LoginRequest, setup_auth_routes
|
||||
|
||||
router = setup_auth_routes(auth_manager)
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", None) == "/api/auth/login":
|
||||
return route.endpoint, LoginRequest
|
||||
raise AssertionError("login route not found")
|
||||
|
||||
|
||||
def test_login_route_does_not_set_cookie_when_trusted_session_rejects_stale_user(monkeypatch):
|
||||
auth = MagicMock()
|
||||
auth.verify_password.return_value = True
|
||||
auth.totp_enabled.return_value = False
|
||||
auth.create_session_trusted.return_value = None
|
||||
endpoint, LoginRequest = _login_endpoint(auth)
|
||||
monkeypatch.setattr(
|
||||
"routes.auth_routes.asyncio.to_thread",
|
||||
lambda fn, *args, **kwargs: _immediate_to_thread(fn, *args, **kwargs),
|
||||
)
|
||||
request = SimpleNamespace(client=SimpleNamespace(host="127.0.0.1"))
|
||||
response = MagicMock()
|
||||
body = LoginRequest(username="alice", password="old-password")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
asyncio.run(endpoint(body=body, request=request, response=response))
|
||||
|
||||
assert exc.value.status_code == 401
|
||||
response.set_cookie.assert_not_called()
|
||||
|
||||
|
||||
def test_change_password_route_revokes_other_sessions_after_success(monkeypatch):
|
||||
auth = MagicMock()
|
||||
auth.get_username_for_token.return_value = "alice"
|
||||
|
||||
@@ -25,6 +25,7 @@ import routes.document_routes as droutes
|
||||
from core.database import Document
|
||||
from core.database import Session as DbSession
|
||||
from routes.document_helpers import DocumentPatch
|
||||
from routes.document_helpers import _owner_session_filter
|
||||
|
||||
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_ENGINE = create_engine(
|
||||
@@ -141,3 +142,18 @@ async def test_list_documents_filters_foreign_docs_in_visible_session():
|
||||
assert bob_doc not in ids
|
||||
finally:
|
||||
droutes.SessionLocal = previous_session_local
|
||||
|
||||
|
||||
def test_owner_session_filter_noops_for_auth_disabled_single_user(monkeypatch):
|
||||
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||
previous_session_local = _bind_test_db()
|
||||
try:
|
||||
_alice_session, _bob_session, alice_doc, _bob_doc, _legacy_doc = _seed()
|
||||
db = _TS()
|
||||
try:
|
||||
q = db.query(Document).filter(Document.id == alice_doc)
|
||||
assert _owner_session_filter(q, None).first().id == alice_doc
|
||||
finally:
|
||||
db.close()
|
||||
finally:
|
||||
droutes.SessionLocal = previous_session_local
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from routes import personal_routes
|
||||
|
||||
@@ -42,3 +43,44 @@ def test_personal_upload_paths_stay_under_upload_root(tmp_path, monkeypatch):
|
||||
assert os.path.commonpath([file_path, upload_dir]) == upload_dir
|
||||
assert Path(file_path).name == stored_name
|
||||
assert display_name == "env"
|
||||
|
||||
|
||||
def test_rename_personal_upload_owner_moves_files_and_rewrites_rag(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(personal_routes, "UPLOADS_DIR", str(tmp_path))
|
||||
|
||||
old_dir = Path(personal_routes._personal_upload_dir_for_owner("alice"))
|
||||
old_file = old_dir / "note.txt"
|
||||
old_file.write_text("alice private RAG note", encoding="utf-8")
|
||||
|
||||
manager_calls = []
|
||||
rag_calls = []
|
||||
manager = SimpleNamespace(
|
||||
rename_directory=lambda old, new, path_map=None: manager_calls.append((old, new, dict(path_map or {}))),
|
||||
)
|
||||
rag = SimpleNamespace(
|
||||
rename_owner=lambda old, new, path_map=None, path_prefixes=None: rag_calls.append(
|
||||
(old, new, dict(path_map or {}), list(path_prefixes or []))
|
||||
) or {"success": True, "updated_count": 1},
|
||||
)
|
||||
|
||||
result = personal_routes.rename_personal_upload_owner(
|
||||
"alice",
|
||||
"alice2",
|
||||
personal_docs_manager=manager,
|
||||
rag_manager=rag,
|
||||
)
|
||||
|
||||
new_dir = Path(personal_routes._personal_upload_dir_for_owner("alice2"))
|
||||
new_file = new_dir / "note.txt"
|
||||
assert old_file.exists() is False
|
||||
assert new_file.read_text(encoding="utf-8") == "alice private RAG note"
|
||||
assert result["moved_files"] == 1
|
||||
assert manager_calls == [(str(old_dir), str(new_dir), {str(old_file): str(new_file)})]
|
||||
assert rag_calls == [
|
||||
(
|
||||
"alice",
|
||||
"alice2",
|
||||
{str(old_file): str(new_file)},
|
||||
[(str(old_dir), str(new_dir))],
|
||||
)
|
||||
]
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from src.rag_vector import VectorRAG
|
||||
|
||||
|
||||
class _FakeCollection:
|
||||
def __init__(self, docs):
|
||||
self._docs = {
|
||||
doc_id: {"document": document, "metadata": dict(metadata)}
|
||||
for doc_id, document, metadata in docs
|
||||
}
|
||||
|
||||
def count(self):
|
||||
return len(self._docs)
|
||||
|
||||
def get(self, where=None, include=None):
|
||||
rows = []
|
||||
for doc_id, row in self._docs.items():
|
||||
metadata = row["metadata"]
|
||||
if where and any(metadata.get(key) != value for key, value in where.items()):
|
||||
continue
|
||||
rows.append((doc_id, row))
|
||||
return {
|
||||
"ids": [doc_id for doc_id, _row in rows],
|
||||
"documents": [row["document"] for _doc_id, row in rows],
|
||||
"metadatas": [row["metadata"] for _doc_id, row in rows],
|
||||
}
|
||||
|
||||
def update(self, ids, metadatas):
|
||||
for doc_id, metadata in zip(ids, metadatas):
|
||||
self._docs[doc_id]["metadata"] = dict(metadata)
|
||||
|
||||
|
||||
def _store(collection):
|
||||
store = VectorRAG.__new__(VectorRAG)
|
||||
store._collection = collection
|
||||
store._lanes = []
|
||||
store._healthy = True
|
||||
return store
|
||||
|
||||
|
||||
def test_rename_owner_updates_metadata_used_by_owner_filtered_search(tmp_path):
|
||||
old_dir = tmp_path / "alice"
|
||||
new_dir = tmp_path / "alice2"
|
||||
old_file = old_dir / "note.txt"
|
||||
new_file = new_dir / "note.txt"
|
||||
collection = _FakeCollection([
|
||||
(
|
||||
"doc-old",
|
||||
"private vector note",
|
||||
{
|
||||
"owner": "alice",
|
||||
"source": str(old_file),
|
||||
"directory": str(old_dir),
|
||||
},
|
||||
),
|
||||
(
|
||||
"doc-other",
|
||||
"other vector note",
|
||||
{
|
||||
"owner": "bob",
|
||||
"source": str(tmp_path / "bob" / "note.txt"),
|
||||
},
|
||||
),
|
||||
])
|
||||
store = _store(collection)
|
||||
|
||||
result = store.rename_owner(
|
||||
"alice",
|
||||
"alice2",
|
||||
path_map={str(old_file): str(new_file)},
|
||||
path_prefixes=[(str(old_dir), str(new_dir))],
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["updated_count"] == 1
|
||||
assert store._keyword_search_fallback("private", k=10, owner="alice") == []
|
||||
renamed = store._keyword_search_fallback("private", k=10, owner="alice2")
|
||||
assert [row["id"] for row in renamed] == ["doc-old"]
|
||||
assert renamed[0]["metadata"]["owner"] == "alice2"
|
||||
assert renamed[0]["metadata"]["source"] == str(new_file)
|
||||
assert renamed[0]["metadata"]["directory"] == str(new_dir)
|
||||
assert store._keyword_search_fallback("other", k=10, owner="bob")[0]["id"] == "doc-other"
|
||||
@@ -70,12 +70,20 @@ def rename_endpoint(monkeypatch, tmp_path):
|
||||
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
|
||||
|
||||
|
||||
def _request(tmp_path, session_manager=None, token="t", research_handler=None, upload_handler=None):
|
||||
def _request(
|
||||
tmp_path,
|
||||
session_manager=None,
|
||||
token="t",
|
||||
research_handler=None,
|
||||
upload_handler=None,
|
||||
personal_docs_manager=None,
|
||||
):
|
||||
state = SimpleNamespace(
|
||||
invalidate_token_cache=lambda: None,
|
||||
session_manager=session_manager,
|
||||
research_handler=research_handler,
|
||||
upload_handler=upload_handler,
|
||||
personal_docs_manager=personal_docs_manager,
|
||||
)
|
||||
return SimpleNamespace(
|
||||
cookies={"odysseus_session": token},
|
||||
@@ -467,6 +475,52 @@ def test_rename_updates_upload_metadata_owner(rename_endpoint):
|
||||
assert handler.resolve_upload(upload_id, owner="alice") is None
|
||||
|
||||
|
||||
def test_rename_updates_personal_rag_upload_owner(rename_endpoint, monkeypatch):
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
from routes import personal_routes
|
||||
|
||||
monkeypatch.setattr(personal_routes, "UPLOADS_DIR", str(tmp_path / "personal_uploads"))
|
||||
old_dir = Path(personal_routes._personal_upload_dir_for_owner("alice"))
|
||||
old_file = old_dir / "note.txt"
|
||||
old_file.write_text("private RAG note", encoding="utf-8")
|
||||
|
||||
manager_calls = []
|
||||
rag_calls = []
|
||||
rag = SimpleNamespace(
|
||||
rename_owner=lambda old, new, path_map=None, path_prefixes=None: rag_calls.append(
|
||||
(old, new, dict(path_map or {}), list(path_prefixes or []))
|
||||
) or {"success": True, "updated_count": 1},
|
||||
)
|
||||
personal_docs_manager = SimpleNamespace(
|
||||
rag_manager=rag,
|
||||
rename_directory=lambda old, new, path_map=None: manager_calls.append(
|
||||
(old, new, dict(path_map or {}))
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.run(
|
||||
endpoint(
|
||||
"alice",
|
||||
SimpleNamespace(username="alice2"),
|
||||
_request(tmp_path, personal_docs_manager=personal_docs_manager),
|
||||
)
|
||||
)
|
||||
|
||||
new_dir = Path(personal_routes._personal_upload_dir_for_owner("alice2"))
|
||||
new_file = new_dir / "note.txt"
|
||||
assert old_file.exists() is False
|
||||
assert new_file.read_text(encoding="utf-8") == "private RAG note"
|
||||
assert manager_calls == [(str(old_dir), str(new_dir), {str(old_file): str(new_file)})]
|
||||
assert rag_calls == [
|
||||
(
|
||||
"alice",
|
||||
"alice2",
|
||||
{str(old_file): str(new_file)},
|
||||
[(str(old_dir), str(new_dir))],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Skills (SKILL.md frontmatter + _usage.json sidecar)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -7,6 +7,7 @@ import sys
|
||||
import tempfile
|
||||
import types
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
@@ -14,6 +15,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
import core.database as cdb
|
||||
from core.database import ChatMessage as DbMessage
|
||||
from core.database import Session as DbSession
|
||||
|
||||
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
@@ -72,3 +74,60 @@ def test_list_sessions_excludes_other_users_sessions(monkeypatch):
|
||||
returned_ids = {s["id"] for s in result}
|
||||
assert alice_id in returned_ids
|
||||
assert bob_id not in returned_ids
|
||||
|
||||
|
||||
def test_auto_sort_skip_llm_cleans_owner_stamped_sessions_when_auth_disabled(monkeypatch):
|
||||
import routes.session_routes as sr
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
_stub_multipart_if_missing(monkeypatch)
|
||||
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||
monkeypatch.setattr(sr, "SessionLocal", _TS)
|
||||
monkeypatch.setattr(sr, "effective_user", lambda request: None)
|
||||
|
||||
sid = str(uuid.uuid4())
|
||||
old_time = cdb.utcnow_naive() - timedelta(hours=2)
|
||||
db = _TS()
|
||||
try:
|
||||
db.query(DbMessage).delete()
|
||||
db.query(DbSession).delete()
|
||||
db.add(DbSession(
|
||||
id=sid,
|
||||
owner="alice",
|
||||
name="New chat",
|
||||
endpoint_url="http://localhost",
|
||||
model="gpt-4",
|
||||
archived=False,
|
||||
message_count=1,
|
||||
created_at=old_time,
|
||||
updated_at=old_time,
|
||||
last_message_at=old_time,
|
||||
last_accessed=old_time,
|
||||
))
|
||||
db.add(DbMessage(
|
||||
id="m-" + uuid.uuid4().hex,
|
||||
session_id=sid,
|
||||
role="user",
|
||||
content="hi",
|
||||
timestamp=old_time,
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
session = MagicMock(id=sid, name="New chat", model="gpt-4", endpoint_url="http://localhost", rag=False, archived=False)
|
||||
sm = MagicMock()
|
||||
sm.get_sessions_for_user.return_value = {sid: session}
|
||||
router = sr.setup_session_routes(sm, {})
|
||||
endpoint = next(r.endpoint for r in router.routes
|
||||
if getattr(r, "path", "") == "/api/sessions/auto-sort"
|
||||
and "POST" in getattr(r, "methods", set()))
|
||||
|
||||
result = endpoint(request=MagicMock(), skip_llm=True)
|
||||
|
||||
assert result["deleted_throwaway"] == 1
|
||||
db = _TS()
|
||||
try:
|
||||
assert db.query(DbSession).filter(DbSession.id == sid).first() is None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user