fix(auth): clean up rename and null-owner ownership (#4340)

This commit is contained in:
RaresKeY
2026-06-16 05:33:02 +03:00
committed by GitHub
parent 745c10e0d7
commit 4d10c16d02
14 changed files with 557 additions and 14 deletions
+1
View File
@@ -527,6 +527,7 @@ memory_vector = components.get("memory_vector")
upload_handler = components["upload_handler"] upload_handler = components["upload_handler"]
app.state.upload_handler = upload_handler app.state.upload_handler = upload_handler
personal_docs_mgr = components["personal_docs_manager"] personal_docs_mgr = components["personal_docs_manager"]
app.state.personal_docs_manager = personal_docs_mgr
api_key_manager = components["api_key_manager"] api_key_manager = components["api_key_manager"]
preset_manager = components["preset_manager"] preset_manager = components["preset_manager"]
chat_processor = components["chat_processor"] chat_processor = components["chat_processor"]
+10 -6
View File
@@ -573,16 +573,20 @@ class AuthManager:
return None return None
return self.create_session_trusted(username) return self.create_session_trusted(username)
def create_session_trusted(self, username: str) -> str: def create_session_trusted(self, username: str) -> Optional[str]:
"""Issue a session token for an already-verified user. """Issue a session token for an already-verified user.
Call only after verify_password (and TOTP if enabled) have passed.""" Call only after verify_password (and TOTP if enabled) have passed."""
username = username.strip().lower() username = username.strip().lower()
token = secrets.token_hex(32) token = secrets.token_hex(32)
with self._sessions_lock: with self._config_lock:
self._sessions[token] = { if username not in self.users:
"username": username, logger.warning("Refused to issue session for missing user '%s'", username)
"expiry": time.time() + TOKEN_TTL, return None
} with self._sessions_lock:
self._sessions[token] = {
"username": username,
"expiry": time.time() + TOKEN_TTL,
}
self._save_sessions() self._save_sessions()
return token return token
+19
View File
@@ -144,6 +144,8 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
raise HTTPException(401, "Invalid 2FA code") raise HTTPException(401, "Invalid 2FA code")
# All checks passed — create session (password already verified above) # All checks passed — create session (password already verified above)
token = await asyncio.to_thread(auth_manager.create_session_trusted, username) token = await asyncio.to_thread(auth_manager.create_session_trusted, username)
if not token:
raise HTTPException(401, "Invalid credentials")
cookie_kwargs = dict( cookie_kwargs = dict(
key=SESSION_COOKIE, key=SESSION_COOKIE,
value=token, value=token,
@@ -432,6 +434,23 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
except Exception as e: except Exception as e:
logger.warning("Failed to rename upload owner references %s -> %s: %s", old_username, new_username, e) logger.warning("Failed to rename upload owner references %s -> %s: %s", old_username, new_username, e)
# direct personal RAG uploads live in per-owner directories and the
# vector metadata also carries the username used for owner-filtered
# search. Keep both in sync with the auth rename.
try:
from routes.personal_routes import rename_personal_upload_owner
personal_docs_manager = getattr(request.app.state, "personal_docs_manager", None)
if personal_docs_manager is not None:
rag_manager = getattr(personal_docs_manager, "rag_manager", None)
rename_personal_upload_owner(
old_username,
new_username,
personal_docs_manager=personal_docs_manager,
rag_manager=rag_manager,
)
except Exception as e:
logger.warning("Failed to rename personal RAG upload owner references %s -> %s: %s", old_username, new_username, e)
# skills: SKILL.md frontmatter carries owner: <username>; the usage # skills: SKILL.md frontmatter carries owner: <username>; the usage
# sidecar (_usage.json) keys entries as owner::skill-name. Both must # sidecar (_usage.json) keys entries as owner::skill-name. Both must
# be updated or the renamed user's Skills panel goes empty. # be updated or the renamed user's Skills panel goes empty.
+5 -2
View File
@@ -102,8 +102,11 @@ def _owner_session_filter(q, user):
The owner backfill runs in init_db before the app serves requests, so The owner backfill runs in init_db before the app serves requests, so
by the time this filter is live there are no NULL-owner rows to leak; by the time this filter is live there are no NULL-owner rows to leak;
we therefore match the owner strictly.""" we therefore match the owner strictly for authenticated callers."""
if user is None: if not user:
from src.auth_helpers import _auth_disabled
if user == "" or _auth_disabled():
return q
return q.filter(False) return q.filter(False)
return q.filter(Document.owner == user) return q.filter(Document.owner == user)
+86 -3
View File
@@ -2,8 +2,9 @@
"""Routes for personal documents management.""" """Routes for personal documents management."""
import os import os
import logging import logging
import shutil
import uuid import uuid
from typing import List, Tuple from typing import Any, Dict, List, Tuple
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
from src.request_models import DirectoryRequest from src.request_models import DirectoryRequest
from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR from core.constants import BASE_DIR, PERSONAL_DIR, PERSONAL_UPLOADS_DIR
@@ -18,14 +19,15 @@ UPLOADS_DIR = PERSONAL_UPLOADS_DIR
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _personal_upload_dir_for_owner(owner: str | None) -> str: def _personal_upload_dir_for_owner(owner: str | None, *, create: bool = True) -> str:
"""Return the per-owner upload directory used for direct RAG uploads.""" """Return the per-owner upload directory used for direct RAG uploads."""
owner_segment = secure_filename((owner or "local").strip())[:80] or "local" owner_segment = secure_filename((owner or "local").strip())[:80] or "local"
upload_dir = os.path.abspath(os.path.join(UPLOADS_DIR, owner_segment)) upload_dir = os.path.abspath(os.path.join(UPLOADS_DIR, owner_segment))
base_abs = os.path.abspath(UPLOADS_DIR) base_abs = os.path.abspath(UPLOADS_DIR)
if os.path.commonpath([upload_dir, base_abs]) != base_abs: if os.path.commonpath([upload_dir, base_abs]) != base_abs:
raise ValueError("Unsafe upload owner path") raise ValueError("Unsafe upload owner path")
os.makedirs(upload_dir, exist_ok=True) if create:
os.makedirs(upload_dir, exist_ok=True)
return upload_dir return upload_dir
@@ -44,6 +46,87 @@ def _unique_personal_upload_path(upload_dir: str, original_name: str | None) ->
raise ValueError("Unsafe upload filename") raise ValueError("Unsafe upload filename")
return file_path, filename, safe_name return file_path, filename, safe_name
def _unique_existing_target(path: str) -> str:
"""Return a non-existing sibling path for rename collision handling."""
if not os.path.exists(path):
return path
stem, ext = os.path.splitext(path)
while True:
candidate = f"{stem}-{uuid.uuid4().hex[:10]}{ext}"
if not os.path.exists(candidate):
return candidate
def _remove_empty_tree(path: str) -> None:
"""Best-effort removal of empty directories under ``path``."""
if not os.path.isdir(path):
return
for root, dirs, _files in os.walk(path, topdown=False):
for dirname in dirs:
candidate = os.path.join(root, dirname)
try:
os.rmdir(candidate)
except OSError:
pass
try:
os.rmdir(path)
except OSError:
pass
def rename_personal_upload_owner(
old_owner: str,
new_owner: str,
*,
personal_docs_manager: Any = None,
rag_manager: Any = None,
) -> Dict[str, Any]:
"""Move direct personal uploads and rewrite RAG owner metadata on user rename."""
old_dir = _personal_upload_dir_for_owner(old_owner, create=False)
new_dir = _personal_upload_dir_for_owner(new_owner, create=False)
path_map: Dict[str, str] = {}
moved_files = 0
if os.path.isdir(old_dir) and old_dir != new_dir:
os.makedirs(new_dir, exist_ok=True)
for root, _dirs, files in os.walk(old_dir):
rel_root = os.path.relpath(root, old_dir)
target_root = new_dir if rel_root == "." else os.path.join(new_dir, rel_root)
os.makedirs(target_root, exist_ok=True)
for filename in files:
source = os.path.abspath(os.path.join(root, filename))
target = _unique_existing_target(os.path.abspath(os.path.join(target_root, filename)))
shutil.move(source, target)
path_map[source] = target
moved_files += 1
_remove_empty_tree(old_dir)
if personal_docs_manager is not None:
rename_directory = getattr(personal_docs_manager, "rename_directory", None)
if callable(rename_directory):
rename_directory(old_dir, new_dir, path_map=path_map)
rag_result = None
if rag_manager is not None:
rename_owner = getattr(rag_manager, "rename_owner", None)
if callable(rename_owner):
rag_result = rename_owner(
old_owner,
new_owner,
path_map=path_map,
path_prefixes=[(old_dir, new_dir)],
)
return {
"old_dir": old_dir,
"new_dir": new_dir,
"moved_files": moved_files,
"path_map": path_map,
"rag_result": rag_result,
}
def setup_personal_routes(personal_docs_manager, rag_manager, rag_available): def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
""" """
Setup personal documents related routes. Setup personal documents related routes.
+13 -2
View File
@@ -1004,6 +1004,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
""" """
from src.llm_core import llm_call from src.llm_core import llm_call
user = effective_user(request) user = effective_user(request)
single_user_mode = not user and _auth_disabled()
user_sessions = session_manager.get_sessions_for_user(user) user_sessions = session_manager.get_sessions_for_user(user)
# Delete empty and throwaway sessions before sorting # Delete empty and throwaway sessions before sorting
@@ -1022,7 +1023,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
} }
_THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages _THROWAWAY_MAX_MESSAGES = 4 # only delete if <= this many messages
try: try:
rows = db.query(DbSession).filter(DbSession.archived == False, DbSession.owner == user).limit(2000).all() rows_q = db.query(DbSession).filter(DbSession.archived == False)
if user:
rows_q = rows_q.filter(DbSession.owner == user)
elif not single_user_mode:
rows_q = rows_q.filter(DbSession.owner == user)
rows = rows_q.limit(2000).all()
folder_map = {r.id: r.folder for r in rows} folder_map = {r.id: r.folder for r in rows}
# Precompute per-session message counts in TWO aggregate queries # Precompute per-session message counts in TWO aggregate queries
# instead of 13 queries PER session — with many chats the per-row # instead of 13 queries PER session — with many chats the per-row
@@ -1242,7 +1248,12 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
db = SessionLocal() db = SessionLocal()
try: try:
for sid, folder_name in assignments.items(): for sid, folder_name in assignments.items():
db_session = db.query(DbSession).filter(DbSession.id == sid, DbSession.owner == user).first() db_session_q = db.query(DbSession).filter(DbSession.id == sid)
if user:
db_session_q = db_session_q.filter(DbSession.owner == user)
elif not single_user_mode:
db_session_q = db_session_q.filter(DbSession.owner == user)
db_session = db_session_q.first()
if db_session: if db_session:
db_session.folder = folder_name db_session.folder = folder_name
db_session.updated_at = datetime.utcnow() db_session.updated_at = datetime.utcnow()
+41
View File
@@ -322,6 +322,47 @@ class PersonalDocsManager:
else: else:
logger.info(f"Directory not in index: {directory}") logger.info(f"Directory not in index: {directory}")
def rename_directory(self, old_directory: str, new_directory: str, *, path_map: Dict[str, str] = None):
"""Rewrite tracked directory and excluded-file paths after an owner rename."""
old_directory = os.path.abspath(old_directory)
new_directory = os.path.abspath(new_directory)
path_map = {os.path.abspath(k): os.path.abspath(v) for k, v in (path_map or {}).items()}
def rewrite(path: str) -> str:
abs_path = os.path.abspath(path)
mapped = path_map.get(abs_path)
if mapped:
return mapped
if abs_path == old_directory:
return new_directory
if abs_path.startswith(old_directory + os.sep):
return new_directory + abs_path[len(old_directory):]
return abs_path
changed_dirs = False
rewritten_dirs = []
for directory in self.indexed_directories:
rewritten = rewrite(directory)
changed_dirs = changed_dirs or rewritten != os.path.abspath(directory)
if rewritten not in rewritten_dirs:
rewritten_dirs.append(rewritten)
if changed_dirs:
self.indexed_directories = rewritten_dirs
self.save_directories()
changed_excluded = False
rewritten_excluded = set()
for path in self.excluded_files:
rewritten = rewrite(path)
changed_excluded = changed_excluded or rewritten != os.path.abspath(path)
rewritten_excluded.add(rewritten)
if changed_excluded:
self.excluded_files = rewritten_excluded
self._save_excluded()
if changed_dirs or changed_excluded:
self.refresh_index()
def get_indexed_directories(self): def get_indexed_directories(self):
"""Get the list of all indexed directories.""" """Get the list of all indexed directories."""
return self.indexed_directories.copy() return self.indexed_directories.copy()
+86
View File
@@ -50,6 +50,23 @@ def _generate_doc_id(text: str, owner: str = "") -> str:
return f"doc_{hashlib.sha256(key.encode('utf-8')).hexdigest()[:16]}" return f"doc_{hashlib.sha256(key.encode('utf-8')).hexdigest()[:16]}"
def _rewrite_owner_path(value: str, path_map: Dict[str, str], path_prefixes: List[tuple]) -> str:
if not isinstance(value, str) or not value:
return value
abs_value = os.path.abspath(value)
mapped = path_map.get(abs_value)
if mapped:
return mapped
for old_prefix, new_prefix in path_prefixes:
old_abs = os.path.abspath(old_prefix)
new_abs = os.path.abspath(new_prefix)
if abs_value == old_abs:
return new_abs
if abs_value.startswith(old_abs + os.sep):
return new_abs + abs_value[len(old_abs):]
return value
class VectorRAG: class VectorRAG:
"""RAG system using ChromaDB vector storage with hybrid search.""" """RAG system using ChromaDB vector storage with hybrid search."""
@@ -250,6 +267,75 @@ class VectorRAG:
"failed_count": len(docs) - len(valid), "failed_count": len(docs) - len(valid),
} }
def rename_owner(
self,
old_owner: str,
new_owner: str,
*,
path_map: Optional[Dict[str, str]] = None,
path_prefixes: Optional[List[tuple]] = None,
) -> Dict[str, Any]:
"""Rewrite existing RAG metadata after an auth username rename."""
if not self.healthy:
return {"success": False, "updated_count": 0, "message": "Collection not initialized"}
old_owner = (old_owner or "").strip().lower()
new_owner = (new_owner or "").strip().lower()
if not old_owner or not new_owner or old_owner == new_owner:
return {"success": True, "updated_count": 0, "message": "No owner rename needed"}
path_map = {os.path.abspath(k): os.path.abspath(v) for k, v in (path_map or {}).items()}
path_prefixes = path_prefixes or []
updated_ids = set()
failed_count = 0
for lane_name, collection in self._collections_for_delete():
try:
results = collection.get(
where={"owner": old_owner},
include=["metadatas"],
)
except Exception as e:
logger.warning("rename_owner metadata scan failed in %s lane: %s", lane_name, e)
failed_count += 1
continue
ids = results.get("ids") or []
metadatas = results.get("metadatas") or []
if not ids:
continue
new_metas = []
selected_ids = []
for doc_id, meta in zip(ids, metadatas):
if not isinstance(meta, dict):
continue
next_meta = dict(meta)
if str(next_meta.get("owner", "")).strip().lower() == old_owner:
next_meta["owner"] = new_owner
for key in ("source", "directory"):
next_meta[key] = _rewrite_owner_path(next_meta.get(key), path_map, path_prefixes)
selected_ids.append(doc_id)
new_metas.append(next_meta)
if not selected_ids:
continue
try:
collection.update(ids=selected_ids, metadatas=new_metas)
updated_ids.update(selected_ids)
except Exception as e:
logger.warning("rename_owner metadata update failed in %s lane: %s", lane_name, e)
failed_count += len(selected_ids)
success = failed_count == 0
return {
"success": success,
"updated_count": len(updated_ids),
"failed_count": failed_count,
"message": f"Updated {len(updated_ids)} RAG chunk(s)",
}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Search — hybrid: vector similarity + keyword overlap # Search — hybrid: vector similarity + keyword overlap
# ------------------------------------------------------------------ # ------------------------------------------------------------------
+43
View File
@@ -80,6 +80,16 @@ def test_password_change_allows_new_password_and_blocks_old_password(tmp_path):
assert mgr.create_session("alice", "new-password") is not None assert mgr.create_session("alice", "new-password") is not None
def test_create_session_trusted_rejects_username_renamed_after_verification(tmp_path):
mgr = _make_manager(tmp_path)
assert mgr.create_user("admin", "admin-password", is_admin=True)
assert mgr.verify_password("alice", "old-password") is True
assert mgr.rename_user("alice", "alice2", "admin") is True
assert mgr.create_session_trusted("alice") is None
def _change_password_endpoint(auth_manager): def _change_password_endpoint(auth_manager):
sys.modules.pop("routes.auth_routes", None) sys.modules.pop("routes.auth_routes", None)
_real_core_package() _real_core_package()
@@ -92,6 +102,39 @@ def _change_password_endpoint(auth_manager):
raise AssertionError("change-password route not found") raise AssertionError("change-password route not found")
def _login_endpoint(auth_manager):
sys.modules.pop("routes.auth_routes", None)
_real_core_package()
from routes.auth_routes import LoginRequest, setup_auth_routes
router = setup_auth_routes(auth_manager)
for route in router.routes:
if getattr(route, "path", None) == "/api/auth/login":
return route.endpoint, LoginRequest
raise AssertionError("login route not found")
def test_login_route_does_not_set_cookie_when_trusted_session_rejects_stale_user(monkeypatch):
auth = MagicMock()
auth.verify_password.return_value = True
auth.totp_enabled.return_value = False
auth.create_session_trusted.return_value = None
endpoint, LoginRequest = _login_endpoint(auth)
monkeypatch.setattr(
"routes.auth_routes.asyncio.to_thread",
lambda fn, *args, **kwargs: _immediate_to_thread(fn, *args, **kwargs),
)
request = SimpleNamespace(client=SimpleNamespace(host="127.0.0.1"))
response = MagicMock()
body = LoginRequest(username="alice", password="old-password")
with pytest.raises(HTTPException) as exc:
asyncio.run(endpoint(body=body, request=request, response=response))
assert exc.value.status_code == 401
response.set_cookie.assert_not_called()
def test_change_password_route_revokes_other_sessions_after_success(monkeypatch): def test_change_password_route_revokes_other_sessions_after_success(monkeypatch):
auth = MagicMock() auth = MagicMock()
auth.get_username_for_token.return_value = "alice" auth.get_username_for_token.return_value = "alice"
@@ -25,6 +25,7 @@ import routes.document_routes as droutes
from core.database import Document from core.database import Document
from core.database import Session as DbSession from core.database import Session as DbSession
from routes.document_helpers import DocumentPatch from routes.document_helpers import DocumentPatch
from routes.document_helpers import _owner_session_filter
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False) _TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
_ENGINE = create_engine( _ENGINE = create_engine(
@@ -141,3 +142,18 @@ async def test_list_documents_filters_foreign_docs_in_visible_session():
assert bob_doc not in ids assert bob_doc not in ids
finally: finally:
droutes.SessionLocal = previous_session_local droutes.SessionLocal = previous_session_local
def test_owner_session_filter_noops_for_auth_disabled_single_user(monkeypatch):
monkeypatch.setenv("AUTH_ENABLED", "false")
previous_session_local = _bind_test_db()
try:
_alice_session, _bob_session, alice_doc, _bob_doc, _legacy_doc = _seed()
db = _TS()
try:
q = db.query(Document).filter(Document.id == alice_doc)
assert _owner_session_filter(q, None).first().id == alice_doc
finally:
db.close()
finally:
droutes.SessionLocal = previous_session_local
+42
View File
@@ -1,5 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from routes import personal_routes from routes import personal_routes
@@ -42,3 +43,44 @@ def test_personal_upload_paths_stay_under_upload_root(tmp_path, monkeypatch):
assert os.path.commonpath([file_path, upload_dir]) == upload_dir assert os.path.commonpath([file_path, upload_dir]) == upload_dir
assert Path(file_path).name == stored_name assert Path(file_path).name == stored_name
assert display_name == "env" assert display_name == "env"
def test_rename_personal_upload_owner_moves_files_and_rewrites_rag(tmp_path, monkeypatch):
monkeypatch.setattr(personal_routes, "UPLOADS_DIR", str(tmp_path))
old_dir = Path(personal_routes._personal_upload_dir_for_owner("alice"))
old_file = old_dir / "note.txt"
old_file.write_text("alice private RAG note", encoding="utf-8")
manager_calls = []
rag_calls = []
manager = SimpleNamespace(
rename_directory=lambda old, new, path_map=None: manager_calls.append((old, new, dict(path_map or {}))),
)
rag = SimpleNamespace(
rename_owner=lambda old, new, path_map=None, path_prefixes=None: rag_calls.append(
(old, new, dict(path_map or {}), list(path_prefixes or []))
) or {"success": True, "updated_count": 1},
)
result = personal_routes.rename_personal_upload_owner(
"alice",
"alice2",
personal_docs_manager=manager,
rag_manager=rag,
)
new_dir = Path(personal_routes._personal_upload_dir_for_owner("alice2"))
new_file = new_dir / "note.txt"
assert old_file.exists() is False
assert new_file.read_text(encoding="utf-8") == "alice private RAG note"
assert result["moved_files"] == 1
assert manager_calls == [(str(old_dir), str(new_dir), {str(old_file): str(new_file)})]
assert rag_calls == [
(
"alice",
"alice2",
{str(old_file): str(new_file)},
[(str(old_dir), str(new_dir))],
)
]
+81
View File
@@ -0,0 +1,81 @@
from src.rag_vector import VectorRAG
class _FakeCollection:
def __init__(self, docs):
self._docs = {
doc_id: {"document": document, "metadata": dict(metadata)}
for doc_id, document, metadata in docs
}
def count(self):
return len(self._docs)
def get(self, where=None, include=None):
rows = []
for doc_id, row in self._docs.items():
metadata = row["metadata"]
if where and any(metadata.get(key) != value for key, value in where.items()):
continue
rows.append((doc_id, row))
return {
"ids": [doc_id for doc_id, _row in rows],
"documents": [row["document"] for _doc_id, row in rows],
"metadatas": [row["metadata"] for _doc_id, row in rows],
}
def update(self, ids, metadatas):
for doc_id, metadata in zip(ids, metadatas):
self._docs[doc_id]["metadata"] = dict(metadata)
def _store(collection):
store = VectorRAG.__new__(VectorRAG)
store._collection = collection
store._lanes = []
store._healthy = True
return store
def test_rename_owner_updates_metadata_used_by_owner_filtered_search(tmp_path):
old_dir = tmp_path / "alice"
new_dir = tmp_path / "alice2"
old_file = old_dir / "note.txt"
new_file = new_dir / "note.txt"
collection = _FakeCollection([
(
"doc-old",
"private vector note",
{
"owner": "alice",
"source": str(old_file),
"directory": str(old_dir),
},
),
(
"doc-other",
"other vector note",
{
"owner": "bob",
"source": str(tmp_path / "bob" / "note.txt"),
},
),
])
store = _store(collection)
result = store.rename_owner(
"alice",
"alice2",
path_map={str(old_file): str(new_file)},
path_prefixes=[(str(old_dir), str(new_dir))],
)
assert result["success"] is True
assert result["updated_count"] == 1
assert store._keyword_search_fallback("private", k=10, owner="alice") == []
renamed = store._keyword_search_fallback("private", k=10, owner="alice2")
assert [row["id"] for row in renamed] == ["doc-old"]
assert renamed[0]["metadata"]["owner"] == "alice2"
assert renamed[0]["metadata"]["source"] == str(new_file)
assert renamed[0]["metadata"]["directory"] == str(new_dir)
assert store._keyword_search_fallback("other", k=10, owner="bob")[0]["id"] == "doc-other"
+55 -1
View File
@@ -70,12 +70,20 @@ def rename_endpoint(monkeypatch, tmp_path):
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
def _request(tmp_path, session_manager=None, token="t", research_handler=None, upload_handler=None): def _request(
tmp_path,
session_manager=None,
token="t",
research_handler=None,
upload_handler=None,
personal_docs_manager=None,
):
state = SimpleNamespace( state = SimpleNamespace(
invalidate_token_cache=lambda: None, invalidate_token_cache=lambda: None,
session_manager=session_manager, session_manager=session_manager,
research_handler=research_handler, research_handler=research_handler,
upload_handler=upload_handler, upload_handler=upload_handler,
personal_docs_manager=personal_docs_manager,
) )
return SimpleNamespace( return SimpleNamespace(
cookies={"odysseus_session": token}, cookies={"odysseus_session": token},
@@ -467,6 +475,52 @@ def test_rename_updates_upload_metadata_owner(rename_endpoint):
assert handler.resolve_upload(upload_id, owner="alice") is None assert handler.resolve_upload(upload_id, owner="alice") is None
def test_rename_updates_personal_rag_upload_owner(rename_endpoint, monkeypatch):
endpoint, _am, tmp_path = rename_endpoint
from routes import personal_routes
monkeypatch.setattr(personal_routes, "UPLOADS_DIR", str(tmp_path / "personal_uploads"))
old_dir = Path(personal_routes._personal_upload_dir_for_owner("alice"))
old_file = old_dir / "note.txt"
old_file.write_text("private RAG note", encoding="utf-8")
manager_calls = []
rag_calls = []
rag = SimpleNamespace(
rename_owner=lambda old, new, path_map=None, path_prefixes=None: rag_calls.append(
(old, new, dict(path_map or {}), list(path_prefixes or []))
) or {"success": True, "updated_count": 1},
)
personal_docs_manager = SimpleNamespace(
rag_manager=rag,
rename_directory=lambda old, new, path_map=None: manager_calls.append(
(old, new, dict(path_map or {}))
),
)
asyncio.run(
endpoint(
"alice",
SimpleNamespace(username="alice2"),
_request(tmp_path, personal_docs_manager=personal_docs_manager),
)
)
new_dir = Path(personal_routes._personal_upload_dir_for_owner("alice2"))
new_file = new_dir / "note.txt"
assert old_file.exists() is False
assert new_file.read_text(encoding="utf-8") == "private RAG note"
assert manager_calls == [(str(old_dir), str(new_dir), {str(old_file): str(new_file)})]
assert rag_calls == [
(
"alice",
"alice2",
{str(old_file): str(new_file)},
[(str(old_dir), str(new_dir))],
)
]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# 5. Skills (SKILL.md frontmatter + _usage.json sidecar) # 5. Skills (SKILL.md frontmatter + _usage.json sidecar)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+59
View File
@@ -7,6 +7,7 @@ import sys
import tempfile import tempfile
import types import types
import uuid import uuid
from datetime import timedelta
import pytest import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
@@ -14,6 +15,7 @@ from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
import core.database as cdb import core.database as cdb
from core.database import ChatMessage as DbMessage
from core.database import Session as DbSession from core.database import Session as DbSession
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False) _TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
@@ -72,3 +74,60 @@ def test_list_sessions_excludes_other_users_sessions(monkeypatch):
returned_ids = {s["id"] for s in result} returned_ids = {s["id"] for s in result}
assert alice_id in returned_ids assert alice_id in returned_ids
assert bob_id not in returned_ids assert bob_id not in returned_ids
def test_auto_sort_skip_llm_cleans_owner_stamped_sessions_when_auth_disabled(monkeypatch):
import routes.session_routes as sr
from unittest.mock import MagicMock
_stub_multipart_if_missing(monkeypatch)
monkeypatch.setenv("AUTH_ENABLED", "false")
monkeypatch.setattr(sr, "SessionLocal", _TS)
monkeypatch.setattr(sr, "effective_user", lambda request: None)
sid = str(uuid.uuid4())
old_time = cdb.utcnow_naive() - timedelta(hours=2)
db = _TS()
try:
db.query(DbMessage).delete()
db.query(DbSession).delete()
db.add(DbSession(
id=sid,
owner="alice",
name="New chat",
endpoint_url="http://localhost",
model="gpt-4",
archived=False,
message_count=1,
created_at=old_time,
updated_at=old_time,
last_message_at=old_time,
last_accessed=old_time,
))
db.add(DbMessage(
id="m-" + uuid.uuid4().hex,
session_id=sid,
role="user",
content="hi",
timestamp=old_time,
))
db.commit()
finally:
db.close()
session = MagicMock(id=sid, name="New chat", model="gpt-4", endpoint_url="http://localhost", rag=False, archived=False)
sm = MagicMock()
sm.get_sessions_for_user.return_value = {sid: session}
router = sr.setup_session_routes(sm, {})
endpoint = next(r.endpoint for r in router.routes
if getattr(r, "path", "") == "/api/sessions/auto-sort"
and "POST" in getattr(r, "methods", set()))
result = endpoint(request=MagicMock(), skip_llm=True)
assert result["deleted_throwaway"] == 1
db = _TS()
try:
assert db.query(DbSession).filter(DbSession.id == sid).first() is None
finally:
db.close()