Files
odysseus/tests/test_task_chain_owner_scope.py
2026-06-07 12:43:43 +02:00

128 lines
3.6 KiB
Python

"""Task chaining must not cross owner boundaries."""
import tempfile
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.helpers.import_state import clear_fake_database_modules
clear_fake_database_modules()
import core.database as cdb
import routes.task_routes as task_routes
from core.database import ScheduledTask
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
_ENGINE = create_engine(
f"sqlite:///{_TMPDB.name}",
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
cdb.Base.metadata.create_all(_ENGINE)
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
task_routes.SessionLocal = _TS
def _req(user="alice"):
return SimpleNamespace(state=SimpleNamespace(current_user=user))
def _endpoint(method, path):
task_routes.SessionLocal = _TS
router = task_routes.setup_task_routes(MagicMock())
for route in router.routes:
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
return route.endpoint
raise RuntimeError(f"{method} {path} not found")
def _seed_task(task_id, owner, *, then_task_id=None):
db = _TS()
try:
task = ScheduledTask(
id=task_id,
owner=owner,
name=task_id,
prompt="do work",
task_type="llm",
trigger_type="webhook",
status="active",
output_target="session",
then_task_id=then_task_id,
)
db.add(task)
db.commit()
finally:
db.close()
@pytest.mark.asyncio
async def test_create_task_rejects_cross_owner_chain_target():
_seed_task("bob-target-create", "bob")
create_task = _endpoint("POST", "/api/tasks")
req = task_routes.TaskCreate(
prompt="alice source",
trigger_type="webhook",
then_task_id="bob-target-create",
)
with pytest.raises(HTTPException) as exc:
await create_task(_req("alice"), req)
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_update_task_rejects_cross_owner_chain_target():
_seed_task("alice-source-update", "alice")
_seed_task("bob-target-update", "bob")
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
with pytest.raises(HTTPException) as exc:
await update_task(
_req("alice"),
"alice-source-update",
task_routes.TaskUpdate(then_task_id="bob-target-update"),
)
assert exc.value.status_code == 404
db = _TS()
try:
source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first()
assert source.then_task_id is None
finally:
db.close()
@pytest.mark.asyncio
async def test_update_task_allows_same_owner_chain_target():
_seed_task("alice-source-allow", "alice")
_seed_task("alice-target-allow", "alice")
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
out = await update_task(
_req("alice"),
"alice-source-allow",
task_routes.TaskUpdate(then_task_id="alice-target-allow"),
)
assert out["then_task_id"] == "alice-target-allow"
def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe():
_seed_task("bob-target-cycle", "bob")
from src.task_scheduler import TaskScheduler
scheduler = TaskScheduler.__new__(TaskScheduler)
db = _TS()
try:
assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True
finally:
db.close()