Files
odysseus/tests/test_document_session_owner_scope.py

144 lines
4.4 KiB
Python

"""Document session owner-scope regressions.
Route handlers are called directly, matching the pattern used by the existing
document route tests. This keeps coverage on the real closures without spinning
up middleware.
"""
import tempfile
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.helpers.import_state import clear_fake_database_modules
clear_fake_database_modules()
import core.database as cdb
import routes.document_routes as droutes
from core.database import Document
from core.database import Session as DbSession
from routes.document_helpers import DocumentPatch
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
_ENGINE = create_engine(
f"sqlite:///{_TMPDB.name}",
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
cdb.Base.metadata.create_all(_ENGINE)
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
def _req(user="alice"):
return SimpleNamespace(state=SimpleNamespace(current_user=user))
def _endpoint(method, path):
router = droutes.setup_document_routes(MagicMock(), None)
for route in router.routes:
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
return route.endpoint
raise RuntimeError(f"{method} {path} not found")
def _bind_test_db():
previous = droutes.SessionLocal
droutes.SessionLocal = _TS
return previous
def _seed():
alice_session = "alice-" + uuid.uuid4().hex[:8]
bob_session = "bob-" + uuid.uuid4().hex[:8]
alice_doc = str(uuid.uuid4())
bob_doc = str(uuid.uuid4())
legacy_doc = str(uuid.uuid4())
db = _TS()
try:
db.add(DbSession(id=alice_session, owner="alice", name="alice", model="m", endpoint_url="http://x"))
db.add(DbSession(id=bob_session, owner="bob", name="bob", model="m", endpoint_url="http://x"))
db.add(Document(
id=alice_doc,
session_id=alice_session,
title="alice doc",
language="markdown",
current_content="alice body",
version_count=1,
is_active=True,
owner="alice",
))
db.add(Document(
id=bob_doc,
session_id=bob_session,
title="bob doc",
language="markdown",
current_content="bob body",
version_count=1,
is_active=True,
owner="bob",
))
db.add(Document(
id=legacy_doc,
session_id=alice_session,
title="legacy doc",
language="markdown",
current_content="legacy body",
version_count=1,
is_active=True,
owner=None,
))
db.commit()
return alice_session, bob_session, alice_doc, bob_doc, legacy_doc
finally:
db.close()
@pytest.mark.asyncio
async def test_patch_document_rejects_cross_owner_session_link():
previous_session_local = _bind_test_db()
try:
patch_document = _endpoint("PATCH", "/api/document/{doc_id}")
alice_session, bob_session, _alice_doc, bob_doc, _legacy_doc = _seed()
with pytest.raises(HTTPException) as exc:
await patch_document(_req("bob"), bob_doc, DocumentPatch(session_id=alice_session))
assert exc.value.status_code == 404
db = _TS()
try:
assert db.query(Document).filter(Document.id == bob_doc).first().session_id == bob_session
finally:
db.close()
finally:
droutes.SessionLocal = previous_session_local
@pytest.mark.asyncio
async def test_list_documents_filters_foreign_docs_in_visible_session():
previous_session_local = _bind_test_db()
try:
list_documents = _endpoint("GET", "/api/documents/{session_id}")
alice_session, _bob_session, alice_doc, bob_doc, legacy_doc = _seed()
db = _TS()
try:
db.query(Document).filter(Document.id == bob_doc).update({"session_id": alice_session})
db.commit()
finally:
db.close()
rows = await list_documents(_req("alice"), alice_session)
ids = {row["id"] for row in rows}
assert alice_doc in ids
assert legacy_doc in ids
assert bob_doc not in ids
finally:
droutes.SessionLocal = previous_session_local