diff --git a/routes/task_routes.py b/routes/task_routes.py index dfaed0808..a31d12995 100644 --- a/routes/task_routes.py +++ b/routes/task_routes.py @@ -429,6 +429,20 @@ def setup_task_routes(task_scheduler) -> APIRouter: except Exception: return False + def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]: + target_id = (then_task_id or "").strip() + if not target_id: + return None + if current_task_id and target_id == current_task_id: + raise HTTPException(400, "Task cannot chain to itself") + q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id) + if user: + q = q.filter(ScheduledTask.owner == user) + target = q.first() + if not target: + raise HTTPException(404, "Chained task not found") + return target.id + @router.post("") async def create_task(request: Request, req: TaskCreate): user = _owner(request) @@ -492,6 +506,7 @@ def setup_task_routes(task_scheduler) -> APIRouter: task_id = str(uuid.uuid4()) db = SessionLocal() try: + then_task_id = _validate_then_task_id(db, req.then_task_id, user) notifications_enabled = ( False if req.task_type == "action" and req.notifications_enabled is None else bool(req.notifications_enabled) if req.notifications_enabled is not None @@ -518,7 +533,7 @@ def setup_task_routes(task_scheduler) -> APIRouter: output_target=req.output_target, model=req.model or None, endpoint_url=req.endpoint_url or None, - then_task_id=req.then_task_id or None, + then_task_id=then_task_id, webhook_token=webhook_token, notifications_enabled=notifications_enabled, ) @@ -671,7 +686,7 @@ def setup_task_routes(task_scheduler) -> APIRouter: if req.trigger_count is not None: task.trigger_count = req.trigger_count if req.then_task_id is not None: - task.then_task_id = req.then_task_id or None + task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id) if req.notifications_enabled is not None: task.notifications_enabled = bool(req.notifications_enabled) if req.cron_expression is not None: diff --git a/src/task_scheduler.py b/src/task_scheduler.py index 2fcb5dc09..96b866720 100644 --- a/src/task_scheduler.py +++ b/src/task_scheduler.py @@ -844,7 +844,13 @@ class TaskScheduler: # Task chaining — trigger the next task on success if run.status == "success" and task.then_task_id: chain_id = task.then_task_id - if not self._has_chain_cycle(db, chain_id): + chain_task = db.query(ScheduledTask).filter(ScheduledTask.id == chain_id).first() + if not chain_task or chain_task.owner != task.owner: + logger.warning( + "Skipping chain from %r: target task %s is missing or not owned by %r", + task.name, chain_id, task.owner, + ) + elif not self._has_chain_cycle(db, chain_id, owner=task.owner): logger.info(f"Chaining: '{task.name}' → task {chain_id}") asyncio.create_task(self._run_chained(chain_id)) else: @@ -1791,7 +1797,7 @@ class TaskScheduler: self._executing.add(task_id) await self._execute_task(task_id) - def _has_chain_cycle(self, db, start_id: str, max_depth: int = 10) -> bool: + def _has_chain_cycle(self, db, start_id: str, max_depth: int = 10, owner: str | None = None) -> bool: """Detect cycles in task chains.""" from core.database import ScheduledTask visited = set() @@ -1801,6 +1807,8 @@ class TaskScheduler: return True visited.add(current) task = db.query(ScheduledTask).filter(ScheduledTask.id == current).first() + if owner is not None and task and task.owner != owner: + return True if not task or not task.then_task_id: return False current = task.then_task_id diff --git a/tests/test_task_chain_owner_scope.py b/tests/test_task_chain_owner_scope.py new file mode 100644 index 000000000..d13852663 --- /dev/null +++ b/tests/test_task_chain_owner_scope.py @@ -0,0 +1,127 @@ +"""Task chaining must not cross owner boundaries.""" + +import tempfile +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +from tests.helpers.import_state import clear_fake_database_modules + +clear_fake_database_modules() + +import core.database as cdb +import routes.task_routes as task_routes +from core.database import ScheduledTask + +_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False) +_ENGINE = create_engine( + f"sqlite:///{_TMPDB.name}", + connect_args={"check_same_thread": False}, + poolclass=NullPool, +) +cdb.Base.metadata.create_all(_ENGINE) +_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False) +task_routes.SessionLocal = _TS + + +def _req(user="alice"): + return SimpleNamespace(state=SimpleNamespace(current_user=user)) + + +def _endpoint(method, path): + task_routes.SessionLocal = _TS + router = task_routes.setup_task_routes(MagicMock()) + for route in router.routes: + if getattr(route, "path", None) == path and method in getattr(route, "methods", set()): + return route.endpoint + raise RuntimeError(f"{method} {path} not found") + + +def _seed_task(task_id, owner, *, then_task_id=None): + db = _TS() + try: + task = ScheduledTask( + id=task_id, + owner=owner, + name=task_id, + prompt="do work", + task_type="llm", + trigger_type="webhook", + status="active", + output_target="session", + then_task_id=then_task_id, + ) + db.add(task) + db.commit() + finally: + db.close() + + +@pytest.mark.asyncio +async def test_create_task_rejects_cross_owner_chain_target(): + _seed_task("bob-target-create", "bob") + create_task = _endpoint("POST", "/api/tasks") + + req = task_routes.TaskCreate( + prompt="alice source", + trigger_type="webhook", + then_task_id="bob-target-create", + ) + with pytest.raises(HTTPException) as exc: + await create_task(_req("alice"), req) + + assert exc.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_update_task_rejects_cross_owner_chain_target(): + _seed_task("alice-source-update", "alice") + _seed_task("bob-target-update", "bob") + update_task = _endpoint("PUT", "/api/tasks/{task_id}") + + with pytest.raises(HTTPException) as exc: + await update_task( + _req("alice"), + "alice-source-update", + task_routes.TaskUpdate(then_task_id="bob-target-update"), + ) + + assert exc.value.status_code == 404 + db = _TS() + try: + source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first() + assert source.then_task_id is None + finally: + db.close() + + +@pytest.mark.asyncio +async def test_update_task_allows_same_owner_chain_target(): + _seed_task("alice-source-allow", "alice") + _seed_task("alice-target-allow", "alice") + update_task = _endpoint("PUT", "/api/tasks/{task_id}") + + out = await update_task( + _req("alice"), + "alice-source-allow", + task_routes.TaskUpdate(then_task_id="alice-target-allow"), + ) + + assert out["then_task_id"] == "alice-target-allow" + + +def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe(): + _seed_task("bob-target-cycle", "bob") + from src.task_scheduler import TaskScheduler + + scheduler = TaskScheduler.__new__(TaskScheduler) + db = _TS() + try: + assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True + finally: + db.close()