Enforce task chain owner scope (#3006)

This commit is contained in:
Vykos
2026-06-07 12:43:43 +02:00
committed by GitHub
parent 3cff06781e
commit 7b4e6c4c1b
3 changed files with 154 additions and 4 deletions
+17 -2
View File
@@ -429,6 +429,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
except Exception:
return False
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
target_id = (then_task_id or "").strip()
if not target_id:
return None
if current_task_id and target_id == current_task_id:
raise HTTPException(400, "Task cannot chain to itself")
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
if user:
q = q.filter(ScheduledTask.owner == user)
target = q.first()
if not target:
raise HTTPException(404, "Chained task not found")
return target.id
@router.post("")
async def create_task(request: Request, req: TaskCreate):
user = _owner(request)
@@ -492,6 +506,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
task_id = str(uuid.uuid4())
db = SessionLocal()
try:
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
notifications_enabled = (
False if req.task_type == "action" and req.notifications_enabled is None
else bool(req.notifications_enabled) if req.notifications_enabled is not None
@@ -518,7 +533,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
output_target=req.output_target,
model=req.model or None,
endpoint_url=req.endpoint_url or None,
then_task_id=req.then_task_id or None,
then_task_id=then_task_id,
webhook_token=webhook_token,
notifications_enabled=notifications_enabled,
)
@@ -671,7 +686,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
if req.trigger_count is not None:
task.trigger_count = req.trigger_count
if req.then_task_id is not None:
task.then_task_id = req.then_task_id or None
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
if req.notifications_enabled is not None:
task.notifications_enabled = bool(req.notifications_enabled)
if req.cron_expression is not None:
+10 -2
View File
@@ -844,7 +844,13 @@ class TaskScheduler:
# Task chaining — trigger the next task on success
if run.status == "success" and task.then_task_id:
chain_id = task.then_task_id
if not self._has_chain_cycle(db, chain_id):
chain_task = db.query(ScheduledTask).filter(ScheduledTask.id == chain_id).first()
if not chain_task or chain_task.owner != task.owner:
logger.warning(
"Skipping chain from %r: target task %s is missing or not owned by %r",
task.name, chain_id, task.owner,
)
elif not self._has_chain_cycle(db, chain_id, owner=task.owner):
logger.info(f"Chaining: '{task.name}' → task {chain_id}")
asyncio.create_task(self._run_chained(chain_id))
else:
@@ -1791,7 +1797,7 @@ class TaskScheduler:
self._executing.add(task_id)
await self._execute_task(task_id)
def _has_chain_cycle(self, db, start_id: str, max_depth: int = 10) -> bool:
def _has_chain_cycle(self, db, start_id: str, max_depth: int = 10, owner: str | None = None) -> bool:
"""Detect cycles in task chains."""
from core.database import ScheduledTask
visited = set()
@@ -1801,6 +1807,8 @@ class TaskScheduler:
return True
visited.add(current)
task = db.query(ScheduledTask).filter(ScheduledTask.id == current).first()
if owner is not None and task and task.owner != owner:
return True
if not task or not task.then_task_id:
return False
current = task.then_task_id
+127
View File
@@ -0,0 +1,127 @@
"""Task chaining must not cross owner boundaries."""
import tempfile
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from tests.helpers.import_state import clear_fake_database_modules
clear_fake_database_modules()
import core.database as cdb
import routes.task_routes as task_routes
from core.database import ScheduledTask
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
_ENGINE = create_engine(
f"sqlite:///{_TMPDB.name}",
connect_args={"check_same_thread": False},
poolclass=NullPool,
)
cdb.Base.metadata.create_all(_ENGINE)
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
task_routes.SessionLocal = _TS
def _req(user="alice"):
return SimpleNamespace(state=SimpleNamespace(current_user=user))
def _endpoint(method, path):
task_routes.SessionLocal = _TS
router = task_routes.setup_task_routes(MagicMock())
for route in router.routes:
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
return route.endpoint
raise RuntimeError(f"{method} {path} not found")
def _seed_task(task_id, owner, *, then_task_id=None):
db = _TS()
try:
task = ScheduledTask(
id=task_id,
owner=owner,
name=task_id,
prompt="do work",
task_type="llm",
trigger_type="webhook",
status="active",
output_target="session",
then_task_id=then_task_id,
)
db.add(task)
db.commit()
finally:
db.close()
@pytest.mark.asyncio
async def test_create_task_rejects_cross_owner_chain_target():
_seed_task("bob-target-create", "bob")
create_task = _endpoint("POST", "/api/tasks")
req = task_routes.TaskCreate(
prompt="alice source",
trigger_type="webhook",
then_task_id="bob-target-create",
)
with pytest.raises(HTTPException) as exc:
await create_task(_req("alice"), req)
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_update_task_rejects_cross_owner_chain_target():
_seed_task("alice-source-update", "alice")
_seed_task("bob-target-update", "bob")
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
with pytest.raises(HTTPException) as exc:
await update_task(
_req("alice"),
"alice-source-update",
task_routes.TaskUpdate(then_task_id="bob-target-update"),
)
assert exc.value.status_code == 404
db = _TS()
try:
source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first()
assert source.then_task_id is None
finally:
db.close()
@pytest.mark.asyncio
async def test_update_task_allows_same_owner_chain_target():
_seed_task("alice-source-allow", "alice")
_seed_task("alice-target-allow", "alice")
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
out = await update_task(
_req("alice"),
"alice-source-allow",
task_routes.TaskUpdate(then_task_id="alice-target-allow"),
)
assert out["then_task_id"] == "alice-target-allow"
def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe():
_seed_task("bob-target-cycle", "bob")
from src.task_scheduler import TaskScheduler
scheduler = TaskScheduler.__new__(TaskScheduler)
db = _TS()
try:
assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True
finally:
db.close()