diff --git a/src/agent_tools/__init__.py b/src/agent_tools/__init__.py index 52fe4a99c..c2d910627 100644 --- a/src/agent_tools/__init__.py +++ b/src/agent_tools/__init__.py @@ -22,6 +22,7 @@ from .subprocess_tools import BashTool, PythonTool from .web_tools import WebSearchTool, WebFetchTool from .filesystem_tools import ReadFileTool, WriteFileTool, EditFileTool, LsTool, GlobTool, GrepTool, GetWorkspaceTool from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocumentTool, SuggestDocumentTool, ManageDocumentTool +from .model_interaction_tools import ChatWithModelTool, AskTeacherTool, ListModelsTool TOOL_HANDLERS = { "bash": BashTool().execute, @@ -40,6 +41,9 @@ TOOL_HANDLERS = { "suggest_document": SuggestDocumentTool().execute, "manage_documents": ManageDocumentTool().execute, "get_workspace": GetWorkspaceTool().execute, + "chat_with_model": ChatWithModelTool().execute, + "ask_teacher": AskTeacherTool().execute, + "list_models": ListModelsTool().execute, } # --------------------------------------------------------------------------- diff --git a/src/agent_tools/model_interaction_tools.py b/src/agent_tools/model_interaction_tools.py new file mode 100644 index 000000000..6cbabe919 --- /dev/null +++ b/src/agent_tools/model_interaction_tools.py @@ -0,0 +1,208 @@ +"""model_interaction_tools.py - agent tools for talking to other models. + +Owns the model-interaction tool implementations (chat_with_model, ask_teacher, +list_models) and their handler classes, registered in ``TOOL_HANDLERS``. Part +of the tool -> registry migration (#3629): the implementations were moved here +out of ``src.ai_interaction`` so dispatch flows through the registry instead of +the elif chain / dispatch_ai_tool in tool_execution.py. + +Shared helpers that still live in ``src.ai_interaction`` and are used by tools +not yet migrated (``_resolve_model``, ``AI_CHAT_TIMEOUT``) are imported lazily +inside the functions to avoid an import cycle at module load. +""" +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + + +_TEACHER_SYSTEM_PROMPT = ( + "You are a senior AI mentor. A less capable model is stuck on a problem and asking for help. " + "Provide clear, actionable guidance:\n" + "1. Brief analysis of the problem\n" + "2. Recommended approach (step by step)\n" + "3. Key things to watch out for\n\n" + "Be concise and practical. No preamble." +) + + +async def chat_with_model(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """Send a message to a specific model and return its response. + + Content format: + Line 1: model_name (or model_name@endpoint_name) + Line 2+: the message to send + """ + from src.ai_interaction import _resolve_model, AI_CHAT_TIMEOUT + from src.llm_core import llm_call_async + + lines = content.strip().split("\n", 1) + if not lines or not lines[0].strip(): + return {"error": "First line must be the model name"} + + model_spec = lines[0].strip() + message = lines[1].strip() if len(lines) > 1 else "" + if not message: + return {"error": "No message provided (line 2+ is the message)"} + + try: + url, model, headers = _resolve_model(model_spec, owner=owner) + except ValueError as e: + return {"error": str(e)} + + try: + response = await llm_call_async( + url, model, + [{"role": "user", "content": message}], + headers=headers, + timeout=AI_CHAT_TIMEOUT, + ) + # Truncate very long responses + if len(response) > 10000: + response = response[:10000] + "\n... (truncated)" + return {"model": model, "response": response} + except Exception as e: + logger.error(f"chat_with_model failed: {e}") + return {"error": f"Failed to get response from {model_spec}: {e}"} + + +async def ask_teacher(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """Ask a more capable model for help. + + Content format: + Line 1: model_name (or 'auto') + Line 2+: the problem description + """ + from src.ai_interaction import _resolve_model, AI_CHAT_TIMEOUT + from src.llm_core import llm_call_async + from src.settings import get_setting + + lines = content.strip().split("\n", 1) + model_spec = lines[0].strip() if lines else "auto" + problem = lines[1].strip() if len(lines) > 1 else "" + + if not problem: + return {"error": "No problem description provided"} + + if model_spec.lower() in ("auto", ""): + model_spec = get_setting("teacher_model", "") + if not model_spec: + return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."} + + try: + url, model, headers = _resolve_model(model_spec, owner=owner) + except ValueError as e: + return {"error": str(e)} + + try: + response = await llm_call_async( + url, model, + [ + {"role": "system", "content": _TEACHER_SYSTEM_PROMPT}, + {"role": "user", "content": f"Problem:\n{problem}"}, + ], + headers=headers, + timeout=AI_CHAT_TIMEOUT, + ) + if len(response) > 8000: + response = response[:8000] + "\n... (truncated)" + return {"model": model, "response": response, "teacher": True} + except Exception as e: + logger.error(f"ask_teacher failed: {e}") + return {"error": f"Teacher call failed ({model_spec}): {e}"} + + +async def list_models(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: + """List all available models across configured endpoints. + + Content = optional filter keyword. + """ + import json + import httpx + from src.database import SessionLocal, ModelEndpoint + from src.llm_core import _detect_provider, ANTHROPIC_MODELS + from src.auth_helpers import owner_filter + from src.endpoint_resolver import resolve_endpoint_runtime, build_headers, build_models_url + + keyword = content.strip().lower() if content.strip() else None + + db = SessionLocal() + try: + query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) + if owner: + query = owner_filter(query, ModelEndpoint, owner) + endpoints = query.all() + if not endpoints: + return {"results": "No enabled model endpoints configured."} + + result_lines = [] + total_models = 0 + + for ep in endpoints: + try: + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception: + continue + provider = _detect_provider(base) + headers = build_headers(api_key, base) + + model_ids = [] + if provider == "anthropic": + model_ids = list(ANTHROPIC_MODELS) + else: + try: + models_url = build_models_url(base) + if models_url: + r = httpx.get(models_url, headers=headers, timeout=5) + r.raise_for_status() + data = r.json() + model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not model_ids: + model_ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] + else: + model_ids = json.loads(ep.cached_models or "[]") + except Exception: + model_ids = ["(endpoint offline)"] + + if keyword: + model_ids = [m for m in model_ids if keyword in m.lower() or keyword in (ep.name or "").lower()] + + if model_ids: + result_lines.append(f"\n**{ep.name or base}** ({provider}):") + for mid in model_ids: + result_lines.append(f" - `{mid}`") + total_models += 1 + + if not result_lines: + return {"results": "No models found" + (f" matching '{keyword}'" if keyword else "") + "."} + + header = f"Available models ({total_models} total):" + return {"results": header + "\n".join(result_lines)} + except Exception as e: + logger.error(f"list_models failed: {e}") + return {"error": str(e)} + finally: + db.close() + + +# --------------------------------------------------------------------------- +# Handler classes registered in TOOL_HANDLERS +# --------------------------------------------------------------------------- + +class ChatWithModelTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await chat_with_model(content, ctx.get("session_id"), owner=ctx.get("owner")) + + +class AskTeacherTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await ask_teacher(content, ctx.get("session_id"), owner=ctx.get("owner")) + + +class ListModelsTool: + async def execute(self, content: str, ctx: dict) -> Dict: + return await list_models(content, ctx.get("session_id"), owner=ctx.get("owner")) diff --git a/src/ai_interaction.py b/src/ai_interaction.py index 33d5d28f7..667df8fb5 100644 --- a/src/ai_interaction.py +++ b/src/ai_interaction.py @@ -1,8 +1,12 @@ """ ai_interaction.py -AI-to-AI interaction tools: chat_with_model, create_session, list_sessions, -send_to_session, pipeline. +AI-to-AI interaction tools: create_session, list_sessions, send_to_session, +pipeline, plus shared model resolution (_resolve_model). + +chat_with_model, ask_teacher and list_models were moved to +src/agent_tools/model_interaction_tools.py as part of the tool -> registry +migration (#3629); they still reuse _resolve_model / AI_CHAT_TIMEOUT from here. These are agent tools — the LLM writes fenced code blocks and they execute through the standard agent_tools.py pipeline. @@ -159,242 +163,6 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di # Tool implementations # --------------------------------------------------------------------------- -async def do_chat_with_model(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """Send a message to a specific model and return its response. - - Content format: - Line 1: model_name (or model_name@endpoint_name) - Line 2+: the message to send - """ - from src.llm_core import llm_call_async - - lines = content.strip().split("\n", 1) - if not lines or not lines[0].strip(): - return {"error": "First line must be the model name"} - - model_spec = lines[0].strip() - message = lines[1].strip() if len(lines) > 1 else "" - if not message: - return {"error": "No message provided (line 2+ is the message)"} - - try: - url, model, headers = _resolve_model(model_spec, owner=owner) - except ValueError as e: - return {"error": str(e)} - - try: - response = await llm_call_async( - url, model, - [{"role": "user", "content": message}], - headers=headers, - timeout=AI_CHAT_TIMEOUT, - ) - # Truncate very long responses - if len(response) > 10000: - response = response[:10000] + "\n... (truncated)" - return {"model": model, "response": response} - except Exception as e: - logger.error(f"chat_with_model failed: {e}") - return {"error": f"Failed to get response from {model_spec}: {e}"} - - -_TEACHER_SYSTEM_PROMPT = ( - "You are a senior AI mentor. A less capable model is stuck on a problem and asking for help. " - "Provide clear, actionable guidance:\n" - "1. Brief analysis of the problem\n" - "2. Recommended approach (step by step)\n" - "3. Key things to watch out for\n\n" - "Be concise and practical. No preamble." -) - - -async def do_ask_teacher(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """Ask a more capable model for help. - - Content format: - Line 1: model_name (or 'auto') - Line 2+: the problem description - """ - from src.llm_core import llm_call_async - from src.settings import get_setting - - lines = content.strip().split("\n", 1) - model_spec = lines[0].strip() if lines else "auto" - problem = lines[1].strip() if len(lines) > 1 else "" - - if not problem: - return {"error": "No problem description provided"} - - if model_spec.lower() in ("auto", ""): - model_spec = get_setting("teacher_model", "") - if not model_spec: - return {"error": "No teacher model configured. Specify a model name or set teacher_model in settings."} - - try: - url, model, headers = _resolve_model(model_spec, owner=owner) - except ValueError as e: - return {"error": str(e)} - - try: - response = await llm_call_async( - url, model, - [ - {"role": "system", "content": _TEACHER_SYSTEM_PROMPT}, - {"role": "user", "content": f"Problem:\n{problem}"}, - ], - headers=headers, - timeout=AI_CHAT_TIMEOUT, - ) - if len(response) > 8000: - response = response[:8000] + "\n... (truncated)" - return {"model": model, "response": response, "teacher": True} - except Exception as e: - logger.error(f"ask_teacher failed: {e}") - return {"error": f"Teacher call failed ({model_spec}): {e}"} - - -async def do_second_opinion(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """Get a second opinion from another model, then have the original model - evaluate the feedback and produce a unified version. - - Content format: - Line 1: model_name (or model_name@endpoint_name) - Line 2+ (optional): specific question or focus area - - Flow: - 1. Pull recent conversation context - 2. Send to reviewer model → get honest feedback - 3. Send feedback back to the session's own model → evaluate & unify - 4. Return both the review and the unified response - """ - from src.llm_core import llm_call_async - - lines = content.strip().split("\n", 1) - if not lines or not lines[0].strip(): - return {"error": "First line must be the model name"} - - model_spec = lines[0].strip() - focus = lines[1].strip() if len(lines) > 1 else "" - - try: - reviewer_url, reviewer_model, reviewer_headers = _resolve_model(model_spec, owner=owner) - except ValueError as e: - return {"error": str(e)} - - # Pull recent conversation context from current session - context_text = "" - sess = None - if session_id and _session_manager: - sess = _session_manager.get_session(session_id) - if sess: - messages = sess.get_context_messages() - recent = messages[-15:] if len(messages) > 15 else messages - parts = [] - for m in recent: - role = m.get("role", "unknown").upper() - text = m.get("content", "") - if isinstance(text, list): - text = " ".join( - p.get("text", "") for p in text if isinstance(p, dict) - ) - if text: - parts.append(f"[{role}]: {text[:2000]}") - context_text = "\n\n".join(parts) - - if not context_text: - return {"error": "No conversation context found to review"} - - # ── Step 1: Get the reviewer's feedback ── - reviewer_system = ( - "You are giving a second opinion on a conversation between a user and an AI assistant. " - "Your job is to be genuinely helpful and honest — not a yes-man, but not a contrarian either.\n\n" - "Guidelines:\n" - "- If the plan/idea is solid, say so clearly. Don't manufacture problems that aren't there.\n" - "- If you spot a real flaw, blind spot, or simpler approach — call it out directly.\n" - "- Be practical. Don't over-engineer or over-analyze. Real-world tradeoffs matter.\n" - "- If there's a meaningfully better way to do something, suggest it concretely.\n" - "- Give credit where it's due — highlight what's working well.\n" - "- Keep it concise and actionable. No fluff.\n" - "- You're a second pair of eyes, not a professor grading a paper." - ) - - reviewer_message = f"Here's the conversation so far:\n\n{context_text}" - if focus: - reviewer_message += f"\n\n---\nSpecifically, I want your take on: {focus}" - else: - reviewer_message += "\n\n---\nGive me your honest second opinion on what's being discussed." - - try: - review = await llm_call_async( - reviewer_url, reviewer_model, - [ - {"role": "system", "content": reviewer_system}, - {"role": "user", "content": reviewer_message}, - ], - headers=reviewer_headers, - timeout=AI_CHAT_TIMEOUT, - ) - if len(review) > 8000: - review = review[:8000] + "\n... (truncated)" - except Exception as e: - logger.error(f"second_opinion reviewer call failed: {e}") - return {"error": f"Failed to get second opinion from {model_spec}: {e}"} - - # ── Step 2: Send review back to session's own model for evaluation ── - unified = "" - original_model = "unknown" - if sess: - original_url = sess.endpoint_url - original_model = sess.model - original_headers = getattr(sess, "headers", None) or {} - - unify_system = ( - "Another AI model just reviewed the conversation you've been having with the user. " - "Read their feedback carefully, then respond with:\n\n" - "1. **What you agree with** — acknowledge valid points honestly.\n" - "2. **What you disagree with** — explain why, briefly.\n" - "3. **Unified version** — produce an updated/refined version of whatever was being discussed, " - "incorporating the feedback you found valid. Don't accept every note blindly — " - "use your judgment on what actually improves things vs what's unnecessary.\n\n" - "Be concise and practical. The user wants a better result, not a meta-discussion." - ) - - unify_message = ( - f"Here's the conversation context:\n\n{context_text}\n\n" - f"---\n\n" - f"**Review from {reviewer_model}:**\n\n{review}\n\n" - f"---\n\n" - f"Evaluate this feedback and produce a unified improved version." - ) - - try: - unified = await llm_call_async( - original_url, original_model, - [ - {"role": "system", "content": unify_system}, - {"role": "user", "content": unify_message}, - ], - headers=original_headers, - timeout=AI_CHAT_TIMEOUT, - ) - if len(unified) > 10000: - unified = unified[:10000] + "\n... (truncated)" - except Exception as e: - logger.error(f"second_opinion unify call failed: {e}") - unified = f"(Failed to get unified response: {e})" - - # Build combined result - combined = ( - f"## Second Opinion from {reviewer_model}\n\n{review}" - f"\n\n---\n\n" - f"## {original_model}'s Response\n\n{unified}" - ) - - return { - "model": reviewer_model, - "response": combined, - "instruction": "Present these results to the user exactly as they are. Do NOT call second_opinion again. The user can continue the conversation from here.", - } async def do_create_session(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: @@ -1104,83 +872,6 @@ async def do_manage_memory(content: str, session_id: Optional[str] = None, owner return {"error": f"Unknown action '{action}'. Use: list, add, edit, delete, search"} -# --------------------------------------------------------------------------- -# List models tool -# --------------------------------------------------------------------------- - -async def do_list_models(content: str, session_id: Optional[str] = None, owner: Optional[str] = None) -> Dict: - """List all available models across configured endpoints. - - Content = optional filter keyword. - """ - import httpx - from src.database import SessionLocal, ModelEndpoint - from src.llm_core import _detect_provider, ANTHROPIC_MODELS - from src.auth_helpers import owner_filter - - keyword = content.strip().lower() if content.strip() else None - - db = SessionLocal() - try: - query = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) - if owner: - query = owner_filter(query, ModelEndpoint, owner) - endpoints = query.all() - if not endpoints: - return {"results": "No enabled model endpoints configured."} - - result_lines = [] - total_models = 0 - - for ep in endpoints: - try: - base, api_key = resolve_endpoint_runtime(ep, owner=owner) - except Exception: - continue - provider = _detect_provider(base) - headers = build_headers(api_key, base) - - model_ids = [] - if provider == "anthropic": - model_ids = list(ANTHROPIC_MODELS) - else: - try: - models_url = build_models_url(base) - if models_url: - r = httpx.get(models_url, headers=headers, timeout=5) - r.raise_for_status() - data = r.json() - model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if not model_ids: - model_ids = [ - m.get("name") or m.get("model") - for m in (data.get("models") or []) - if m.get("name") or m.get("model") - ] - else: - model_ids = json.loads(ep.cached_models or "[]") - except Exception: - model_ids = ["(endpoint offline)"] - - if keyword: - model_ids = [m for m in model_ids if keyword in m.lower() or keyword in (ep.name or "").lower()] - - if model_ids: - result_lines.append(f"\n**{ep.name or base}** ({provider}):") - for mid in model_ids: - result_lines.append(f" - `{mid}`") - total_models += 1 - - if not result_lines: - return {"results": "No models found" + (f" matching '{keyword}'" if keyword else "") + "."} - - header = f"Available models ({total_models} total):" - return {"results": header + "\n".join(result_lines)} - except Exception as e: - logger.error(f"list_models failed: {e}") - return {"error": str(e)} - finally: - db.close() # --------------------------------------------------------------------------- @@ -1831,12 +1522,7 @@ async def dispatch_ai_tool( ) -> Tuple[str, Dict]: """Dispatch an AI interaction tool. Returns (description, result_dict).""" - if tool == "chat_with_model": - model_spec = content.split("\n")[0].strip()[:60] - desc = f"chat_with_model: {model_spec}" - result = await do_chat_with_model(content, session_id, owner=owner) - - elif tool == "create_session": + if tool == "create_session": name = content.split("\n")[0].strip()[:60] desc = f"create_session: {name}" result = await do_create_session(content, session_id, owner=owner) @@ -1865,21 +1551,11 @@ async def dispatch_ai_tool( desc = f"manage_memory: {action}" result = await do_manage_memory(content, session_id, owner=owner) - elif tool == "list_models": - keyword = content.strip()[:40] - desc = f"list_models{': ' + keyword if keyword else ''}" - result = await do_list_models(content, session_id, owner=owner) - elif tool == "ui_control": action = content.split("\n")[0].strip()[:60] desc = f"ui_control: {action}" result = await do_ui_control(content, session_id, owner=owner) - elif tool == "ask_teacher": - problem = content.split("\n", 1)[-1].strip()[:60] - desc = f"ask_teacher: {problem}" - result = await do_ask_teacher(content, session_id, owner=owner) - else: desc = f"unknown ai tool: {tool}" result = {"error": f"Unknown AI interaction tool: {tool}"} diff --git a/src/tool_execution.py b/src/tool_execution.py index 8f3f7ed6f..05022bdba 100644 --- a/src/tool_execution.py +++ b/src/tool_execution.py @@ -766,10 +766,19 @@ async def _execute_tool_block_impl( query = content.split("\n")[0].strip() desc = f"search_chats: {query[:80]}" result = await do_search_chats(query, owner=owner) - elif tool in ("chat_with_model", "create_session", "list_sessions", + elif tool in ("chat_with_model", "ask_teacher", "list_models"): + # Migrated to the agent_tools registry (#3629): dispatched through + # TOOL_HANDLERS with the owner/session ctx these tools need, instead + # of the legacy dispatch_ai_tool elif. The do_* impls stay in + # ai_interaction.py (dispatch_ai_tool + the owner-scope test use them). + first_line = content.split(chr(10))[0].strip()[:60] + desc = f"{tool}: {first_line}" if first_line else tool + result = await _document_tool_dispatch(tool, content, session_id, owner) \ + or {"error": f"{tool}: execution failed", "exit_code": 1} + elif tool in ("create_session", "list_sessions", "send_to_session", "pipeline", - "manage_session", "manage_memory", "list_models", - "ui_control", "ask_teacher"): + "manage_session", "manage_memory", + "ui_control"): from src.ai_interaction import dispatch_ai_tool desc, result = await dispatch_ai_tool(tool, content, session_id, owner=owner) elif tool == "manage_tasks": diff --git a/static/js/admin.js b/static/js/admin.js index bd63e10db..58b8765a5 100644 --- a/static/js/admin.js +++ b/static/js/admin.js @@ -1756,7 +1756,6 @@ const TOOL_META = { manage_skills: { name: 'Skills', desc: 'Learn and use procedures', cat: 'Knowledge', ctx: '~200' }, manage_rag: { name: 'RAG / Docs', desc: 'Query indexed documents', cat: 'Knowledge', ctx: '~150' }, chat_with_model: { name: 'Chat with Model', desc: 'Talk to another AI model', cat: 'Multi-Agent', ctx: '~200' }, - second_opinion: { name: 'Second Opinion', desc: 'Get another model\'s take', cat: 'Multi-Agent', ctx: '~150' }, pipeline: { name: 'Pipeline', desc: 'Multi-step AI workflows', cat: 'Multi-Agent', ctx: '~200' }, ask_teacher: { name: 'Ask Teacher', desc: 'Query a more capable model', cat: 'Multi-Agent', ctx: '~150' }, send_to_session: { name: 'Send to Session', desc: 'Send message to another chat', cat: 'Sessions', ctx: '~100' }, diff --git a/static/js/assistant.js b/static/js/assistant.js index dca4bd55f..b4b9dc3cc 100644 --- a/static/js/assistant.js +++ b/static/js/assistant.js @@ -125,7 +125,7 @@ const TOOL_GROUPS = { 'Knowledge': ['web_search', 'read_file', 'manage_memory', 'manage_rag', 'search_chats'], 'Code': ['bash', 'python', 'write_file'], 'Documents': ['create_document', 'edit_document', 'update_document', 'suggest_document'], - 'AI & Models': ['chat_with_model', 'second_opinion', 'ask_teacher', 'pipeline', 'list_models', 'generate_image'], + 'AI & Models': ['chat_with_model', 'ask_teacher', 'pipeline', 'list_models', 'generate_image'], 'System': ['manage_session', 'manage_endpoints', 'manage_mcp', 'manage_settings', 'manage_skills', 'manage_webhooks', 'manage_tokens', 'manage_documents', 'create_session', 'list_sessions', 'send_to_session', 'ui_control'], }; diff --git a/tests/test_ai_interaction_owner_scope.py b/tests/test_ai_interaction_owner_scope.py index 7b2ac63bd..1cfe31c23 100644 --- a/tests/test_ai_interaction_owner_scope.py +++ b/tests/test_ai_interaction_owner_scope.py @@ -3,6 +3,7 @@ import inspect import pytest from src import ai_interaction +from src.agent_tools import model_interaction_tools def _source(fn) -> str: @@ -18,7 +19,8 @@ def test_model_resolver_applies_owner_filter(): def test_model_listing_and_image_fallback_are_owner_scoped(): - list_body = _source(ai_interaction.do_list_models) + # list_models moved to agent_tools.model_interaction_tools (#3629). + list_body = _source(model_interaction_tools.list_models) image_body = _source(ai_interaction.do_generate_image) assert "owner: Optional[str] = None" in list_body @@ -28,12 +30,13 @@ def test_model_listing_and_image_fallback_are_owner_scoped(): assert "_resolve_model(model_spec, owner=owner)" in image_body +# chat_with_model, list_models and ask_teacher moved to the registry (#3629) +# and no longer route through dispatch_ai_tool; their owner threading is covered +# by tests/test_model_interaction_registry.py. The remaining model-ish tools +# still dispatched here: @pytest.mark.parametrize("tool,content", [ - ("chat_with_model", "gpt-test\nhello"), ("pipeline", "gpt-test | summarize this"), - ("list_models", ""), ("ui_control", "switch_model gpt-test"), - ("ask_teacher", "gpt-test\nhelp me"), ]) async def test_dispatch_passes_owner_to_model_tools(monkeypatch, tool, content): seen = {} @@ -42,31 +45,16 @@ async def test_dispatch_passes_owner_to_model_tools(monkeypatch, tool, content): seen[name] = {"content": content, "session_id": session_id, "owner": owner} return {"ok": True} - monkeypatch.setattr( - ai_interaction, - "do_chat_with_model", - lambda content, session_id=None, owner=None: capture("chat_with_model", content, session_id, owner), - ) monkeypatch.setattr( ai_interaction, "do_pipeline", lambda content, session_id=None, owner=None: capture("pipeline", content, session_id, owner), ) - monkeypatch.setattr( - ai_interaction, - "do_list_models", - lambda content, session_id=None, owner=None: capture("list_models", content, session_id, owner), - ) monkeypatch.setattr( ai_interaction, "do_ui_control", lambda content, session_id=None, owner=None: capture("ui_control", content, session_id, owner), ) - monkeypatch.setattr( - ai_interaction, - "do_ask_teacher", - lambda content, session_id=None, owner=None: capture("ask_teacher", content, session_id, owner), - ) _desc, result = await ai_interaction.dispatch_ai_tool(tool, content, session_id="sid1", owner="alice") diff --git a/tests/test_model_interaction_registry.py b/tests/test_model_interaction_registry.py new file mode 100644 index 000000000..fcfdef3e6 --- /dev/null +++ b/tests/test_model_interaction_registry.py @@ -0,0 +1,104 @@ +"""Tests for the model-interaction tools after their move to the agent_tools +registry (#3629): chat_with_model, ask_teacher, list_models. + +The implementations now live in src/agent_tools/model_interaction_tools.py +(moved out of src/ai_interaction.py). These assert (1) the handlers are +registered in TOOL_HANDLERS, (2) each handler runs the moved logic and threads +session_id/owner from the ctx, and (3) tool_execution.py dispatches them +through the registry rather than the legacy dispatch_ai_tool elif. +""" +import asyncio +from pathlib import Path + +import src.ai_interaction as ai_interaction +import src.llm_core as llm_core +import src.database as database +from src.agent_tools import TOOL_HANDLERS +from src.agent_tools import model_interaction_tools as mit + +_MODEL_TOOLS = ("chat_with_model", "ask_teacher", "list_models") + + +def test_model_interaction_tools_registered(): + for name in _MODEL_TOOLS: + assert name in TOOL_HANDLERS, f"{name} missing from TOOL_HANDLERS" + + +def test_chat_with_model_threads_owner_and_returns(monkeypatch): + seen = {} + + def fake_resolve(spec, owner=None): + seen["spec"] = spec + seen["owner"] = owner + return ("http://x", "model-x", {}) + + async def fake_call(url, model, messages, headers=None, timeout=None): + seen["message"] = messages[-1]["content"] + return "hi back" + + monkeypatch.setattr(ai_interaction, "_resolve_model", fake_resolve) + monkeypatch.setattr(llm_core, "llm_call_async", fake_call) + + res = asyncio.run(mit.ChatWithModelTool().execute( + "model-x\nhello there", {"owner": "alice", "session_id": "s1"})) + + assert res == {"model": "model-x", "response": "hi back"} + assert seen["owner"] == "alice" + assert seen["spec"] == "model-x" + assert seen["message"] == "hello there" + + +def test_ask_teacher_threads_owner_and_marks_teacher(monkeypatch): + seen = {} + + def fake_resolve(spec, owner=None): + seen["owner"] = owner + return ("http://x", "teacher-x", {}) + + async def fake_call(url, model, messages, headers=None, timeout=None): + return "do this and that" + + monkeypatch.setattr(ai_interaction, "_resolve_model", fake_resolve) + monkeypatch.setattr(llm_core, "llm_call_async", fake_call) + + res = asyncio.run(mit.AskTeacherTool().execute( + "teacher-x\nI am stuck", {"owner": "bob"})) + + assert res["teacher"] is True + assert res["response"] == "do this and that" + assert seen["owner"] == "bob" + + +def test_list_models_no_endpoints(monkeypatch): + class _Q: + def filter(self, *a, **k): + return self + + def all(self): + return [] + + class _S: + def query(self, *a, **k): + return _Q() + + def close(self): + pass + + monkeypatch.setattr(database, "SessionLocal", lambda: _S()) + + res = asyncio.run(mit.ListModelsTool().execute("", {})) + assert res == {"results": "No enabled model endpoints configured."} + + +def test_dispatched_via_registry_not_dispatch_ai_tool(): + """The model tools route through the registry (_document_tool_dispatch), and + are no longer in the dispatch_ai_tool elif tuple.""" + source = (Path(__file__).resolve().parent.parent / "src" / "tool_execution.py").read_text(encoding="utf-8") + assert 'elif tool in ("chat_with_model", "ask_teacher", "list_models"):' in source + + marker = "from src.ai_interaction import dispatch_ai_tool" + idx = source.index(marker) + branch_head = source.rfind("elif tool in (", 0, idx) + legacy_tuple = source[branch_head:idx] + for name in _MODEL_TOOLS: + assert f'"{name}"' not in legacy_tuple, f"{name} still routed via dispatch_ai_tool"