Tasks: clean up queued cancellation state

This commit is contained in:
mechramc
2026-06-02 06:51:21 -05:00
committed by GitHub
parent f975279b26
commit 8e87d3002b
2 changed files with 153 additions and 23 deletions
+48 -23
View File
@@ -241,6 +241,35 @@ class TaskScheduler:
except Exception:
logger.debug("Task progress update failed", exc_info=True)
def _mark_run_aborted(self, task_id: str, run_id: str | None = None, message: str = "Stopped by user") -> bool:
"""Mark an active run as aborted. Used by stop/cancel paths."""
try:
from core.database import SessionLocal, TaskRun
db = SessionLocal()
try:
q = db.query(TaskRun)
if run_id:
q = q.filter(TaskRun.id == run_id)
else:
q = q.filter(
TaskRun.task_id == task_id,
TaskRun.status.in_(("queued", "running")),
).order_by(TaskRun.started_at.desc())
run = q.first()
if not run or run.status not in ("queued", "running"):
return False
run.status = "aborted"
run.error = message
run.result = run.result or message
run.finished_at = datetime.utcnow()
db.commit()
return True
finally:
db.close()
except Exception:
logger.debug("Task abort marker failed for %s", task_id, exc_info=True)
return False
def add_notification(self, task_name: str, status: str, task_id: str = None, owner: str = None, body: str = None):
"""Store a notification about a completed task run. Tagged with the
task's owner so `pop_notifications` can return only that user's
@@ -581,12 +610,25 @@ class TaskScheduler:
finally:
_q_db.close()
if bypass_model_slot or not self._task_needs_model_slot(task_id):
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
return
try:
if bypass_model_slot or not self._task_needs_model_slot(task_id):
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
return
async with self._run_semaphore:
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
async with self._run_semaphore:
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
except asyncio.CancelledError:
# If cancellation happens while queued behind the semaphore,
# _execute_task_locked never runs and cannot update the Activity row.
self._mark_run_aborted(task_id, run_id)
raise
finally:
handle = self._task_handles.get(task_id)
if handle is current:
self._task_handles.pop(task_id, None)
if release_executing:
async with self._executing_lock:
self._executing.discard(task_id)
async def _execute_task_locked(self, task_id: str, run_id: str, *, release_executing: bool = True):
from core.database import SessionLocal, ScheduledTask, TaskRun
@@ -1839,24 +1881,7 @@ class TaskScheduler:
self._executing.discard(task_id)
stopped = True
from core.database import SessionLocal, TaskRun
db = SessionLocal()
try:
run = (
db.query(TaskRun)
.filter(TaskRun.task_id == task_id, TaskRun.status.in_(("queued", "running")))
.order_by(TaskRun.started_at.desc())
.first()
)
if run:
run.status = "aborted"
run.error = "Stopped by user"
run.result = run.result or "Stopped by user"
run.finished_at = datetime.utcnow()
db.commit()
stopped = True
finally:
db.close()
stopped = self._mark_run_aborted(task_id) or stopped
return stopped
async def ensure_defaults(self, owner: str):