diff --git a/src/agent_tools/__init__.py b/src/agent_tools/__init__.py index 372765cec..ba39b4f53 100644 --- a/src/agent_tools/__init__.py +++ b/src/agent_tools/__init__.py @@ -24,6 +24,7 @@ from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool from .model_interaction_tools import ChatWithModelTool, AskTeacherTool, ListModelsTool from .bg_job_tools import ManageBgJobsTool +from .session_tools import CreateSessionTool, ListSessionsTool, SendToSessionTool, ManageSessionTool TOOL_HANDLERS = { "bash": BashTool().execute, @@ -46,6 +47,10 @@ TOOL_HANDLERS = { "ask_teacher": AskTeacherTool().execute, "list_models": ListModelsTool().execute, "manage_bg_jobs": ManageBgJobsTool().execute, + "create_session": CreateSessionTool().execute, + "list_sessions": ListSessionsTool().execute, + "send_to_session": SendToSessionTool().execute, + "manage_session": ManageSessionTool().execute, } # --------------------------------------------------------------------------- diff --git a/src/agent_tools/session_tools.py b/src/agent_tools/session_tools.py new file mode 100644 index 000000000..797235c5d --- /dev/null +++ b/src/agent_tools/session_tools.py @@ -0,0 +1,464 @@ +"""session_tools.py - agent tools for AI-to-AI session management. + +Owns create_session, list_sessions, send_to_session and manage_session, moved +out of src.ai_interaction as part of the tool -> registry migration (#3629), and +their handler classes registered in TOOL_HANDLERS. + +The session manager is a runtime-set singleton in src.ai_interaction, so each +function fetches it via get_session_manager() (imported here); _resolve_model and +AI_CHAT_TIMEOUT are reused from there too. +""" +import json +import logging +import uuid +from typing import Dict, Optional + +from src.ai_interaction import get_session_manager, _resolve_model, AI_CHAT_TIMEOUT + +logger = logging.getLogger(__name__) + + +async def create_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """Create a new chat session. + + Content format: + Line 1: session name + Line 2: model_name (or model_name@endpoint_name) + """ + _session_manager = get_session_manager() + if not _session_manager: + return {"error": "Session manager not available"} + + lines = content.strip().split("\n") + if len(lines) < 2: + return {"error": "Need 2 lines: session name, then model spec"} + + name = lines[0].strip() + model_spec = lines[1].strip() + + if not name: + return {"error": "Session name cannot be empty"} + + try: + url, model, headers = _resolve_model(model_spec, owner=owner) + except ValueError as e: + return {"error": str(e)} + + sid = str(uuid.uuid4())[:8] + try: + _session_manager.create_session( + session_id=sid, + name=name, + endpoint_url=url, + model=model, + rag=False, + owner=owner, + ) + # Store headers on session for future calls + sess = _session_manager.get_session(sid) + if sess and headers: + sess.headers = headers + try: + from src.event_bus import fire_event + fire_event("session_created", owner) + except Exception: + logger.debug("session_created event dispatch failed", exc_info=True) + + return {"session_id": sid, "name": name, "model": model, "endpoint_url": url} + except Exception as e: + logger.error(f"create_session failed: {e}") + return {"error": f"Failed to create session: {e}"} + +async def list_sessions(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """List sessions sorted by most-recently-active first. + + Output includes a relative "last active" timestamp per row so the + agent can answer "open my last chat" without guessing from titles. + The most-recent session is always first in the list. + + Content = optional filter keyword (matches session name). + """ + _session_manager = get_session_manager() + if not _session_manager: + return {"error": "Session manager not available"} + + keyword = content.strip().lower() if content.strip() else None + + try: + from core.database import SessionLocal, Session as DbSession + from datetime import datetime, timezone + + # Pull every session's last_accessed from the DB so we can sort + # by recency. In-memory sessions hold name + model + msg_count; + # the DB row holds the timestamps. + db = SessionLocal() + try: + db_rows = {r.id: r for r in db.query(DbSession).all()} + finally: + db.close() + + # SECURITY: scope to the caller's sessions. Passing None returned + # every user's sessions, which the agent tool then exposed via the + # "list my chats" reply. + sessions = _session_manager.get_sessions_for_user(owner) + rows = [] + for sid, sess in sessions.items(): + if keyword and keyword not in (sess.name or "").lower(): + continue + db_row = db_rows.get(sid) + # Prefer last_accessed; fall back to updated_at, then created_at. + ts = None + if db_row: + ts = getattr(db_row, 'last_accessed', None) or getattr(db_row, 'updated_at', None) or getattr(db_row, 'created_at', None) + rows.append((ts, sid, sess)) + + # Sort by timestamp DESC; rows without a timestamp sink to the bottom. + rows.sort(key=lambda r: r[0] or datetime.min, reverse=True) + + def _rel(ts): + if not ts: + return 'never' + now = datetime.utcnow() + try: + if ts.tzinfo is not None: + now = datetime.now(timezone.utc) + diff = (now - ts).total_seconds() + except Exception: + return 'unknown' + if diff < 60: return 'just now' + if diff < 3600: return f'{int(diff / 60)}m ago' + if diff < 86400: return f'{int(diff / 3600)}h ago' + if diff < 86400 * 7: return f'{int(diff / 86400)}d ago' + return ts.strftime('%Y-%m-%d') + + lines = [] + for i, (ts, sid, sess) in enumerate(rows): + if i >= 50: + lines.append(f"... and {len(rows) - 50} more (showing first 50)") + break + safe_name = (sess.name or "Untitled").replace("[", "\\[").replace("]", "\\]") + msg_count = getattr(sess, "message_count", 0) or 0 + model = getattr(sess, "model", "unknown") + marker = " ← most recent" if i == 0 else "" + lines.append(f"- **[{safe_name}](#session-{sid})** (id: `{sid}`, model: {model}, {msg_count} msgs, last active {_rel(ts)}){marker}") + + if not lines: + return {"results": "No sessions found" + (f" matching '{keyword}'" if keyword else "") + "."} + + return { + "results": ( + f"Found {len(rows)} session(s), sorted most-recent first:\n" + + "\n".join(lines) + + "\n\nAssistant: when replying to the user, preserve the chat-title markdown links exactly as shown, e.g. `[Chat](#session-id)`. Do not rewrite this as a plain, non-clickable table." + ) + } + except Exception as e: + logger.error(f"list_sessions failed: {e}") + return {"error": str(e)} + +async def send_to_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """Send a message to an existing session and get a response. + + Content format: + Line 1: session_id + Line 2+: message + """ + _session_manager = get_session_manager() + from src.llm_core import llm_call_async + from core.models import ChatMessage + + if not _session_manager: + return {"error": "Session manager not available"} + + lines = content.strip().split("\n", 1) + if len(lines) < 2: + return {"error": "Need 2 lines: session_id, then message"} + + target_sid = lines[0].strip() + message = lines[1].strip() + + sess = _session_manager.get_session(target_sid) + if not sess: + return {"error": f"Session '{target_sid}' not found"} + + # Owner-scope: reject access to another user's session + if owner and getattr(sess, "owner", None) and sess.owner != owner: + return {"error": f"Session '{target_sid}' not found"} + + if not message: + return {"error": "No message provided"} + + try: + # Build context from session history + context = sess.get_context_messages() + context.append({"role": "user", "content": message}) + + response = await llm_call_async( + sess.endpoint_url, sess.model, context, + headers=sess.headers, + timeout=AI_CHAT_TIMEOUT, + ) + + # Save both messages to session + sess.add_message(ChatMessage("user", message)) + sess.add_message(ChatMessage("assistant", response)) + + # Truncate for tool output + if len(response) > 10000: + response = response[:10000] + "\n... (truncated)" + + return { + "session_id": target_sid, + "session_name": sess.name, + "response": response, + } + except Exception as e: + logger.error(f"send_to_session failed: {e}") + return {"error": f"Failed to send to session: {e}"} + +async def manage_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """Manage sessions: rename, archive, delete, important, truncate, fork. + + Content format: + Line 1: action (rename|archive|unarchive|delete|important|unimportant|truncate|fork) + Line 2: target session_id (or "current" to use the active session) + Line 3+: action-specific params (e.g. new name for rename, keep_count for truncate) + """ + _session_manager = get_session_manager() + if not _session_manager: + return {"error": "Session manager not available"} + + from src.database import SessionLocal, Session as DbSession + + # Accept BOTH the structured JSON args the tool schema advertises + # ({action, session_id, value}) AND the legacy line-based format + # (line1=action, line2=session_id, line3=value). Native function-calling + # models send JSON; fenced-block callers send lines. Previously only the + # line format was parsed, so a model that followed the schema (JSON) got + # "Need at least 2 lines" / "Rename needs line 3" and couldn't drive it. + _raw = (content or "").strip() + action = "" + target_sid = "" + value = None # the action param: new name (rename) / keep_count (truncate, fork) + _list_filter = "" + _parsed = None + if _raw.startswith("{"): + try: + _parsed = json.loads(_raw) + except Exception: + _parsed = None + if isinstance(_parsed, dict): + action = str(_parsed.get("action") or "").strip().lower() + target_sid = str(_parsed.get("session_id") or _parsed.get("session") or _parsed.get("id") or "").strip() + _v = _parsed.get("value") + if _v is None: + _v = (_parsed.get("name") or _parsed.get("new_name") + or _parsed.get("title") or _parsed.get("keep_count")) + value = None if _v is None else str(_v).strip() + _list_filter = str(_parsed.get("filter") or "").strip() + else: + lines = _raw.split("\n") + if not lines or not lines[0].strip(): + return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"} + action = lines[0].strip().lower() + target_sid = lines[1].strip() if len(lines) >= 2 else "" + value = lines[2].strip() if len(lines) >= 3 else None + _list_filter = "\n".join(lines[1:]).strip() + + if not action: + return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"} + + # `list` alias - dispatch to list_sessions so the agent's natural + # first guess (every other manage_* tool has a `list` action) works. + if action == "list": + return await list_sessions(_list_filter, session_id, owner=owner) + + if not target_sid: + return {"error": "Need a session_id (or 'current' for the active chat)"} + + # Allow "current" to refer to the active session + if target_sid.lower() == "current" and session_id: + target_sid = session_id + + # `switch` / `open` / `select` / `view` - the agent reaches for + # these when the user asks to "open" or "switch to" a session. + # There's no server-side way to make the browser navigate, so we + # just return a clickable anchor link the user can click. The + # frontend's chat-history click delegate routes `#session-` + # to selectSession(). The agent's reply naturally embeds this + # result so the user sees a single clickable line. + def _session_query(db): + query = db.query(DbSession).filter(DbSession.id == target_sid) + if owner is not None: + query = query.filter(DbSession.owner == owner) + return query + + if action in ("switch", "open", "select", "view"): + db = SessionLocal() + try: + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + name = db_sess.name or target_sid + finally: + db.close() + return { + "action": action, + "session_id": target_sid, + "name": name, + "results": f"[{name}](#session-{target_sid}) - click to open.", + } + + db = SessionLocal() + try: + if action == "rename": + if not value: + return {"error": "rename needs a new name (the `value` arg, or line 3 in the legacy format)"} + new_name = value + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + db_sess.name = new_name + db.commit() + _session_manager.update_session_name(target_sid, new_name) + return {"action": "rename", "session_id": target_sid, "name": new_name, + "results": f"Session renamed to '{new_name}'"} + + elif action == "archive": + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + db_sess.archived = True + db.commit() + return {"action": "archive", "session_id": target_sid, + "results": f"Session '{db_sess.name}' archived"} + + elif action == "unarchive": + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + db_sess.archived = False + db.commit() + return {"action": "unarchive", "session_id": target_sid, + "results": f"Session '{db_sess.name}' unarchived"} + + elif action == "delete": + if target_sid == session_id: + return {"error": "Cannot delete the current session while chatting in it. Delete other sessions first."} + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Refusing to delete an unknown chat id; use the exact id from list_sessions."} + if db_sess and db_sess.is_important: + return {"error": f"Session '{db_sess.name}' is starred/favorited. Unstar it first before deleting."} + try: + ok = _session_manager.delete_session(target_sid) + if not ok: + return {"error": f"Session '{target_sid}' was not deleted because it no longer exists."} + return {"action": "delete", "session_id": target_sid, + "results": f"Session '{db_sess.name or target_sid}' deleted"} + except Exception as e: + return {"error": f"Failed to delete session: {e}"} + + elif action in ("important", "unimportant"): + is_important = action == "important" + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + # Prevent AI from unstarring sessions - only the user can do that manually + if not is_important and db_sess.is_important: + return {"error": f"Session '{db_sess.name}' is starred by the user. Only the user can unstar sessions manually."} + db_sess.is_important = is_important + db.commit() + status = "marked as important" if is_important else "unmarked as important" + return {"action": action, "session_id": target_sid, + "results": f"Session '{db_sess.name}' {status}"} + + elif action == "truncate": + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + keep_count = 10 + if value: + try: + keep_count = int(value) + except ValueError: + pass + success = _session_manager.truncate_messages(target_sid, keep_count) + if success: + return {"action": "truncate", "session_id": target_sid, + "results": f"Session truncated to last {keep_count} messages"} + return {"error": f"Failed to truncate session '{target_sid}'"} + + elif action == "fork": + db_sess = _session_query(db).first() + if not db_sess: + return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} + keep_count = 0 # 0 = all messages + if value: + try: + keep_count = int(value) + except ValueError: + pass + + source = _session_manager.get_session(target_sid) + if not source: + return {"error": f"Session '{target_sid}' not found"} + + new_sid = str(uuid.uuid4())[:8] + _session_manager.create_session( + session_id=new_sid, + name=f"Fork: {source.name}", + endpoint_url=source.endpoint_url, + model=source.model, + rag=False, + owner=owner, + ) + # Copy messages + history = source.get_context_messages() + if keep_count > 0: + history = history[:keep_count] + from core.models import ChatMessage as InMemoryMsg + new_sess = _session_manager.get_session(new_sid) + for msg in history: + new_sess.add_message(InMemoryMsg(msg["role"], msg["content"])) + try: + from src.event_bus import fire_event + fire_event("session_created", owner) + except Exception: + logger.debug("session_created event dispatch failed", exc_info=True) + + return {"action": "fork", "session_id": new_sid, + "source_session": target_sid, "messages_copied": len(history), + "results": f"Forked session '{source.name}' -> new session {new_sid} ({len(history)} messages)"} + + else: + return {"error": f"Unknown action '{action}'. Use: list, switch, rename, archive, unarchive, delete, important, unimportant, truncate, fork"} + except Exception as e: + logger.error(f"manage_session failed: {e}") + return {"error": str(e)} + finally: + db.close() + + +# --------------------------------------------------------------------------- +# Handler classes registered in TOOL_HANDLERS +# --------------------------------------------------------------------------- + +class CreateSessionTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await create_session(content, ctx.get("session_id"), owner=ctx.get("owner")) + + +class ListSessionsTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await list_sessions(content, ctx.get("session_id"), owner=ctx.get("owner")) + + +class SendToSessionTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await send_to_session(content, ctx.get("session_id"), owner=ctx.get("owner")) + + +class ManageSessionTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await manage_session(content, ctx.get("session_id"), owner=ctx.get("owner")) diff --git a/src/ai_interaction.py b/src/ai_interaction.py index 667df8fb5..6655beaf4 100644 --- a/src/ai_interaction.py +++ b/src/ai_interaction.py @@ -1,12 +1,14 @@ """ ai_interaction.py -AI-to-AI interaction tools: create_session, list_sessions, send_to_session, -pipeline, plus shared model resolution (_resolve_model). +AI-to-AI interaction tools: pipeline and manage_memory, plus shared model +resolution (_resolve_model), the session-manager singleton, and dispatch_ai_tool. -chat_with_model, ask_teacher and list_models were moved to -src/agent_tools/model_interaction_tools.py as part of the tool -> registry -migration (#3629); they still reuse _resolve_model / AI_CHAT_TIMEOUT from here. +As part of the tool -> registry migration (#3629), chat_with_model, ask_teacher +and list_models moved to src/agent_tools/model_interaction_tools.py, and +create_session, list_sessions, send_to_session and manage_session moved to +src/agent_tools/session_tools.py. Those modules reuse get_session_manager / +_resolve_model / AI_CHAT_TIMEOUT from here. These are agent tools — the LLM writes fenced code blocks and they execute through the standard agent_tools.py pipeline. @@ -165,204 +167,6 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di -async def do_create_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """Create a new chat session. - - Content format: - Line 1: session name - Line 2: model_name (or model_name@endpoint_name) - """ - if not _session_manager: - return {"error": "Session manager not available"} - - lines = content.strip().split("\n") - if len(lines) < 2: - return {"error": "Need 2 lines: session name, then model spec"} - - name = lines[0].strip() - model_spec = lines[1].strip() - - if not name: - return {"error": "Session name cannot be empty"} - - try: - url, model, headers = _resolve_model(model_spec, owner=owner) - except ValueError as e: - return {"error": str(e)} - - sid = str(uuid.uuid4())[:8] - try: - _session_manager.create_session( - session_id=sid, - name=name, - endpoint_url=url, - model=model, - rag=False, - owner=owner, - ) - # Store headers on session for future calls - sess = _session_manager.get_session(sid) - if sess and headers: - sess.headers = headers - try: - from src.event_bus import fire_event - fire_event("session_created", owner) - except Exception: - logger.debug("session_created event dispatch failed", exc_info=True) - - return {"session_id": sid, "name": name, "model": model, "endpoint_url": url} - except Exception as e: - logger.error(f"create_session failed: {e}") - return {"error": f"Failed to create session: {e}"} - - -async def do_list_sessions(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """List sessions sorted by most-recently-active first. - - Output includes a relative "last active" timestamp per row so the - agent can answer "open my last chat" without guessing from titles. - The most-recent session is always first in the list. - - Content = optional filter keyword (matches session name). - """ - if not _session_manager: - return {"error": "Session manager not available"} - - keyword = content.strip().lower() if content.strip() else None - - try: - from core.database import SessionLocal, Session as DbSession - from datetime import datetime, timezone - - # Pull every session's last_accessed from the DB so we can sort - # by recency. In-memory sessions hold name + model + msg_count; - # the DB row holds the timestamps. - db = SessionLocal() - try: - db_rows = {r.id: r for r in db.query(DbSession).all()} - finally: - db.close() - - # SECURITY: scope to the caller's sessions. Passing None returned - # every user's sessions, which the agent tool then exposed via the - # "list my chats" reply. - sessions = _session_manager.get_sessions_for_user(owner) - rows = [] - for sid, sess in sessions.items(): - if keyword and keyword not in (sess.name or "").lower(): - continue - db_row = db_rows.get(sid) - # Prefer last_accessed; fall back to updated_at, then created_at. - ts = None - if db_row: - ts = getattr(db_row, 'last_accessed', None) or getattr(db_row, 'updated_at', None) or getattr(db_row, 'created_at', None) - rows.append((ts, sid, sess)) - - # Sort by timestamp DESC; rows without a timestamp sink to the bottom. - rows.sort(key=lambda r: r[0] or datetime.min, reverse=True) - - def _rel(ts): - if not ts: - return 'never' - now = datetime.utcnow() - try: - if ts.tzinfo is not None: - now = datetime.now(timezone.utc) - diff = (now - ts).total_seconds() - except Exception: - return 'unknown' - if diff < 60: return 'just now' - if diff < 3600: return f'{int(diff / 60)}m ago' - if diff < 86400: return f'{int(diff / 3600)}h ago' - if diff < 86400 * 7: return f'{int(diff / 86400)}d ago' - return ts.strftime('%Y-%m-%d') - - lines = [] - for i, (ts, sid, sess) in enumerate(rows): - if i >= 50: - lines.append(f"... and {len(rows) - 50} more (showing first 50)") - break - safe_name = (sess.name or "Untitled").replace("[", "\\[").replace("]", "\\]") - msg_count = getattr(sess, "message_count", 0) or 0 - model = getattr(sess, "model", "unknown") - marker = " ← most recent" if i == 0 else "" - lines.append(f"- **[{safe_name}](#session-{sid})** (id: `{sid}`, model: {model}, {msg_count} msgs, last active {_rel(ts)}){marker}") - - if not lines: - return {"results": "No sessions found" + (f" matching '{keyword}'" if keyword else "") + "."} - - return { - "results": ( - f"Found {len(rows)} session(s), sorted most-recent first:\n" - + "\n".join(lines) - + "\n\nAssistant: when replying to the user, preserve the chat-title markdown links exactly as shown, e.g. `[Chat](#session-id)`. Do not rewrite this as a plain, non-clickable table." - ) - } - except Exception as e: - logger.error(f"list_sessions failed: {e}") - return {"error": str(e)} - - -async def do_send_to_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """Send a message to an existing session and get a response. - - Content format: - Line 1: session_id - Line 2+: message - """ - from src.llm_core import llm_call_async - from core.models import ChatMessage - - if not _session_manager: - return {"error": "Session manager not available"} - - lines = content.strip().split("\n", 1) - if len(lines) < 2: - return {"error": "Need 2 lines: session_id, then message"} - - target_sid = lines[0].strip() - message = lines[1].strip() - - sess = _session_manager.get_session(target_sid) - if not sess: - return {"error": f"Session '{target_sid}' not found"} - - # Owner-scope: reject access to another user's session - if owner and getattr(sess, "owner", None) and sess.owner != owner: - return {"error": f"Session '{target_sid}' not found"} - - if not message: - return {"error": "No message provided"} - - try: - # Build context from session history - context = sess.get_context_messages() - context.append({"role": "user", "content": message}) - - response = await llm_call_async( - sess.endpoint_url, sess.model, context, - headers=sess.headers, - timeout=AI_CHAT_TIMEOUT, - ) - - # Save both messages to session - sess.add_message(ChatMessage("user", message)) - sess.add_message(ChatMessage("assistant", response)) - - # Truncate for tool output - if len(response) > 10000: - response = response[:10000] + "\n... (truncated)" - - return { - "session_id": target_sid, - "session_name": sess.name, - "response": response, - } - except Exception as e: - logger.error(f"send_to_session failed: {e}") - return {"error": f"Failed to send to session: {e}"} - - async def stream_ai_tool(tool: str, content: str, session_id: Optional[str] = None, owner: Optional[str] = None): """Dispatcher for streaming AI tools. Yields events as async generator.""" # Fallback: run non-streaming and yield final result @@ -483,229 +287,6 @@ async def do_pipeline(content: str, session_id: Optional[str] = None, owner: Opt # Session management tool # --------------------------------------------------------------------------- -async def do_manage_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """Manage sessions: rename, archive, delete, important, truncate, fork. - - Content format: - Line 1: action (rename|archive|unarchive|delete|important|unimportant|truncate|fork) - Line 2: target session_id (or "current" to use the active session) - Line 3+: action-specific params (e.g. new name for rename, keep_count for truncate) - """ - if not _session_manager: - return {"error": "Session manager not available"} - - from src.database import SessionLocal, Session as DbSession - - # Accept BOTH the structured JSON args the tool schema advertises - # ({action, session_id, value}) AND the legacy line-based format - # (line1=action, line2=session_id, line3=value). Native function-calling - # models send JSON; fenced-block callers send lines. Previously only the - # line format was parsed, so a model that followed the schema (JSON) got - # "Need at least 2 lines" / "Rename needs line 3" and couldn't drive it. - _raw = (content or "").strip() - action = "" - target_sid = "" - value = None # the action param: new name (rename) / keep_count (truncate, fork) - _list_filter = "" - _parsed = None - if _raw.startswith("{"): - try: - _parsed = json.loads(_raw) - except Exception: - _parsed = None - if isinstance(_parsed, dict): - action = str(_parsed.get("action") or "").strip().lower() - target_sid = str(_parsed.get("session_id") or _parsed.get("session") or _parsed.get("id") or "").strip() - _v = _parsed.get("value") - if _v is None: - _v = (_parsed.get("name") or _parsed.get("new_name") - or _parsed.get("title") or _parsed.get("keep_count")) - value = None if _v is None else str(_v).strip() - _list_filter = str(_parsed.get("filter") or "").strip() - else: - lines = _raw.split("\n") - if not lines or not lines[0].strip(): - return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"} - action = lines[0].strip().lower() - target_sid = lines[1].strip() if len(lines) >= 2 else "" - value = lines[2].strip() if len(lines) >= 3 else None - _list_filter = "\n".join(lines[1:]).strip() - - if not action: - return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"} - - # `list` alias — dispatch to do_list_sessions so the agent's natural - # first guess (every other manage_* tool has a `list` action) works. - if action == "list": - return await do_list_sessions(_list_filter, session_id, owner=owner) - - if not target_sid: - return {"error": "Need a session_id (or 'current' for the active chat)"} - - # Allow "current" to refer to the active session - if target_sid.lower() == "current" and session_id: - target_sid = session_id - - # `switch` / `open` / `select` / `view` — the agent reaches for - # these when the user asks to "open" or "switch to" a session. - # There's no server-side way to make the browser navigate, so we - # just return a clickable anchor link the user can click. The - # frontend's chat-history click delegate routes `#session-` - # to selectSession(). The agent's reply naturally embeds this - # result so the user sees a single clickable line. - def _session_query(db): - query = db.query(DbSession).filter(DbSession.id == target_sid) - if owner is not None: - query = query.filter(DbSession.owner == owner) - return query - - if action in ("switch", "open", "select", "view"): - db = SessionLocal() - try: - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - name = db_sess.name or target_sid - finally: - db.close() - return { - "action": action, - "session_id": target_sid, - "name": name, - "results": f"[{name}](#session-{target_sid}) — click to open.", - } - - db = SessionLocal() - try: - if action == "rename": - if not value: - return {"error": "rename needs a new name (the `value` arg, or line 3 in the legacy format)"} - new_name = value - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - db_sess.name = new_name - db.commit() - _session_manager.update_session_name(target_sid, new_name) - return {"action": "rename", "session_id": target_sid, "name": new_name, - "results": f"Session renamed to '{new_name}'"} - - elif action == "archive": - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - db_sess.archived = True - db.commit() - return {"action": "archive", "session_id": target_sid, - "results": f"Session '{db_sess.name}' archived"} - - elif action == "unarchive": - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - db_sess.archived = False - db.commit() - return {"action": "unarchive", "session_id": target_sid, - "results": f"Session '{db_sess.name}' unarchived"} - - elif action == "delete": - if target_sid == session_id: - return {"error": "Cannot delete the current session while chatting in it. Delete other sessions first."} - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Refusing to delete an unknown chat id; use the exact id from list_sessions."} - if db_sess and db_sess.is_important: - return {"error": f"Session '{db_sess.name}' is starred/favorited. Unstar it first before deleting."} - try: - ok = _session_manager.delete_session(target_sid) - if not ok: - return {"error": f"Session '{target_sid}' was not deleted because it no longer exists."} - return {"action": "delete", "session_id": target_sid, - "results": f"Session '{db_sess.name or target_sid}' deleted"} - except Exception as e: - return {"error": f"Failed to delete session: {e}"} - - elif action in ("important", "unimportant"): - is_important = action == "important" - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - # Prevent AI from unstarring sessions — only the user can do that manually - if not is_important and db_sess.is_important: - return {"error": f"Session '{db_sess.name}' is starred by the user. Only the user can unstar sessions manually."} - db_sess.is_important = is_important - db.commit() - status = "marked as important" if is_important else "unmarked as important" - return {"action": action, "session_id": target_sid, - "results": f"Session '{db_sess.name}' {status}"} - - elif action == "truncate": - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - keep_count = 10 - if value: - try: - keep_count = int(value) - except ValueError: - pass - success = _session_manager.truncate_messages(target_sid, keep_count) - if success: - return {"action": "truncate", "session_id": target_sid, - "results": f"Session truncated to last {keep_count} messages"} - return {"error": f"Failed to truncate session '{target_sid}'"} - - elif action == "fork": - db_sess = _session_query(db).first() - if not db_sess: - return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."} - keep_count = 0 # 0 = all messages - if value: - try: - keep_count = int(value) - except ValueError: - pass - - source = _session_manager.get_session(target_sid) - if not source: - return {"error": f"Session '{target_sid}' not found"} - - new_sid = str(uuid.uuid4())[:8] - _session_manager.create_session( - session_id=new_sid, - name=f"Fork: {source.name}", - endpoint_url=source.endpoint_url, - model=source.model, - rag=False, - owner=owner, - ) - # Copy messages - history = source.get_context_messages() - if keep_count > 0: - history = history[:keep_count] - from core.models import ChatMessage as InMemoryMsg - new_sess = _session_manager.get_session(new_sid) - for msg in history: - new_sess.add_message(InMemoryMsg(msg["role"], msg["content"])) - try: - from src.event_bus import fire_event - fire_event("session_created", owner) - except Exception: - logger.debug("session_created event dispatch failed", exc_info=True) - - return {"action": "fork", "session_id": new_sid, - "source_session": target_sid, "messages_copied": len(history), - "results": f"Forked session '{source.name}' -> new session {new_sid} ({len(history)} messages)"} - - else: - return {"error": f"Unknown action '{action}'. Use: list, switch, rename, archive, unarchive, delete, important, unimportant, truncate, fork"} - except Exception as e: - logger.error(f"manage_session failed: {e}") - return {"error": str(e)} - finally: - db.close() - - # --------------------------------------------------------------------------- # Memory management tool # --------------------------------------------------------------------------- @@ -1522,30 +1103,10 @@ async def dispatch_ai_tool( ) -> Tuple[str, Dict]: """Dispatch an AI interaction tool. Returns (description, result_dict).""" - if tool == "create_session": - name = content.split("\n")[0].strip()[:60] - desc = f"create_session: {name}" - result = await do_create_session(content, session_id, owner=owner) - - elif tool == "list_sessions": - keyword = content.strip()[:40] - desc = f"list_sessions{': ' + keyword if keyword else ''}" - result = await do_list_sessions(content, session_id, owner=owner) - - elif tool == "send_to_session": - sid = content.split("\n")[0].strip()[:20] - desc = f"send_to_session: {sid}" - result = await do_send_to_session(content, session_id, owner=owner) - - elif tool == "pipeline": + if tool == "pipeline": desc = "pipeline: running steps" result = await do_pipeline(content, session_id, owner=owner) - elif tool == "manage_session": - action = content.split("\n")[0].strip()[:40] - desc = f"manage_session: {action}" - result = await do_manage_session(content, session_id, owner=owner) - elif tool == "manage_memory": action = content.split("\n")[0].strip()[:40] desc = f"manage_memory: {action}" diff --git a/src/tool_execution.py b/src/tool_execution.py index c13910e3a..3b4ba5eab 100644 --- a/src/tool_execution.py +++ b/src/tool_execution.py @@ -781,16 +781,21 @@ async def _execute_tool_block_impl( elif tool in ("chat_with_model", "ask_teacher", "list_models"): # Migrated to the agent_tools registry (#3629): dispatched through # TOOL_HANDLERS with the owner/session ctx these tools need, instead - # of the legacy dispatch_ai_tool elif. The do_* impls stay in - # ai_interaction.py (dispatch_ai_tool + the owner-scope test use them). + # of the legacy dispatch_ai_tool elif. The impls live in + # src/agent_tools/model_interaction_tools.py. first_line = content.split(chr(10))[0].strip()[:60] desc = f"{tool}: {first_line}" if first_line else tool result = await _document_tool_dispatch(tool, content, session_id, owner) \ or {"error": f"{tool}: execution failed", "exit_code": 1} - elif tool in ("create_session", "list_sessions", - "send_to_session", "pipeline", - "manage_session", "manage_memory", - "ui_control"): + elif tool in ("create_session", "list_sessions", "send_to_session", "manage_session"): + # Migrated to the agent_tools registry (#3629): dispatched through + # TOOL_HANDLERS with the owner/session ctx these tools need. The impls + # live in src/agent_tools/session_tools.py. + first_line = content.split(chr(10))[0].strip()[:60] + desc = f"{tool}: {first_line}" if first_line else tool + result = await _document_tool_dispatch(tool, content, session_id, owner) \ + or {"error": f"{tool}: execution failed", "exit_code": 1} + elif tool in ("pipeline", "manage_memory", "ui_control"): from src.ai_interaction import dispatch_ai_tool desc, result = await dispatch_ai_tool(tool, content, session_id, owner=owner) elif tool == "manage_tasks": diff --git a/tests/test_session_tools_registry.py b/tests/test_session_tools_registry.py new file mode 100644 index 000000000..804cfdbdc --- /dev/null +++ b/tests/test_session_tools_registry.py @@ -0,0 +1,149 @@ +"""Tests for the session tools' move to the agent_tools registry (#3629): +create_session, list_sessions, send_to_session, manage_session. + +The implementations now live in src/agent_tools/session_tools.py (moved out of +src/ai_interaction.py). These assert (1) the handlers are registered in +TOOL_HANDLERS, (2) the moved logic runs and threads owner/session from ctx +(the session manager is fetched via ai_interaction.get_session_manager), and +(3) tool_execution.py dispatches them through the registry rather than the +legacy dispatch_ai_tool elif. +""" +import asyncio +from pathlib import Path + +import src.ai_interaction as ai_interaction +import src.database as database +from src.agent_tools import TOOL_HANDLERS +from src.agent_tools import session_tools as st + +_SESSION_TOOLS = ("create_session", "list_sessions", "send_to_session", "manage_session") + + +def test_session_tools_registered(): + for name in _SESSION_TOOLS: + assert name in TOOL_HANDLERS, f"{name} missing from TOOL_HANDLERS" + + +def test_list_sessions_handler_threads_ctx(monkeypatch): + # The handler must thread content + session_id + owner from ctx into the + # moved list_sessions implementation. Spy at the function boundary so the + # test does not depend on list_sessions' DB internals. + seen = {} + + async def spy(content, session_id=None, owner=None): + seen.update(content=content, session_id=session_id, owner=owner) + return {"results": "ok"} + + monkeypatch.setattr(st, "list_sessions", spy) + res = asyncio.run(st.ListSessionsTool().execute("q", {"owner": "alice", "session_id": "s1"})) + assert res == {"results": "ok"} + assert seen == {"content": "q", "session_id": "s1", "owner": "alice"} + + +def test_manage_session_list_delegates_to_list_sessions(monkeypatch): + # manage_session("list") must delegate to list_sessions; guards against a + # stale do_list_sessions reference surviving the move (caught live in e2e). + called = {} + + async def spy(content, session_id=None, owner=None): + called["owner"] = owner + return {"results": "ok"} + + monkeypatch.setattr(st, "list_sessions", spy) + # manage_session imports `Session` from src.database before the list branch; + # the src.database test double may not expose it, so provide a stand-in. + monkeypatch.setattr(database, "Session", object, raising=False) + monkeypatch.setattr(ai_interaction, "_session_manager", object()) # truthy: pass the guard + res = asyncio.run(st.ManageSessionTool().execute("list", {"owner": "carol"})) + assert called.get("owner") == "carol" + assert res == {"results": "ok"} + + +def test_create_session_reaches_uuid_and_creates(monkeypatch): + # Regression for the missing `import uuid` (PR review): create_session must + # get past _resolve_model and mint a session id without NameError. + monkeypatch.setattr(st, "_resolve_model", lambda spec, owner=None: ("http://x", "model-x", {})) + created = {} + + class FakeMgr: + def create_session(self, **kw): + created.update(kw) + + def get_session(self, sid): + return None + + monkeypatch.setattr(ai_interaction, "_session_manager", FakeMgr()) + res = asyncio.run(st.CreateSessionTool().execute("My Chat\nmodel-x", {"owner": "alice"})) + assert res.get("name") == "My Chat" and res.get("model") == "model-x" + assert isinstance(res.get("session_id"), str) and res["session_id"] + assert created.get("name") == "My Chat" # the uuid-minted id reached the manager + + +def test_manage_session_fork_reaches_uuid(monkeypatch): + # Regression for the missing `import uuid`: the fork action also mints a new + # session id and must not NameError. Mocks the DB query layer so the fork + # branch reaches the uuid call without a real sessions table. + class FakeDbSession: + id = "id" + owner = "owner" + + class FakeQ: + def filter(self, *a, **k): + return self + + def first(self): + return object() + + class FakeDB: + def query(self, *a, **k): + return FakeQ() + + def close(self): + pass + + monkeypatch.setattr(database, "Session", FakeDbSession, raising=False) + monkeypatch.setattr(database, "SessionLocal", lambda: FakeDB(), raising=False) + + class Src: + name = "Orig" + endpoint_url = "http://x" + model = "m" + + def get_context_messages(self): + return [] + + created = {} + + class FakeMgr: + def get_session(self, sid): + return Src() if sid == "abc" else type("S", (), {"add_message": lambda self, m: None})() + + def create_session(self, **kw): + created.update(kw) + + monkeypatch.setattr(ai_interaction, "_session_manager", FakeMgr()) + res = asyncio.run(st.ManageSessionTool().execute('{"action":"fork","session_id":"abc"}', {"owner": "owner"})) + assert res.get("action") == "fork" + assert isinstance(res.get("session_id"), str) and res["session_id"] + assert created.get("name") == "Fork: Orig" # uuid-minted new session was created + + +def test_no_session_manager_is_handled(monkeypatch): + # With no session manager set, the moved function must fail gracefully + # (proves the handler reached the impl, not an "unknown tool"). + monkeypatch.setattr(ai_interaction, "_session_manager", None) + res = asyncio.run(st.ListSessionsTool().execute("", {"owner": "bob"})) + assert isinstance(res, dict) + assert "error" in res or "results" in res + + +def test_dispatched_via_registry_not_dispatch_ai_tool(): + source = (Path(__file__).resolve().parent.parent / "src" / "tool_execution.py").read_text(encoding="utf-8") + assert 'elif tool in ("create_session", "list_sessions", "send_to_session", "manage_session"):' in source + + marker = "from src.ai_interaction import dispatch_ai_tool" + idx = source.index(marker) + branch_head = source.rfind("elif tool in (", 0, idx) + legacy_tuple = source[branch_head:idx] + for name in _SESSION_TOOLS: + assert f'"{name}"' not in legacy_tuple, f"{name} still routed via dispatch_ai_tool"