mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
128 lines
3.6 KiB
Python
128 lines
3.6 KiB
Python
"""Task chaining must not cross owner boundaries."""
|
|
|
|
import tempfile
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from tests.helpers.import_state import clear_fake_database_modules
|
|
|
|
clear_fake_database_modules()
|
|
|
|
import core.database as cdb
|
|
import routes.task_routes as task_routes
|
|
from core.database import ScheduledTask
|
|
|
|
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
|
_ENGINE = create_engine(
|
|
f"sqlite:///{_TMPDB.name}",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=NullPool,
|
|
)
|
|
cdb.Base.metadata.create_all(_ENGINE)
|
|
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
|
|
task_routes.SessionLocal = _TS
|
|
|
|
|
|
def _req(user="alice"):
|
|
return SimpleNamespace(state=SimpleNamespace(current_user=user))
|
|
|
|
|
|
def _endpoint(method, path):
|
|
task_routes.SessionLocal = _TS
|
|
router = task_routes.setup_task_routes(MagicMock())
|
|
for route in router.routes:
|
|
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
|
|
return route.endpoint
|
|
raise RuntimeError(f"{method} {path} not found")
|
|
|
|
|
|
def _seed_task(task_id, owner, *, then_task_id=None):
|
|
db = _TS()
|
|
try:
|
|
task = ScheduledTask(
|
|
id=task_id,
|
|
owner=owner,
|
|
name=task_id,
|
|
prompt="do work",
|
|
task_type="llm",
|
|
trigger_type="webhook",
|
|
status="active",
|
|
output_target="session",
|
|
then_task_id=then_task_id,
|
|
)
|
|
db.add(task)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_task_rejects_cross_owner_chain_target():
|
|
_seed_task("bob-target-create", "bob")
|
|
create_task = _endpoint("POST", "/api/tasks")
|
|
|
|
req = task_routes.TaskCreate(
|
|
prompt="alice source",
|
|
trigger_type="webhook",
|
|
then_task_id="bob-target-create",
|
|
)
|
|
with pytest.raises(HTTPException) as exc:
|
|
await create_task(_req("alice"), req)
|
|
|
|
assert exc.value.status_code == 404
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_task_rejects_cross_owner_chain_target():
|
|
_seed_task("alice-source-update", "alice")
|
|
_seed_task("bob-target-update", "bob")
|
|
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
|
|
|
|
with pytest.raises(HTTPException) as exc:
|
|
await update_task(
|
|
_req("alice"),
|
|
"alice-source-update",
|
|
task_routes.TaskUpdate(then_task_id="bob-target-update"),
|
|
)
|
|
|
|
assert exc.value.status_code == 404
|
|
db = _TS()
|
|
try:
|
|
source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first()
|
|
assert source.then_task_id is None
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_task_allows_same_owner_chain_target():
|
|
_seed_task("alice-source-allow", "alice")
|
|
_seed_task("alice-target-allow", "alice")
|
|
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
|
|
|
|
out = await update_task(
|
|
_req("alice"),
|
|
"alice-source-allow",
|
|
task_routes.TaskUpdate(then_task_id="alice-target-allow"),
|
|
)
|
|
|
|
assert out["then_task_id"] == "alice-target-allow"
|
|
|
|
|
|
def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe():
|
|
_seed_task("bob-target-cycle", "bob")
|
|
from src.task_scheduler import TaskScheduler
|
|
|
|
scheduler = TaskScheduler.__new__(TaskScheduler)
|
|
db = _TS()
|
|
try:
|
|
assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True
|
|
finally:
|
|
db.close()
|