mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
"""Document session owner-scope regressions.
|
|
|
|
Route handlers are called directly, matching the pattern used by the existing
|
|
document route tests. This keeps coverage on the real closures without spinning
|
|
up middleware.
|
|
"""
|
|
|
|
import tempfile
|
|
import uuid
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from tests.helpers.import_state import clear_fake_database_modules
|
|
|
|
clear_fake_database_modules()
|
|
|
|
import core.database as cdb
|
|
import routes.document_routes as droutes
|
|
from core.database import Document
|
|
from core.database import Session as DbSession
|
|
from routes.document_helpers import DocumentPatch
|
|
|
|
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
|
_ENGINE = create_engine(
|
|
f"sqlite:///{_TMPDB.name}",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=NullPool,
|
|
)
|
|
cdb.Base.metadata.create_all(_ENGINE)
|
|
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
|
|
|
|
|
|
def _req(user="alice"):
|
|
return SimpleNamespace(state=SimpleNamespace(current_user=user))
|
|
|
|
|
|
def _endpoint(method, path):
|
|
router = droutes.setup_document_routes(MagicMock(), None)
|
|
for route in router.routes:
|
|
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
|
|
return route.endpoint
|
|
raise RuntimeError(f"{method} {path} not found")
|
|
|
|
|
|
def _bind_test_db():
|
|
previous = droutes.SessionLocal
|
|
droutes.SessionLocal = _TS
|
|
return previous
|
|
|
|
|
|
def _seed():
|
|
alice_session = "alice-" + uuid.uuid4().hex[:8]
|
|
bob_session = "bob-" + uuid.uuid4().hex[:8]
|
|
alice_doc = str(uuid.uuid4())
|
|
bob_doc = str(uuid.uuid4())
|
|
legacy_doc = str(uuid.uuid4())
|
|
db = _TS()
|
|
try:
|
|
db.add(DbSession(id=alice_session, owner="alice", name="alice", model="m", endpoint_url="http://x"))
|
|
db.add(DbSession(id=bob_session, owner="bob", name="bob", model="m", endpoint_url="http://x"))
|
|
db.add(Document(
|
|
id=alice_doc,
|
|
session_id=alice_session,
|
|
title="alice doc",
|
|
language="markdown",
|
|
current_content="alice body",
|
|
version_count=1,
|
|
is_active=True,
|
|
owner="alice",
|
|
))
|
|
db.add(Document(
|
|
id=bob_doc,
|
|
session_id=bob_session,
|
|
title="bob doc",
|
|
language="markdown",
|
|
current_content="bob body",
|
|
version_count=1,
|
|
is_active=True,
|
|
owner="bob",
|
|
))
|
|
db.add(Document(
|
|
id=legacy_doc,
|
|
session_id=alice_session,
|
|
title="legacy doc",
|
|
language="markdown",
|
|
current_content="legacy body",
|
|
version_count=1,
|
|
is_active=True,
|
|
owner=None,
|
|
))
|
|
db.commit()
|
|
return alice_session, bob_session, alice_doc, bob_doc, legacy_doc
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_patch_document_rejects_cross_owner_session_link():
|
|
previous_session_local = _bind_test_db()
|
|
try:
|
|
patch_document = _endpoint("PATCH", "/api/document/{doc_id}")
|
|
alice_session, bob_session, _alice_doc, bob_doc, _legacy_doc = _seed()
|
|
|
|
with pytest.raises(HTTPException) as exc:
|
|
await patch_document(_req("bob"), bob_doc, DocumentPatch(session_id=alice_session))
|
|
|
|
assert exc.value.status_code == 404
|
|
db = _TS()
|
|
try:
|
|
assert db.query(Document).filter(Document.id == bob_doc).first().session_id == bob_session
|
|
finally:
|
|
db.close()
|
|
finally:
|
|
droutes.SessionLocal = previous_session_local
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_documents_filters_foreign_docs_in_visible_session():
|
|
previous_session_local = _bind_test_db()
|
|
try:
|
|
list_documents = _endpoint("GET", "/api/documents/{session_id}")
|
|
alice_session, _bob_session, alice_doc, bob_doc, legacy_doc = _seed()
|
|
db = _TS()
|
|
try:
|
|
db.query(Document).filter(Document.id == bob_doc).update({"session_id": alice_session})
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
rows = await list_documents(_req("alice"), alice_session)
|
|
ids = {row["id"] for row in rows}
|
|
|
|
assert alice_doc in ids
|
|
assert legacy_doc in ids
|
|
assert bob_doc not in ids
|
|
finally:
|
|
droutes.SessionLocal = previous_session_local
|