mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-19 11:15:24 -04:00
refactor(tools): move session tools to the agent_tools registry (#4454)
Moves create_session, list_sessions, send_to_session and manage_session out of ai_interaction.py into src/agent_tools/session_tools.py (the do_ prefix dropped) and registers them in TOOL_HANDLERS, so dispatch flows through the registry instead of the dispatch_ai_tool elif in tool_execution.py. Same pattern as the model-interaction move. The bodies move verbatim; each fetches the runtime-set session manager via a get_session_manager() shim, and reuses _resolve_model / AI_CHAT_TIMEOUT from ai_interaction. manage_session's internal 'list' alias is repointed from the old do_list_sessions to the moved list_sessions. stream_ai_tool (dead, no callers) and do_pipeline stay put. dispatch_ai_tool loses its four now-unused branches. Tests: test_session_tools_registry covers registration, owner threading, the manage_session->list_sessions delegation, graceful no-manager handling, and registry dispatch. Verified end-to-end against a live SessionManager.
This commit is contained in:
committed by
GitHub
parent
076e8c93c9
commit
ed18192a8e
@@ -24,6 +24,7 @@ from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool,
|
||||
from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool
|
||||
from .model_interaction_tools import ChatWithModelTool, AskTeacherTool, ListModelsTool
|
||||
from .bg_job_tools import ManageBgJobsTool
|
||||
from .session_tools import CreateSessionTool, ListSessionsTool, SendToSessionTool, ManageSessionTool
|
||||
|
||||
TOOL_HANDLERS = {
|
||||
"bash": BashTool().execute,
|
||||
@@ -46,6 +47,10 @@ TOOL_HANDLERS = {
|
||||
"ask_teacher": AskTeacherTool().execute,
|
||||
"list_models": ListModelsTool().execute,
|
||||
"manage_bg_jobs": ManageBgJobsTool().execute,
|
||||
"create_session": CreateSessionTool().execute,
|
||||
"list_sessions": ListSessionsTool().execute,
|
||||
"send_to_session": SendToSessionTool().execute,
|
||||
"manage_session": ManageSessionTool().execute,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
"""session_tools.py - agent tools for AI-to-AI session management.
|
||||
|
||||
Owns create_session, list_sessions, send_to_session and manage_session, moved
|
||||
out of src.ai_interaction as part of the tool -> registry migration (#3629), and
|
||||
their handler classes registered in TOOL_HANDLERS.
|
||||
|
||||
The session manager is a runtime-set singleton in src.ai_interaction, so each
|
||||
function fetches it via get_session_manager() (imported here); _resolve_model and
|
||||
AI_CHAT_TIMEOUT are reused from there too.
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.ai_interaction import get_session_manager, _resolve_model, AI_CHAT_TIMEOUT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Create a new chat session.
|
||||
|
||||
Content format:
|
||||
Line 1: session name
|
||||
Line 2: model_name (or model_name@endpoint_name)
|
||||
"""
|
||||
_session_manager = get_session_manager()
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return {"error": "Need 2 lines: session name, then model spec"}
|
||||
|
||||
name = lines[0].strip()
|
||||
model_spec = lines[1].strip()
|
||||
|
||||
if not name:
|
||||
return {"error": "Session name cannot be empty"}
|
||||
|
||||
try:
|
||||
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
sid = str(uuid.uuid4())[:8]
|
||||
try:
|
||||
_session_manager.create_session(
|
||||
session_id=sid,
|
||||
name=name,
|
||||
endpoint_url=url,
|
||||
model=model,
|
||||
rag=False,
|
||||
owner=owner,
|
||||
)
|
||||
# Store headers on session for future calls
|
||||
sess = _session_manager.get_session(sid)
|
||||
if sess and headers:
|
||||
sess.headers = headers
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", owner)
|
||||
except Exception:
|
||||
logger.debug("session_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {"session_id": sid, "name": name, "model": model, "endpoint_url": url}
|
||||
except Exception as e:
|
||||
logger.error(f"create_session failed: {e}")
|
||||
return {"error": f"Failed to create session: {e}"}
|
||||
|
||||
async def list_sessions(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""List sessions sorted by most-recently-active first.
|
||||
|
||||
Output includes a relative "last active" timestamp per row so the
|
||||
agent can answer "open my last chat" without guessing from titles.
|
||||
The most-recent session is always first in the list.
|
||||
|
||||
Content = optional filter keyword (matches session name).
|
||||
"""
|
||||
_session_manager = get_session_manager()
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
keyword = content.strip().lower() if content.strip() else None
|
||||
|
||||
try:
|
||||
from core.database import SessionLocal, Session as DbSession
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Pull every session's last_accessed from the DB so we can sort
|
||||
# by recency. In-memory sessions hold name + model + msg_count;
|
||||
# the DB row holds the timestamps.
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_rows = {r.id: r for r in db.query(DbSession).all()}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# SECURITY: scope to the caller's sessions. Passing None returned
|
||||
# every user's sessions, which the agent tool then exposed via the
|
||||
# "list my chats" reply.
|
||||
sessions = _session_manager.get_sessions_for_user(owner)
|
||||
rows = []
|
||||
for sid, sess in sessions.items():
|
||||
if keyword and keyword not in (sess.name or "").lower():
|
||||
continue
|
||||
db_row = db_rows.get(sid)
|
||||
# Prefer last_accessed; fall back to updated_at, then created_at.
|
||||
ts = None
|
||||
if db_row:
|
||||
ts = getattr(db_row, 'last_accessed', None) or getattr(db_row, 'updated_at', None) or getattr(db_row, 'created_at', None)
|
||||
rows.append((ts, sid, sess))
|
||||
|
||||
# Sort by timestamp DESC; rows without a timestamp sink to the bottom.
|
||||
rows.sort(key=lambda r: r[0] or datetime.min, reverse=True)
|
||||
|
||||
def _rel(ts):
|
||||
if not ts:
|
||||
return 'never'
|
||||
now = datetime.utcnow()
|
||||
try:
|
||||
if ts.tzinfo is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
diff = (now - ts).total_seconds()
|
||||
except Exception:
|
||||
return 'unknown'
|
||||
if diff < 60: return 'just now'
|
||||
if diff < 3600: return f'{int(diff / 60)}m ago'
|
||||
if diff < 86400: return f'{int(diff / 3600)}h ago'
|
||||
if diff < 86400 * 7: return f'{int(diff / 86400)}d ago'
|
||||
return ts.strftime('%Y-%m-%d')
|
||||
|
||||
lines = []
|
||||
for i, (ts, sid, sess) in enumerate(rows):
|
||||
if i >= 50:
|
||||
lines.append(f"... and {len(rows) - 50} more (showing first 50)")
|
||||
break
|
||||
safe_name = (sess.name or "Untitled").replace("[", "\\[").replace("]", "\\]")
|
||||
msg_count = getattr(sess, "message_count", 0) or 0
|
||||
model = getattr(sess, "model", "unknown")
|
||||
marker = " ← most recent" if i == 0 else ""
|
||||
lines.append(f"- **[{safe_name}](#session-{sid})** (id: `{sid}`, model: {model}, {msg_count} msgs, last active {_rel(ts)}){marker}")
|
||||
|
||||
if not lines:
|
||||
return {"results": "No sessions found" + (f" matching '{keyword}'" if keyword else "") + "."}
|
||||
|
||||
return {
|
||||
"results": (
|
||||
f"Found {len(rows)} session(s), sorted most-recent first:\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\n\nAssistant: when replying to the user, preserve the chat-title markdown links exactly as shown, e.g. `[Chat](#session-id)`. Do not rewrite this as a plain, non-clickable table."
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"list_sessions failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def send_to_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Send a message to an existing session and get a response.
|
||||
|
||||
Content format:
|
||||
Line 1: session_id
|
||||
Line 2+: message
|
||||
"""
|
||||
_session_manager = get_session_manager()
|
||||
from src.llm_core import llm_call_async
|
||||
from core.models import ChatMessage
|
||||
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
lines = content.strip().split("\n", 1)
|
||||
if len(lines) < 2:
|
||||
return {"error": "Need 2 lines: session_id, then message"}
|
||||
|
||||
target_sid = lines[0].strip()
|
||||
message = lines[1].strip()
|
||||
|
||||
sess = _session_manager.get_session(target_sid)
|
||||
if not sess:
|
||||
return {"error": f"Session '{target_sid}' not found"}
|
||||
|
||||
# Owner-scope: reject access to another user's session
|
||||
if owner and getattr(sess, "owner", None) and sess.owner != owner:
|
||||
return {"error": f"Session '{target_sid}' not found"}
|
||||
|
||||
if not message:
|
||||
return {"error": "No message provided"}
|
||||
|
||||
try:
|
||||
# Build context from session history
|
||||
context = sess.get_context_messages()
|
||||
context.append({"role": "user", "content": message})
|
||||
|
||||
response = await llm_call_async(
|
||||
sess.endpoint_url, sess.model, context,
|
||||
headers=sess.headers,
|
||||
timeout=AI_CHAT_TIMEOUT,
|
||||
)
|
||||
|
||||
# Save both messages to session
|
||||
sess.add_message(ChatMessage("user", message))
|
||||
sess.add_message(ChatMessage("assistant", response))
|
||||
|
||||
# Truncate for tool output
|
||||
if len(response) > 10000:
|
||||
response = response[:10000] + "\n... (truncated)"
|
||||
|
||||
return {
|
||||
"session_id": target_sid,
|
||||
"session_name": sess.name,
|
||||
"response": response,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"send_to_session failed: {e}")
|
||||
return {"error": f"Failed to send to session: {e}"}
|
||||
|
||||
async def manage_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Manage sessions: rename, archive, delete, important, truncate, fork.
|
||||
|
||||
Content format:
|
||||
Line 1: action (rename|archive|unarchive|delete|important|unimportant|truncate|fork)
|
||||
Line 2: target session_id (or "current" to use the active session)
|
||||
Line 3+: action-specific params (e.g. new name for rename, keep_count for truncate)
|
||||
"""
|
||||
_session_manager = get_session_manager()
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
from src.database import SessionLocal, Session as DbSession
|
||||
|
||||
# Accept BOTH the structured JSON args the tool schema advertises
|
||||
# ({action, session_id, value}) AND the legacy line-based format
|
||||
# (line1=action, line2=session_id, line3=value). Native function-calling
|
||||
# models send JSON; fenced-block callers send lines. Previously only the
|
||||
# line format was parsed, so a model that followed the schema (JSON) got
|
||||
# "Need at least 2 lines" / "Rename needs line 3" and couldn't drive it.
|
||||
_raw = (content or "").strip()
|
||||
action = ""
|
||||
target_sid = ""
|
||||
value = None # the action param: new name (rename) / keep_count (truncate, fork)
|
||||
_list_filter = ""
|
||||
_parsed = None
|
||||
if _raw.startswith("{"):
|
||||
try:
|
||||
_parsed = json.loads(_raw)
|
||||
except Exception:
|
||||
_parsed = None
|
||||
if isinstance(_parsed, dict):
|
||||
action = str(_parsed.get("action") or "").strip().lower()
|
||||
target_sid = str(_parsed.get("session_id") or _parsed.get("session") or _parsed.get("id") or "").strip()
|
||||
_v = _parsed.get("value")
|
||||
if _v is None:
|
||||
_v = (_parsed.get("name") or _parsed.get("new_name")
|
||||
or _parsed.get("title") or _parsed.get("keep_count"))
|
||||
value = None if _v is None else str(_v).strip()
|
||||
_list_filter = str(_parsed.get("filter") or "").strip()
|
||||
else:
|
||||
lines = _raw.split("\n")
|
||||
if not lines or not lines[0].strip():
|
||||
return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"}
|
||||
action = lines[0].strip().lower()
|
||||
target_sid = lines[1].strip() if len(lines) >= 2 else ""
|
||||
value = lines[2].strip() if len(lines) >= 3 else None
|
||||
_list_filter = "\n".join(lines[1:]).strip()
|
||||
|
||||
if not action:
|
||||
return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"}
|
||||
|
||||
# `list` alias - dispatch to list_sessions so the agent's natural
|
||||
# first guess (every other manage_* tool has a `list` action) works.
|
||||
if action == "list":
|
||||
return await list_sessions(_list_filter, session_id, owner=owner)
|
||||
|
||||
if not target_sid:
|
||||
return {"error": "Need a session_id (or 'current' for the active chat)"}
|
||||
|
||||
# Allow "current" to refer to the active session
|
||||
if target_sid.lower() == "current" and session_id:
|
||||
target_sid = session_id
|
||||
|
||||
# `switch` / `open` / `select` / `view` - the agent reaches for
|
||||
# these when the user asks to "open" or "switch to" a session.
|
||||
# There's no server-side way to make the browser navigate, so we
|
||||
# just return a clickable anchor link the user can click. The
|
||||
# frontend's chat-history click delegate routes `#session-<id>`
|
||||
# to selectSession(). The agent's reply naturally embeds this
|
||||
# result so the user sees a single clickable line.
|
||||
def _session_query(db):
|
||||
query = db.query(DbSession).filter(DbSession.id == target_sid)
|
||||
if owner is not None:
|
||||
query = query.filter(DbSession.owner == owner)
|
||||
return query
|
||||
|
||||
if action in ("switch", "open", "select", "view"):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
name = db_sess.name or target_sid
|
||||
finally:
|
||||
db.close()
|
||||
return {
|
||||
"action": action,
|
||||
"session_id": target_sid,
|
||||
"name": name,
|
||||
"results": f"[{name}](#session-{target_sid}) - click to open.",
|
||||
}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if action == "rename":
|
||||
if not value:
|
||||
return {"error": "rename needs a new name (the `value` arg, or line 3 in the legacy format)"}
|
||||
new_name = value
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
db_sess.name = new_name
|
||||
db.commit()
|
||||
_session_manager.update_session_name(target_sid, new_name)
|
||||
return {"action": "rename", "session_id": target_sid, "name": new_name,
|
||||
"results": f"Session renamed to '{new_name}'"}
|
||||
|
||||
elif action == "archive":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
db_sess.archived = True
|
||||
db.commit()
|
||||
return {"action": "archive", "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name}' archived"}
|
||||
|
||||
elif action == "unarchive":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
db_sess.archived = False
|
||||
db.commit()
|
||||
return {"action": "unarchive", "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name}' unarchived"}
|
||||
|
||||
elif action == "delete":
|
||||
if target_sid == session_id:
|
||||
return {"error": "Cannot delete the current session while chatting in it. Delete other sessions first."}
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Refusing to delete an unknown chat id; use the exact id from list_sessions."}
|
||||
if db_sess and db_sess.is_important:
|
||||
return {"error": f"Session '{db_sess.name}' is starred/favorited. Unstar it first before deleting."}
|
||||
try:
|
||||
ok = _session_manager.delete_session(target_sid)
|
||||
if not ok:
|
||||
return {"error": f"Session '{target_sid}' was not deleted because it no longer exists."}
|
||||
return {"action": "delete", "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name or target_sid}' deleted"}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to delete session: {e}"}
|
||||
|
||||
elif action in ("important", "unimportant"):
|
||||
is_important = action == "important"
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
# Prevent AI from unstarring sessions - only the user can do that manually
|
||||
if not is_important and db_sess.is_important:
|
||||
return {"error": f"Session '{db_sess.name}' is starred by the user. Only the user can unstar sessions manually."}
|
||||
db_sess.is_important = is_important
|
||||
db.commit()
|
||||
status = "marked as important" if is_important else "unmarked as important"
|
||||
return {"action": action, "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name}' {status}"}
|
||||
|
||||
elif action == "truncate":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
keep_count = 10
|
||||
if value:
|
||||
try:
|
||||
keep_count = int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
success = _session_manager.truncate_messages(target_sid, keep_count)
|
||||
if success:
|
||||
return {"action": "truncate", "session_id": target_sid,
|
||||
"results": f"Session truncated to last {keep_count} messages"}
|
||||
return {"error": f"Failed to truncate session '{target_sid}'"}
|
||||
|
||||
elif action == "fork":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
keep_count = 0 # 0 = all messages
|
||||
if value:
|
||||
try:
|
||||
keep_count = int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
source = _session_manager.get_session(target_sid)
|
||||
if not source:
|
||||
return {"error": f"Session '{target_sid}' not found"}
|
||||
|
||||
new_sid = str(uuid.uuid4())[:8]
|
||||
_session_manager.create_session(
|
||||
session_id=new_sid,
|
||||
name=f"Fork: {source.name}",
|
||||
endpoint_url=source.endpoint_url,
|
||||
model=source.model,
|
||||
rag=False,
|
||||
owner=owner,
|
||||
)
|
||||
# Copy messages
|
||||
history = source.get_context_messages()
|
||||
if keep_count > 0:
|
||||
history = history[:keep_count]
|
||||
from core.models import ChatMessage as InMemoryMsg
|
||||
new_sess = _session_manager.get_session(new_sid)
|
||||
for msg in history:
|
||||
new_sess.add_message(InMemoryMsg(msg["role"], msg["content"]))
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", owner)
|
||||
except Exception:
|
||||
logger.debug("session_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {"action": "fork", "session_id": new_sid,
|
||||
"source_session": target_sid, "messages_copied": len(history),
|
||||
"results": f"Forked session '{source.name}' -> new session {new_sid} ({len(history)} messages)"}
|
||||
|
||||
else:
|
||||
return {"error": f"Unknown action '{action}'. Use: list, switch, rename, archive, unarchive, delete, important, unimportant, truncate, fork"}
|
||||
except Exception as e:
|
||||
logger.error(f"manage_session failed: {e}")
|
||||
return {"error": str(e)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler classes registered in TOOL_HANDLERS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CreateSessionTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
return await create_session(content, ctx.get("session_id"), owner=ctx.get("owner"))
|
||||
|
||||
|
||||
class ListSessionsTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
return await list_sessions(content, ctx.get("session_id"), owner=ctx.get("owner"))
|
||||
|
||||
|
||||
class SendToSessionTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
return await send_to_session(content, ctx.get("session_id"), owner=ctx.get("owner"))
|
||||
|
||||
|
||||
class ManageSessionTool:
|
||||
async def execute(self, content: str, ctx: dict) -> Dict:
|
||||
return await manage_session(content, ctx.get("session_id"), owner=ctx.get("owner"))
|
||||
+8
-447
@@ -1,12 +1,14 @@
|
||||
"""
|
||||
ai_interaction.py
|
||||
|
||||
AI-to-AI interaction tools: create_session, list_sessions, send_to_session,
|
||||
pipeline, plus shared model resolution (_resolve_model).
|
||||
AI-to-AI interaction tools: pipeline and manage_memory, plus shared model
|
||||
resolution (_resolve_model), the session-manager singleton, and dispatch_ai_tool.
|
||||
|
||||
chat_with_model, ask_teacher and list_models were moved to
|
||||
src/agent_tools/model_interaction_tools.py as part of the tool -> registry
|
||||
migration (#3629); they still reuse _resolve_model / AI_CHAT_TIMEOUT from here.
|
||||
As part of the tool -> registry migration (#3629), chat_with_model, ask_teacher
|
||||
and list_models moved to src/agent_tools/model_interaction_tools.py, and
|
||||
create_session, list_sessions, send_to_session and manage_session moved to
|
||||
src/agent_tools/session_tools.py. Those modules reuse get_session_manager /
|
||||
_resolve_model / AI_CHAT_TIMEOUT from here.
|
||||
|
||||
These are agent tools — the LLM writes fenced code blocks and they execute
|
||||
through the standard agent_tools.py pipeline.
|
||||
@@ -165,204 +167,6 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
|
||||
|
||||
|
||||
async def do_create_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Create a new chat session.
|
||||
|
||||
Content format:
|
||||
Line 1: session name
|
||||
Line 2: model_name (or model_name@endpoint_name)
|
||||
"""
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return {"error": "Need 2 lines: session name, then model spec"}
|
||||
|
||||
name = lines[0].strip()
|
||||
model_spec = lines[1].strip()
|
||||
|
||||
if not name:
|
||||
return {"error": "Session name cannot be empty"}
|
||||
|
||||
try:
|
||||
url, model, headers = _resolve_model(model_spec, owner=owner)
|
||||
except ValueError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
sid = str(uuid.uuid4())[:8]
|
||||
try:
|
||||
_session_manager.create_session(
|
||||
session_id=sid,
|
||||
name=name,
|
||||
endpoint_url=url,
|
||||
model=model,
|
||||
rag=False,
|
||||
owner=owner,
|
||||
)
|
||||
# Store headers on session for future calls
|
||||
sess = _session_manager.get_session(sid)
|
||||
if sess and headers:
|
||||
sess.headers = headers
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", owner)
|
||||
except Exception:
|
||||
logger.debug("session_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {"session_id": sid, "name": name, "model": model, "endpoint_url": url}
|
||||
except Exception as e:
|
||||
logger.error(f"create_session failed: {e}")
|
||||
return {"error": f"Failed to create session: {e}"}
|
||||
|
||||
|
||||
async def do_list_sessions(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""List sessions sorted by most-recently-active first.
|
||||
|
||||
Output includes a relative "last active" timestamp per row so the
|
||||
agent can answer "open my last chat" without guessing from titles.
|
||||
The most-recent session is always first in the list.
|
||||
|
||||
Content = optional filter keyword (matches session name).
|
||||
"""
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
keyword = content.strip().lower() if content.strip() else None
|
||||
|
||||
try:
|
||||
from core.database import SessionLocal, Session as DbSession
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Pull every session's last_accessed from the DB so we can sort
|
||||
# by recency. In-memory sessions hold name + model + msg_count;
|
||||
# the DB row holds the timestamps.
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_rows = {r.id: r for r in db.query(DbSession).all()}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# SECURITY: scope to the caller's sessions. Passing None returned
|
||||
# every user's sessions, which the agent tool then exposed via the
|
||||
# "list my chats" reply.
|
||||
sessions = _session_manager.get_sessions_for_user(owner)
|
||||
rows = []
|
||||
for sid, sess in sessions.items():
|
||||
if keyword and keyword not in (sess.name or "").lower():
|
||||
continue
|
||||
db_row = db_rows.get(sid)
|
||||
# Prefer last_accessed; fall back to updated_at, then created_at.
|
||||
ts = None
|
||||
if db_row:
|
||||
ts = getattr(db_row, 'last_accessed', None) or getattr(db_row, 'updated_at', None) or getattr(db_row, 'created_at', None)
|
||||
rows.append((ts, sid, sess))
|
||||
|
||||
# Sort by timestamp DESC; rows without a timestamp sink to the bottom.
|
||||
rows.sort(key=lambda r: r[0] or datetime.min, reverse=True)
|
||||
|
||||
def _rel(ts):
|
||||
if not ts:
|
||||
return 'never'
|
||||
now = datetime.utcnow()
|
||||
try:
|
||||
if ts.tzinfo is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
diff = (now - ts).total_seconds()
|
||||
except Exception:
|
||||
return 'unknown'
|
||||
if diff < 60: return 'just now'
|
||||
if diff < 3600: return f'{int(diff / 60)}m ago'
|
||||
if diff < 86400: return f'{int(diff / 3600)}h ago'
|
||||
if diff < 86400 * 7: return f'{int(diff / 86400)}d ago'
|
||||
return ts.strftime('%Y-%m-%d')
|
||||
|
||||
lines = []
|
||||
for i, (ts, sid, sess) in enumerate(rows):
|
||||
if i >= 50:
|
||||
lines.append(f"... and {len(rows) - 50} more (showing first 50)")
|
||||
break
|
||||
safe_name = (sess.name or "Untitled").replace("[", "\\[").replace("]", "\\]")
|
||||
msg_count = getattr(sess, "message_count", 0) or 0
|
||||
model = getattr(sess, "model", "unknown")
|
||||
marker = " ← most recent" if i == 0 else ""
|
||||
lines.append(f"- **[{safe_name}](#session-{sid})** (id: `{sid}`, model: {model}, {msg_count} msgs, last active {_rel(ts)}){marker}")
|
||||
|
||||
if not lines:
|
||||
return {"results": "No sessions found" + (f" matching '{keyword}'" if keyword else "") + "."}
|
||||
|
||||
return {
|
||||
"results": (
|
||||
f"Found {len(rows)} session(s), sorted most-recent first:\n"
|
||||
+ "\n".join(lines)
|
||||
+ "\n\nAssistant: when replying to the user, preserve the chat-title markdown links exactly as shown, e.g. `[Chat](#session-id)`. Do not rewrite this as a plain, non-clickable table."
|
||||
)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"list_sessions failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def do_send_to_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Send a message to an existing session and get a response.
|
||||
|
||||
Content format:
|
||||
Line 1: session_id
|
||||
Line 2+: message
|
||||
"""
|
||||
from src.llm_core import llm_call_async
|
||||
from core.models import ChatMessage
|
||||
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
lines = content.strip().split("\n", 1)
|
||||
if len(lines) < 2:
|
||||
return {"error": "Need 2 lines: session_id, then message"}
|
||||
|
||||
target_sid = lines[0].strip()
|
||||
message = lines[1].strip()
|
||||
|
||||
sess = _session_manager.get_session(target_sid)
|
||||
if not sess:
|
||||
return {"error": f"Session '{target_sid}' not found"}
|
||||
|
||||
# Owner-scope: reject access to another user's session
|
||||
if owner and getattr(sess, "owner", None) and sess.owner != owner:
|
||||
return {"error": f"Session '{target_sid}' not found"}
|
||||
|
||||
if not message:
|
||||
return {"error": "No message provided"}
|
||||
|
||||
try:
|
||||
# Build context from session history
|
||||
context = sess.get_context_messages()
|
||||
context.append({"role": "user", "content": message})
|
||||
|
||||
response = await llm_call_async(
|
||||
sess.endpoint_url, sess.model, context,
|
||||
headers=sess.headers,
|
||||
timeout=AI_CHAT_TIMEOUT,
|
||||
)
|
||||
|
||||
# Save both messages to session
|
||||
sess.add_message(ChatMessage("user", message))
|
||||
sess.add_message(ChatMessage("assistant", response))
|
||||
|
||||
# Truncate for tool output
|
||||
if len(response) > 10000:
|
||||
response = response[:10000] + "\n... (truncated)"
|
||||
|
||||
return {
|
||||
"session_id": target_sid,
|
||||
"session_name": sess.name,
|
||||
"response": response,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"send_to_session failed: {e}")
|
||||
return {"error": f"Failed to send to session: {e}"}
|
||||
|
||||
|
||||
async def stream_ai_tool(tool: str, content: str, session_id: Optional[str] = None, owner: Optional[str] = None):
|
||||
"""Dispatcher for streaming AI tools. Yields events as async generator."""
|
||||
# Fallback: run non-streaming and yield final result
|
||||
@@ -483,229 +287,6 @@ async def do_pipeline(content: str, session_id: Optional[str] = None, owner: Opt
|
||||
# Session management tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def do_manage_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict:
|
||||
"""Manage sessions: rename, archive, delete, important, truncate, fork.
|
||||
|
||||
Content format:
|
||||
Line 1: action (rename|archive|unarchive|delete|important|unimportant|truncate|fork)
|
||||
Line 2: target session_id (or "current" to use the active session)
|
||||
Line 3+: action-specific params (e.g. new name for rename, keep_count for truncate)
|
||||
"""
|
||||
if not _session_manager:
|
||||
return {"error": "Session manager not available"}
|
||||
|
||||
from src.database import SessionLocal, Session as DbSession
|
||||
|
||||
# Accept BOTH the structured JSON args the tool schema advertises
|
||||
# ({action, session_id, value}) AND the legacy line-based format
|
||||
# (line1=action, line2=session_id, line3=value). Native function-calling
|
||||
# models send JSON; fenced-block callers send lines. Previously only the
|
||||
# line format was parsed, so a model that followed the schema (JSON) got
|
||||
# "Need at least 2 lines" / "Rename needs line 3" and couldn't drive it.
|
||||
_raw = (content or "").strip()
|
||||
action = ""
|
||||
target_sid = ""
|
||||
value = None # the action param: new name (rename) / keep_count (truncate, fork)
|
||||
_list_filter = ""
|
||||
_parsed = None
|
||||
if _raw.startswith("{"):
|
||||
try:
|
||||
_parsed = json.loads(_raw)
|
||||
except Exception:
|
||||
_parsed = None
|
||||
if isinstance(_parsed, dict):
|
||||
action = str(_parsed.get("action") or "").strip().lower()
|
||||
target_sid = str(_parsed.get("session_id") or _parsed.get("session") or _parsed.get("id") or "").strip()
|
||||
_v = _parsed.get("value")
|
||||
if _v is None:
|
||||
_v = (_parsed.get("name") or _parsed.get("new_name")
|
||||
or _parsed.get("title") or _parsed.get("keep_count"))
|
||||
value = None if _v is None else str(_v).strip()
|
||||
_list_filter = str(_parsed.get("filter") or "").strip()
|
||||
else:
|
||||
lines = _raw.split("\n")
|
||||
if not lines or not lines[0].strip():
|
||||
return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"}
|
||||
action = lines[0].strip().lower()
|
||||
target_sid = lines[1].strip() if len(lines) >= 2 else ""
|
||||
value = lines[2].strip() if len(lines) >= 3 else None
|
||||
_list_filter = "\n".join(lines[1:]).strip()
|
||||
|
||||
if not action:
|
||||
return {"error": "Missing action (rename|archive|delete|important|truncate|fork|list|switch)"}
|
||||
|
||||
# `list` alias — dispatch to do_list_sessions so the agent's natural
|
||||
# first guess (every other manage_* tool has a `list` action) works.
|
||||
if action == "list":
|
||||
return await do_list_sessions(_list_filter, session_id, owner=owner)
|
||||
|
||||
if not target_sid:
|
||||
return {"error": "Need a session_id (or 'current' for the active chat)"}
|
||||
|
||||
# Allow "current" to refer to the active session
|
||||
if target_sid.lower() == "current" and session_id:
|
||||
target_sid = session_id
|
||||
|
||||
# `switch` / `open` / `select` / `view` — the agent reaches for
|
||||
# these when the user asks to "open" or "switch to" a session.
|
||||
# There's no server-side way to make the browser navigate, so we
|
||||
# just return a clickable anchor link the user can click. The
|
||||
# frontend's chat-history click delegate routes `#session-<id>`
|
||||
# to selectSession(). The agent's reply naturally embeds this
|
||||
# result so the user sees a single clickable line.
|
||||
def _session_query(db):
|
||||
query = db.query(DbSession).filter(DbSession.id == target_sid)
|
||||
if owner is not None:
|
||||
query = query.filter(DbSession.owner == owner)
|
||||
return query
|
||||
|
||||
if action in ("switch", "open", "select", "view"):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
name = db_sess.name or target_sid
|
||||
finally:
|
||||
db.close()
|
||||
return {
|
||||
"action": action,
|
||||
"session_id": target_sid,
|
||||
"name": name,
|
||||
"results": f"[{name}](#session-{target_sid}) — click to open.",
|
||||
}
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if action == "rename":
|
||||
if not value:
|
||||
return {"error": "rename needs a new name (the `value` arg, or line 3 in the legacy format)"}
|
||||
new_name = value
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
db_sess.name = new_name
|
||||
db.commit()
|
||||
_session_manager.update_session_name(target_sid, new_name)
|
||||
return {"action": "rename", "session_id": target_sid, "name": new_name,
|
||||
"results": f"Session renamed to '{new_name}'"}
|
||||
|
||||
elif action == "archive":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
db_sess.archived = True
|
||||
db.commit()
|
||||
return {"action": "archive", "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name}' archived"}
|
||||
|
||||
elif action == "unarchive":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
db_sess.archived = False
|
||||
db.commit()
|
||||
return {"action": "unarchive", "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name}' unarchived"}
|
||||
|
||||
elif action == "delete":
|
||||
if target_sid == session_id:
|
||||
return {"error": "Cannot delete the current session while chatting in it. Delete other sessions first."}
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Refusing to delete an unknown chat id; use the exact id from list_sessions."}
|
||||
if db_sess and db_sess.is_important:
|
||||
return {"error": f"Session '{db_sess.name}' is starred/favorited. Unstar it first before deleting."}
|
||||
try:
|
||||
ok = _session_manager.delete_session(target_sid)
|
||||
if not ok:
|
||||
return {"error": f"Session '{target_sid}' was not deleted because it no longer exists."}
|
||||
return {"action": "delete", "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name or target_sid}' deleted"}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to delete session: {e}"}
|
||||
|
||||
elif action in ("important", "unimportant"):
|
||||
is_important = action == "important"
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
# Prevent AI from unstarring sessions — only the user can do that manually
|
||||
if not is_important and db_sess.is_important:
|
||||
return {"error": f"Session '{db_sess.name}' is starred by the user. Only the user can unstar sessions manually."}
|
||||
db_sess.is_important = is_important
|
||||
db.commit()
|
||||
status = "marked as important" if is_important else "unmarked as important"
|
||||
return {"action": action, "session_id": target_sid,
|
||||
"results": f"Session '{db_sess.name}' {status}"}
|
||||
|
||||
elif action == "truncate":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
keep_count = 10
|
||||
if value:
|
||||
try:
|
||||
keep_count = int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
success = _session_manager.truncate_messages(target_sid, keep_count)
|
||||
if success:
|
||||
return {"action": "truncate", "session_id": target_sid,
|
||||
"results": f"Session truncated to last {keep_count} messages"}
|
||||
return {"error": f"Failed to truncate session '{target_sid}'"}
|
||||
|
||||
elif action == "fork":
|
||||
db_sess = _session_query(db).first()
|
||||
if not db_sess:
|
||||
return {"error": f"Session '{target_sid}' not found. Use list_sessions and pass the exact id it returned."}
|
||||
keep_count = 0 # 0 = all messages
|
||||
if value:
|
||||
try:
|
||||
keep_count = int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
source = _session_manager.get_session(target_sid)
|
||||
if not source:
|
||||
return {"error": f"Session '{target_sid}' not found"}
|
||||
|
||||
new_sid = str(uuid.uuid4())[:8]
|
||||
_session_manager.create_session(
|
||||
session_id=new_sid,
|
||||
name=f"Fork: {source.name}",
|
||||
endpoint_url=source.endpoint_url,
|
||||
model=source.model,
|
||||
rag=False,
|
||||
owner=owner,
|
||||
)
|
||||
# Copy messages
|
||||
history = source.get_context_messages()
|
||||
if keep_count > 0:
|
||||
history = history[:keep_count]
|
||||
from core.models import ChatMessage as InMemoryMsg
|
||||
new_sess = _session_manager.get_session(new_sid)
|
||||
for msg in history:
|
||||
new_sess.add_message(InMemoryMsg(msg["role"], msg["content"]))
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", owner)
|
||||
except Exception:
|
||||
logger.debug("session_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {"action": "fork", "session_id": new_sid,
|
||||
"source_session": target_sid, "messages_copied": len(history),
|
||||
"results": f"Forked session '{source.name}' -> new session {new_sid} ({len(history)} messages)"}
|
||||
|
||||
else:
|
||||
return {"error": f"Unknown action '{action}'. Use: list, switch, rename, archive, unarchive, delete, important, unimportant, truncate, fork"}
|
||||
except Exception as e:
|
||||
logger.error(f"manage_session failed: {e}")
|
||||
return {"error": str(e)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory management tool
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1522,30 +1103,10 @@ async def dispatch_ai_tool(
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Dispatch an AI interaction tool. Returns (description, result_dict)."""
|
||||
|
||||
if tool == "create_session":
|
||||
name = content.split("\n")[0].strip()[:60]
|
||||
desc = f"create_session: {name}"
|
||||
result = await do_create_session(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "list_sessions":
|
||||
keyword = content.strip()[:40]
|
||||
desc = f"list_sessions{': ' + keyword if keyword else ''}"
|
||||
result = await do_list_sessions(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "send_to_session":
|
||||
sid = content.split("\n")[0].strip()[:20]
|
||||
desc = f"send_to_session: {sid}"
|
||||
result = await do_send_to_session(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "pipeline":
|
||||
if tool == "pipeline":
|
||||
desc = "pipeline: running steps"
|
||||
result = await do_pipeline(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "manage_session":
|
||||
action = content.split("\n")[0].strip()[:40]
|
||||
desc = f"manage_session: {action}"
|
||||
result = await do_manage_session(content, session_id, owner=owner)
|
||||
|
||||
elif tool == "manage_memory":
|
||||
action = content.split("\n")[0].strip()[:40]
|
||||
desc = f"manage_memory: {action}"
|
||||
|
||||
+11
-6
@@ -781,16 +781,21 @@ async def _execute_tool_block_impl(
|
||||
elif tool in ("chat_with_model", "ask_teacher", "list_models"):
|
||||
# Migrated to the agent_tools registry (#3629): dispatched through
|
||||
# TOOL_HANDLERS with the owner/session ctx these tools need, instead
|
||||
# of the legacy dispatch_ai_tool elif. The do_* impls stay in
|
||||
# ai_interaction.py (dispatch_ai_tool + the owner-scope test use them).
|
||||
# of the legacy dispatch_ai_tool elif. The impls live in
|
||||
# src/agent_tools/model_interaction_tools.py.
|
||||
first_line = content.split(chr(10))[0].strip()[:60]
|
||||
desc = f"{tool}: {first_line}" if first_line else tool
|
||||
result = await _document_tool_dispatch(tool, content, session_id, owner) \
|
||||
or {"error": f"{tool}: execution failed", "exit_code": 1}
|
||||
elif tool in ("create_session", "list_sessions",
|
||||
"send_to_session", "pipeline",
|
||||
"manage_session", "manage_memory",
|
||||
"ui_control"):
|
||||
elif tool in ("create_session", "list_sessions", "send_to_session", "manage_session"):
|
||||
# Migrated to the agent_tools registry (#3629): dispatched through
|
||||
# TOOL_HANDLERS with the owner/session ctx these tools need. The impls
|
||||
# live in src/agent_tools/session_tools.py.
|
||||
first_line = content.split(chr(10))[0].strip()[:60]
|
||||
desc = f"{tool}: {first_line}" if first_line else tool
|
||||
result = await _document_tool_dispatch(tool, content, session_id, owner) \
|
||||
or {"error": f"{tool}: execution failed", "exit_code": 1}
|
||||
elif tool in ("pipeline", "manage_memory", "ui_control"):
|
||||
from src.ai_interaction import dispatch_ai_tool
|
||||
desc, result = await dispatch_ai_tool(tool, content, session_id, owner=owner)
|
||||
elif tool == "manage_tasks":
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Tests for the session tools' move to the agent_tools registry (#3629):
|
||||
create_session, list_sessions, send_to_session, manage_session.
|
||||
|
||||
The implementations now live in src/agent_tools/session_tools.py (moved out of
|
||||
src/ai_interaction.py). These assert (1) the handlers are registered in
|
||||
TOOL_HANDLERS, (2) the moved logic runs and threads owner/session from ctx
|
||||
(the session manager is fetched via ai_interaction.get_session_manager), and
|
||||
(3) tool_execution.py dispatches them through the registry rather than the
|
||||
legacy dispatch_ai_tool elif.
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import src.ai_interaction as ai_interaction
|
||||
import src.database as database
|
||||
from src.agent_tools import TOOL_HANDLERS
|
||||
from src.agent_tools import session_tools as st
|
||||
|
||||
_SESSION_TOOLS = ("create_session", "list_sessions", "send_to_session", "manage_session")
|
||||
|
||||
|
||||
def test_session_tools_registered():
|
||||
for name in _SESSION_TOOLS:
|
||||
assert name in TOOL_HANDLERS, f"{name} missing from TOOL_HANDLERS"
|
||||
|
||||
|
||||
def test_list_sessions_handler_threads_ctx(monkeypatch):
|
||||
# The handler must thread content + session_id + owner from ctx into the
|
||||
# moved list_sessions implementation. Spy at the function boundary so the
|
||||
# test does not depend on list_sessions' DB internals.
|
||||
seen = {}
|
||||
|
||||
async def spy(content, session_id=None, owner=None):
|
||||
seen.update(content=content, session_id=session_id, owner=owner)
|
||||
return {"results": "ok"}
|
||||
|
||||
monkeypatch.setattr(st, "list_sessions", spy)
|
||||
res = asyncio.run(st.ListSessionsTool().execute("q", {"owner": "alice", "session_id": "s1"}))
|
||||
assert res == {"results": "ok"}
|
||||
assert seen == {"content": "q", "session_id": "s1", "owner": "alice"}
|
||||
|
||||
|
||||
def test_manage_session_list_delegates_to_list_sessions(monkeypatch):
|
||||
# manage_session("list") must delegate to list_sessions; guards against a
|
||||
# stale do_list_sessions reference surviving the move (caught live in e2e).
|
||||
called = {}
|
||||
|
||||
async def spy(content, session_id=None, owner=None):
|
||||
called["owner"] = owner
|
||||
return {"results": "ok"}
|
||||
|
||||
monkeypatch.setattr(st, "list_sessions", spy)
|
||||
# manage_session imports `Session` from src.database before the list branch;
|
||||
# the src.database test double may not expose it, so provide a stand-in.
|
||||
monkeypatch.setattr(database, "Session", object, raising=False)
|
||||
monkeypatch.setattr(ai_interaction, "_session_manager", object()) # truthy: pass the guard
|
||||
res = asyncio.run(st.ManageSessionTool().execute("list", {"owner": "carol"}))
|
||||
assert called.get("owner") == "carol"
|
||||
assert res == {"results": "ok"}
|
||||
|
||||
|
||||
def test_create_session_reaches_uuid_and_creates(monkeypatch):
|
||||
# Regression for the missing `import uuid` (PR review): create_session must
|
||||
# get past _resolve_model and mint a session id without NameError.
|
||||
monkeypatch.setattr(st, "_resolve_model", lambda spec, owner=None: ("http://x", "model-x", {}))
|
||||
created = {}
|
||||
|
||||
class FakeMgr:
|
||||
def create_session(self, **kw):
|
||||
created.update(kw)
|
||||
|
||||
def get_session(self, sid):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(ai_interaction, "_session_manager", FakeMgr())
|
||||
res = asyncio.run(st.CreateSessionTool().execute("My Chat\nmodel-x", {"owner": "alice"}))
|
||||
assert res.get("name") == "My Chat" and res.get("model") == "model-x"
|
||||
assert isinstance(res.get("session_id"), str) and res["session_id"]
|
||||
assert created.get("name") == "My Chat" # the uuid-minted id reached the manager
|
||||
|
||||
|
||||
def test_manage_session_fork_reaches_uuid(monkeypatch):
|
||||
# Regression for the missing `import uuid`: the fork action also mints a new
|
||||
# session id and must not NameError. Mocks the DB query layer so the fork
|
||||
# branch reaches the uuid call without a real sessions table.
|
||||
class FakeDbSession:
|
||||
id = "id"
|
||||
owner = "owner"
|
||||
|
||||
class FakeQ:
|
||||
def filter(self, *a, **k):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return object()
|
||||
|
||||
class FakeDB:
|
||||
def query(self, *a, **k):
|
||||
return FakeQ()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(database, "Session", FakeDbSession, raising=False)
|
||||
monkeypatch.setattr(database, "SessionLocal", lambda: FakeDB(), raising=False)
|
||||
|
||||
class Src:
|
||||
name = "Orig"
|
||||
endpoint_url = "http://x"
|
||||
model = "m"
|
||||
|
||||
def get_context_messages(self):
|
||||
return []
|
||||
|
||||
created = {}
|
||||
|
||||
class FakeMgr:
|
||||
def get_session(self, sid):
|
||||
return Src() if sid == "abc" else type("S", (), {"add_message": lambda self, m: None})()
|
||||
|
||||
def create_session(self, **kw):
|
||||
created.update(kw)
|
||||
|
||||
monkeypatch.setattr(ai_interaction, "_session_manager", FakeMgr())
|
||||
res = asyncio.run(st.ManageSessionTool().execute('{"action":"fork","session_id":"abc"}', {"owner": "owner"}))
|
||||
assert res.get("action") == "fork"
|
||||
assert isinstance(res.get("session_id"), str) and res["session_id"]
|
||||
assert created.get("name") == "Fork: Orig" # uuid-minted new session was created
|
||||
|
||||
|
||||
def test_no_session_manager_is_handled(monkeypatch):
|
||||
# With no session manager set, the moved function must fail gracefully
|
||||
# (proves the handler reached the impl, not an "unknown tool").
|
||||
monkeypatch.setattr(ai_interaction, "_session_manager", None)
|
||||
res = asyncio.run(st.ListSessionsTool().execute("", {"owner": "bob"}))
|
||||
assert isinstance(res, dict)
|
||||
assert "error" in res or "results" in res
|
||||
|
||||
|
||||
def test_dispatched_via_registry_not_dispatch_ai_tool():
|
||||
source = (Path(__file__).resolve().parent.parent / "src" / "tool_execution.py").read_text(encoding="utf-8")
|
||||
assert 'elif tool in ("create_session", "list_sessions", "send_to_session", "manage_session"):' in source
|
||||
|
||||
marker = "from src.ai_interaction import dispatch_ai_tool"
|
||||
idx = source.index(marker)
|
||||
branch_head = source.rfind("elif tool in (", 0, idx)
|
||||
legacy_tuple = source[branch_head:idx]
|
||||
for name in _SESSION_TOOLS:
|
||||
assert f'"{name}"' not in legacy_tuple, f"{name} still routed via dispatch_ai_tool"
|
||||
Reference in New Issue
Block a user