diff --git a/routes/session_routes.py b/routes/session_routes.py index 5bd693383..811a40bbe 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -10,8 +10,9 @@ import logging from core.session_manager import SessionManager from core.models import ChatMessage from src.request_models import SessionResponse -from core.database import Session as DbSession, SessionLocal, Document, GalleryImage +from core.database import Session as DbSession, SessionLocal, Document, GalleryImage, utcnow_naive from src.auth_helpers import get_current_user, effective_user, _auth_disabled +from src.session_actions import is_session_recently_active def _sanitize_export_filename(name: str) -> str: @@ -1028,6 +1029,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ db.query(DbMsg.session_id, _sa_func.count(DbMsg.id)) .filter(DbMsg.role == "assistant").group_by(DbMsg.session_id).all() ) + cleanup_now = utcnow_naive() for row in rows: # Never delete important sessions if getattr(row, 'is_important', False): @@ -1040,6 +1042,8 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ if hasattr(session_manager, 'delete_session'): session_manager.delete_session(row.id) continue + if is_session_recently_active(row, now=cleanup_now): + continue msg_count = _counts.get(row.id, 0) should_delete = False if msg_count == 0: diff --git a/src/session_actions.py b/src/session_actions.py index 7376952d1..072bb4c06 100644 --- a/src/session_actions.py +++ b/src/session_actions.py @@ -8,7 +8,7 @@ and the task scheduler / builtin actions system. import json import logging import re -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone logger = logging.getLogger(__name__) @@ -23,6 +23,34 @@ _THROWAWAY_NAMES = { } _THROWAWAY_MAX_MESSAGES = 4 _FRESH_EMPTY_SESSION_GRACE = timedelta(minutes=10) +_FRESH_SESSION_GRACE = _FRESH_EMPTY_SESSION_GRACE + + +def _utcnow_naive() -> datetime: + """Return naive UTC for existing session DateTime columns.""" + return datetime.now(timezone.utc).replace(tzinfo=None) + + +def _as_naive_utc(value): + if value is None: + return None + if getattr(value, "tzinfo", None) is not None: + return value.astimezone(timezone.utc).replace(tzinfo=None) + return value + + +def is_session_recently_active(row, now=None, grace=_FRESH_SESSION_GRACE) -> bool: + """Return True while a new or active session is too fresh to auto-delete.""" + now = _as_naive_utc(now) or _utcnow_naive() + for attr in ("last_message_at", "last_accessed", "updated_at", "created_at"): + value = _as_naive_utc(getattr(row, attr, None)) + if not value: + continue + if value >= now: + return True + if now - value <= grace: + return True + return False async def run_auto_sort(owner: str, skip_llm: bool = False, delete_throwaway: bool = True) -> str: @@ -52,15 +80,18 @@ async def run_auto_sort(owner: str, skip_llm: bool = False, delete_throwaway: bo *([DbSession.owner == owner] if owner else []), ).all() + cleanup_now = _utcnow_naive() for row in rows: if getattr(row, 'is_important', False): continue - created_at = row.created_at or row.updated_at or datetime.utcnow() - is_fresh = (datetime.utcnow() - created_at) < _FRESH_EMPTY_SESSION_GRACE + created_at = _as_naive_utc(row.created_at or row.updated_at) or _utcnow_naive() + is_fresh = (_utcnow_naive() - created_at) < _FRESH_EMPTY_SESSION_GRACE if (row.name or "").strip() == "Incognito": deleted_throwaway += 1 db.delete(row) continue + if is_session_recently_active(row, now=cleanup_now): + continue msg_count = db.query(DbMsg.id).filter( DbMsg.session_id == row.id @@ -208,7 +239,7 @@ async def run_auto_sort(owner: str, skip_llm: bool = False, delete_throwaway: bo db_sess = db.query(DbSession).filter(DbSession.id == full_id).first() if db_sess: db_sess.folder = folder_name - db_sess.updated_at = datetime.utcnow() + db_sess.updated_at = _utcnow_naive() updated += 1 db.commit() diff --git a/tests/test_session_actions_cleanup.py b/tests/test_session_actions_cleanup.py new file mode 100644 index 000000000..221713d33 --- /dev/null +++ b/tests/test_session_actions_cleanup.py @@ -0,0 +1,166 @@ +"""Regression coverage for auto-sort session cleanup. + +Issue #1851 reported fresh chats being deleted immediately after their first +turn, leaving the browser pointed at a session id that no longer exists. +""" + +import asyncio +from datetime import timedelta +import sys +import tempfile +import uuid + +import pytest + +sqlalchemy = pytest.importorskip("sqlalchemy") +if type(sqlalchemy).__name__ == "MagicMock": + pytest.skip("sqlalchemy is stubbed in this environment", allow_module_level=True) + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +import core.database as cdb +from core.database import ChatMessage as DbMessage, Session as DbSession, utcnow_naive +import src.session_actions as session_actions + + +def _make_session_factory(): + tmp = tempfile.NamedTemporaryFile(suffix=".db", delete=False) + tmp.close() + engine = create_engine( + f"sqlite:///{tmp.name}", + connect_args={"check_same_thread": False}, + poolclass=NullPool, + ) + DbSession.metadata.create_all(bind=engine) + return sessionmaker(bind=engine, autoflush=False, autocommit=False) + + +def _install_session_factory(monkeypatch, session_factory): + monkeypatch.setitem(sys.modules, "core.database", cdb) + core_pkg = sys.modules.get("core") + if core_pkg is not None: + monkeypatch.setattr(core_pkg, "database", cdb, raising=False) + monkeypatch.setattr(cdb, "SessionLocal", session_factory) + + +def _add_message(db, sid, role, content, timestamp): + db.add( + DbMessage( + id="m-" + uuid.uuid4().hex, + session_id=sid, + role=role, + content=content, + timestamp=timestamp, + ) + ) + + +def test_auto_sort_keeps_fresh_chat_with_completed_first_turn(monkeypatch): + session_factory = _make_session_factory() + _install_session_factory(monkeypatch, session_factory) + + sid = "s-" + uuid.uuid4().hex + db = session_factory() + try: + db.add( + DbSession( + id=sid, + owner="alice", + name="Quick question", + endpoint_url="", + model="", + archived=False, + message_count=2, + last_message_at=utcnow_naive(), + ) + ) + _add_message(db, sid, "user", "hi", utcnow_naive()) + _add_message(db, sid, "assistant", "Hello! How can I help?", utcnow_naive()) + db.commit() + finally: + db.close() + + result = asyncio.run(session_actions.run_auto_sort("alice", skip_llm=True)) + + db = session_factory() + try: + assert db.query(DbSession).filter(DbSession.id == sid).first() is not None + assert db.query(DbMessage).filter(DbMessage.session_id == sid).count() == 2 + assert "Cleaned 0 sessions" in result + finally: + db.close() + + +def test_auto_sort_keeps_fresh_session_while_first_response_is_pending(monkeypatch): + session_factory = _make_session_factory() + _install_session_factory(monkeypatch, session_factory) + + sid = "s-" + uuid.uuid4().hex + db = session_factory() + try: + db.add( + DbSession( + id=sid, + owner="alice", + name="New chat", + endpoint_url="", + model="", + archived=False, + message_count=1, + last_message_at=utcnow_naive(), + ) + ) + _add_message(db, sid, "user", "Tell me a quick joke", utcnow_naive()) + db.commit() + finally: + db.close() + + result = asyncio.run(session_actions.run_auto_sort("alice", skip_llm=True)) + + db = session_factory() + try: + assert db.query(DbSession).filter(DbSession.id == sid).first() is not None + assert db.query(DbMessage).filter(DbMessage.session_id == sid).count() == 1 + assert "Cleaned 0 sessions" in result + finally: + db.close() + + +def test_auto_sort_still_deletes_old_throwaway_sessions(monkeypatch): + session_factory = _make_session_factory() + _install_session_factory(monkeypatch, session_factory) + + old_time = utcnow_naive() - timedelta(hours=2) + sid = "s-" + uuid.uuid4().hex + db = session_factory() + try: + db.add( + DbSession( + id=sid, + owner="alice", + name="New chat", + endpoint_url="", + model="", + archived=False, + message_count=1, + created_at=old_time, + updated_at=old_time, + last_accessed=old_time, + last_message_at=old_time, + ) + ) + _add_message(db, sid, "user", "hi", old_time) + db.commit() + finally: + db.close() + + result = asyncio.run(session_actions.run_auto_sort("alice", skip_llm=True)) + + db = session_factory() + try: + assert db.query(DbSession).filter(DbSession.id == sid).first() is None + assert "Cleaned 1 sessions" in result + finally: + db.close()