mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 17:55:26 -04:00
Enforce task chain owner scope (#3006)
This commit is contained in:
@@ -0,0 +1,127 @@
|
||||
"""Task chaining must not cross owner boundaries."""
|
||||
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from tests.helpers.import_state import clear_fake_database_modules
|
||||
|
||||
clear_fake_database_modules()
|
||||
|
||||
import core.database as cdb
|
||||
import routes.task_routes as task_routes
|
||||
from core.database import ScheduledTask
|
||||
|
||||
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_ENGINE = create_engine(
|
||||
f"sqlite:///{_TMPDB.name}",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=NullPool,
|
||||
)
|
||||
cdb.Base.metadata.create_all(_ENGINE)
|
||||
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
|
||||
task_routes.SessionLocal = _TS
|
||||
|
||||
|
||||
def _req(user="alice"):
|
||||
return SimpleNamespace(state=SimpleNamespace(current_user=user))
|
||||
|
||||
|
||||
def _endpoint(method, path):
|
||||
task_routes.SessionLocal = _TS
|
||||
router = task_routes.setup_task_routes(MagicMock())
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise RuntimeError(f"{method} {path} not found")
|
||||
|
||||
|
||||
def _seed_task(task_id, owner, *, then_task_id=None):
|
||||
db = _TS()
|
||||
try:
|
||||
task = ScheduledTask(
|
||||
id=task_id,
|
||||
owner=owner,
|
||||
name=task_id,
|
||||
prompt="do work",
|
||||
task_type="llm",
|
||||
trigger_type="webhook",
|
||||
status="active",
|
||||
output_target="session",
|
||||
then_task_id=then_task_id,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_rejects_cross_owner_chain_target():
|
||||
_seed_task("bob-target-create", "bob")
|
||||
create_task = _endpoint("POST", "/api/tasks")
|
||||
|
||||
req = task_routes.TaskCreate(
|
||||
prompt="alice source",
|
||||
trigger_type="webhook",
|
||||
then_task_id="bob-target-create",
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await create_task(_req("alice"), req)
|
||||
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_rejects_cross_owner_chain_target():
|
||||
_seed_task("alice-source-update", "alice")
|
||||
_seed_task("bob-target-update", "bob")
|
||||
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_task(
|
||||
_req("alice"),
|
||||
"alice-source-update",
|
||||
task_routes.TaskUpdate(then_task_id="bob-target-update"),
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 404
|
||||
db = _TS()
|
||||
try:
|
||||
source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first()
|
||||
assert source.then_task_id is None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_allows_same_owner_chain_target():
|
||||
_seed_task("alice-source-allow", "alice")
|
||||
_seed_task("alice-target-allow", "alice")
|
||||
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
|
||||
|
||||
out = await update_task(
|
||||
_req("alice"),
|
||||
"alice-source-allow",
|
||||
task_routes.TaskUpdate(then_task_id="alice-target-allow"),
|
||||
)
|
||||
|
||||
assert out["then_task_id"] == "alice-target-allow"
|
||||
|
||||
|
||||
def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe():
|
||||
_seed_task("bob-target-cycle", "bob")
|
||||
from src.task_scheduler import TaskScheduler
|
||||
|
||||
scheduler = TaskScheduler.__new__(TaskScheduler)
|
||||
db = _TS()
|
||||
try:
|
||||
assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user