mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
Enforce task chain owner scope (#3006)
This commit is contained in:
+17
-2
@@ -429,6 +429,20 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _validate_then_task_id(db, then_task_id: Optional[str], user: Optional[str], current_task_id: Optional[str] = None) -> Optional[str]:
|
||||
target_id = (then_task_id or "").strip()
|
||||
if not target_id:
|
||||
return None
|
||||
if current_task_id and target_id == current_task_id:
|
||||
raise HTTPException(400, "Task cannot chain to itself")
|
||||
q = db.query(ScheduledTask).filter(ScheduledTask.id == target_id)
|
||||
if user:
|
||||
q = q.filter(ScheduledTask.owner == user)
|
||||
target = q.first()
|
||||
if not target:
|
||||
raise HTTPException(404, "Chained task not found")
|
||||
return target.id
|
||||
|
||||
@router.post("")
|
||||
async def create_task(request: Request, req: TaskCreate):
|
||||
user = _owner(request)
|
||||
@@ -492,6 +506,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
task_id = str(uuid.uuid4())
|
||||
db = SessionLocal()
|
||||
try:
|
||||
then_task_id = _validate_then_task_id(db, req.then_task_id, user)
|
||||
notifications_enabled = (
|
||||
False if req.task_type == "action" and req.notifications_enabled is None
|
||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||
@@ -518,7 +533,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
output_target=req.output_target,
|
||||
model=req.model or None,
|
||||
endpoint_url=req.endpoint_url or None,
|
||||
then_task_id=req.then_task_id or None,
|
||||
then_task_id=then_task_id,
|
||||
webhook_token=webhook_token,
|
||||
notifications_enabled=notifications_enabled,
|
||||
)
|
||||
@@ -671,7 +686,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
if req.trigger_count is not None:
|
||||
task.trigger_count = req.trigger_count
|
||||
if req.then_task_id is not None:
|
||||
task.then_task_id = req.then_task_id or None
|
||||
task.then_task_id = _validate_then_task_id(db, req.then_task_id, user, current_task_id=task.id)
|
||||
if req.notifications_enabled is not None:
|
||||
task.notifications_enabled = bool(req.notifications_enabled)
|
||||
if req.cron_expression is not None:
|
||||
|
||||
+10
-2
@@ -844,7 +844,13 @@ class TaskScheduler:
|
||||
# Task chaining — trigger the next task on success
|
||||
if run.status == "success" and task.then_task_id:
|
||||
chain_id = task.then_task_id
|
||||
if not self._has_chain_cycle(db, chain_id):
|
||||
chain_task = db.query(ScheduledTask).filter(ScheduledTask.id == chain_id).first()
|
||||
if not chain_task or chain_task.owner != task.owner:
|
||||
logger.warning(
|
||||
"Skipping chain from %r: target task %s is missing or not owned by %r",
|
||||
task.name, chain_id, task.owner,
|
||||
)
|
||||
elif not self._has_chain_cycle(db, chain_id, owner=task.owner):
|
||||
logger.info(f"Chaining: '{task.name}' → task {chain_id}")
|
||||
asyncio.create_task(self._run_chained(chain_id))
|
||||
else:
|
||||
@@ -1791,7 +1797,7 @@ class TaskScheduler:
|
||||
self._executing.add(task_id)
|
||||
await self._execute_task(task_id)
|
||||
|
||||
def _has_chain_cycle(self, db, start_id: str, max_depth: int = 10) -> bool:
|
||||
def _has_chain_cycle(self, db, start_id: str, max_depth: int = 10, owner: str | None = None) -> bool:
|
||||
"""Detect cycles in task chains."""
|
||||
from core.database import ScheduledTask
|
||||
visited = set()
|
||||
@@ -1801,6 +1807,8 @@ class TaskScheduler:
|
||||
return True
|
||||
visited.add(current)
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == current).first()
|
||||
if owner is not None and task and task.owner != owner:
|
||||
return True
|
||||
if not task or not task.then_task_id:
|
||||
return False
|
||||
current = task.then_task_id
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
"""Task chaining must not cross owner boundaries."""
|
||||
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from tests.helpers.import_state import clear_fake_database_modules
|
||||
|
||||
clear_fake_database_modules()
|
||||
|
||||
import core.database as cdb
|
||||
import routes.task_routes as task_routes
|
||||
from core.database import ScheduledTask
|
||||
|
||||
_TMPDB = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
_ENGINE = create_engine(
|
||||
f"sqlite:///{_TMPDB.name}",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=NullPool,
|
||||
)
|
||||
cdb.Base.metadata.create_all(_ENGINE)
|
||||
_TS = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False)
|
||||
task_routes.SessionLocal = _TS
|
||||
|
||||
|
||||
def _req(user="alice"):
|
||||
return SimpleNamespace(state=SimpleNamespace(current_user=user))
|
||||
|
||||
|
||||
def _endpoint(method, path):
|
||||
task_routes.SessionLocal = _TS
|
||||
router = task_routes.setup_task_routes(MagicMock())
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", None) == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise RuntimeError(f"{method} {path} not found")
|
||||
|
||||
|
||||
def _seed_task(task_id, owner, *, then_task_id=None):
|
||||
db = _TS()
|
||||
try:
|
||||
task = ScheduledTask(
|
||||
id=task_id,
|
||||
owner=owner,
|
||||
name=task_id,
|
||||
prompt="do work",
|
||||
task_type="llm",
|
||||
trigger_type="webhook",
|
||||
status="active",
|
||||
output_target="session",
|
||||
then_task_id=then_task_id,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_rejects_cross_owner_chain_target():
|
||||
_seed_task("bob-target-create", "bob")
|
||||
create_task = _endpoint("POST", "/api/tasks")
|
||||
|
||||
req = task_routes.TaskCreate(
|
||||
prompt="alice source",
|
||||
trigger_type="webhook",
|
||||
then_task_id="bob-target-create",
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await create_task(_req("alice"), req)
|
||||
|
||||
assert exc.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_rejects_cross_owner_chain_target():
|
||||
_seed_task("alice-source-update", "alice")
|
||||
_seed_task("bob-target-update", "bob")
|
||||
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await update_task(
|
||||
_req("alice"),
|
||||
"alice-source-update",
|
||||
task_routes.TaskUpdate(then_task_id="bob-target-update"),
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 404
|
||||
db = _TS()
|
||||
try:
|
||||
source = db.query(ScheduledTask).filter(ScheduledTask.id == "alice-source-update").first()
|
||||
assert source.then_task_id is None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_allows_same_owner_chain_target():
|
||||
_seed_task("alice-source-allow", "alice")
|
||||
_seed_task("alice-target-allow", "alice")
|
||||
update_task = _endpoint("PUT", "/api/tasks/{task_id}")
|
||||
|
||||
out = await update_task(
|
||||
_req("alice"),
|
||||
"alice-source-allow",
|
||||
task_routes.TaskUpdate(then_task_id="alice-target-allow"),
|
||||
)
|
||||
|
||||
assert out["then_task_id"] == "alice-target-allow"
|
||||
|
||||
|
||||
def test_scheduler_cycle_guard_treats_cross_owner_chain_as_unsafe():
|
||||
_seed_task("bob-target-cycle", "bob")
|
||||
from src.task_scheduler import TaskScheduler
|
||||
|
||||
scheduler = TaskScheduler.__new__(TaskScheduler)
|
||||
db = _TS()
|
||||
try:
|
||||
assert scheduler._has_chain_cycle(db, "bob-target-cycle", owner="alice") is True
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user