Enforce task chain owner scope (#3006)

This commit is contained in:
Vykos
2026-06-07 12:43:43 +02:00
committed by GitHub
parent 3cff06781e
commit 7b4e6c4c1b
3 changed files with 154 additions and 4 deletions
+127
View File
@@ -0,0 +1,127 @@
"""Task chaining must not cross owner boundaries."""
import tempfile
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.helpers.import_state import clear_fake_database_modules
clear_fake_database_modules()
import core.database as cdb
import routes.task_routes as task_routes
from core.database import ScheduledTask
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
_ENGINE = create_engine(
f"sqlite:///{_TMPDB.name}",
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
cdb.Base.metadata.create_all(_ENGINE)
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
task_routes.SessionLocal = _TS
def _req(user="alice"):
return SimpleNamespace(state=SimpleNamespace(current_user=user))
def _endpoint(method, path):
task_routes.SessionLocal = _TS
router = task_routes.setup_task_routes(MagicMock())
for route in router.routes:
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
return route.endpoint
raise RuntimeError(f"{method} {path} not found")
def _seed_task(task_id, owner, *, then_task_id=None):
db = _TS()
try:
task = ScheduledTask(
id=task_id,
owner=owner,
name=task_id,
prompt="do work",
task_type="llm",
trigger_type="webhook",
status="active",
output_target="session",
then_task_id=then_task_id,
)
db.add(task)
db.commit()
finally:
db.close()
@pytest.mark.asyncio
async def test_create_task_rejects_cross_owner_chain_target():
_seed_task("bob-target-create", "bob")
create_task = _endpoint("POST", "/api/tasks")
req = task_routes.TaskCreate(
prompt="alice source",
trigger_type="webhook",
then_task_id="bob-target-create",
)
with pytest.raises(HTTPException) as exc:
await create_task(_req("alice"), req)
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_update_task_rejects_cross_owner_chain_target():
_seed_task("alice-source-update", "alice")
_seed_task("bob-target-update", "bob")
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
with pytest.raises(HTTPException) as exc:
await update_task(
_req("alice"),
"alice-source-update",
task_routes.TaskUpdate(then_task_id="bob-target-update"),
)
assert exc.value.status_code == 404
db = _TS()
try:
source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first()
assert source.then_task_id is None
finally:
db.close()
@pytest.mark.asyncio
async def test_update_task_allows_same_owner_chain_target():
_seed_task("alice-source-allow", "alice")
_seed_task("alice-target-allow", "alice")
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
out = await update_task(
_req("alice"),
"alice-source-allow",
task_routes.TaskUpdate(then_task_id="alice-target-allow"),
)
assert out["then_task_id"] == "alice-target-allow"
def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe():
_seed_task("bob-target-cycle", "bob")
from src.task_scheduler import TaskScheduler
scheduler = TaskScheduler.__new__(TaskScheduler)
db = _TS()
try:
assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True
finally:
db.close()