From 5ce2056521bb68c18f7ddd882cc9a9c18ba8f417 Mon Sep 17 00:00:00 2001 From: Kenny Van de Maele Date: Wed, 24 Jun 2026 09:29:10 +0200 Subject: [PATCH] refactor(tools): migrate config/integration admin tools to the registry (#4742) Part of #3629 (the `admin_tools.py` bullet). Moves the config/integration admin tools off the legacy elif dispatch chain in tool_implementations.py onto the agent_tools registry: manage_endpoints, manage_mcp, manage_webhooks, manage_tokens, manage_settings The do_* implementations (and manage_mcp's command-allowlist / RCE guard: _validate_mcp_command, _mcp_allowed_commands, and the _MCP_* constants) move verbatim into the new src/agent_tools/admin_tools.py. They register through a single ADMIN_TOOL_HANDLERS map that TOOL_HANDLERS.update()s, and the five elif branches plus their imports are dropped from tool_execution.py, so these tools now flow through _direct_fallback like the other migrated clusters. The names are re-exported from src.agent_tools for back-compat. Dedup: - _parse_tool_args was duplicated in tool_implementations.py and document_tools.py. It now lives once in src.tool_utils (which imports nothing from the project beyond src.constants, so this introduces no cycle) and both call sites import it from there. The orphaned `import json` in document_tools is removed with it. - The five tools share one _owner_adapter(fn) factory that threads ctx["owner"] into the owner-taking do_* signature, instead of five near-identical wrappers. Tests: new tests/test_admin_tools_registry.py pins the registration, the re-export back-compat, the owner-threading adapter, and the single-source _parse_tool_args (across admin_tools and document_tools). Existing MCP / settings / webhook suites are repointed at the new module. --- src/agent_tools/__init__.py | 12 +- src/agent_tools/admin_tools.py | 784 ++++++++++++++++++++ src/agent_tools/document_tools.py | 34 +- src/tool_execution.py | 24 +- src/tool_implementations.py | 787 +-------------------- src/tool_utils.py | 35 + tests/test_admin_tools_registry.py | 69 ++ tests/test_context_budget.py | 3 +- tests/test_manage_mcp_command_allowlist.py | 4 +- tests/test_manage_settings_token_budget.py | 2 +- tests/test_mcp_reconnect_args.py | 4 +- tests/test_review_regressions.py | 2 +- 12 files changed, 912 insertions(+), 848 deletions(-) create mode 100644 src/agent_tools/admin_tools.py create mode 100644 tests/test_admin_tools_registry.py diff --git a/src/agent_tools/__init__.py b/src/agent_tools/__init__.py index ba39b4f53..b4a796850 100644 --- a/src/agent_tools/__init__.py +++ b/src/agent_tools/__init__.py @@ -25,6 +25,11 @@ from .document_tools import CreateDocumentTool, UpdateDocumentTool, EditDocument from .model_interaction_tools import ChatWithModelTool, AskTeacherTool, ListModelsTool from .bg_job_tools import ManageBgJobsTool from .session_tools import CreateSessionTool, ListSessionsTool, SendToSessionTool, ManageSessionTool +from .admin_tools import ( + ADMIN_TOOL_HANDLERS, + do_manage_endpoints, do_manage_mcp, do_manage_webhooks, + do_manage_tokens, do_manage_settings, +) TOOL_HANDLERS = { "bash": BashTool().execute, @@ -52,6 +57,8 @@ TOOL_HANDLERS = { "send_to_session": SendToSessionTool().execute, "manage_session": ManageSessionTool().execute, } +# Config/integration admin tools (manage_endpoints/mcp/webhooks/tokens/settings). +TOOL_HANDLERS.update(ADMIN_TOOL_HANDLERS) # --------------------------------------------------------------------------- # Constants (re-exported for backward compatibility — single source of truth @@ -138,10 +145,5 @@ from src.tool_implementations import ( # noqa: E402, F401 do_search_chats, do_manage_skills, do_manage_tasks, - do_manage_endpoints, - do_manage_mcp, - do_manage_webhooks, - do_manage_tokens, - do_manage_settings, do_api_call, ) diff --git a/src/agent_tools/admin_tools.py b/src/agent_tools/admin_tools.py new file mode 100644 index 000000000..bb54b6cbd --- /dev/null +++ b/src/agent_tools/admin_tools.py @@ -0,0 +1,784 @@ +"""Config/integration admin agent tools (TOOL_HANDLERS). + +Moved verbatim from tool_implementations.py as part of the tool-registry +migration (#3629, the `admin_tools.py` bullet): manage_endpoints / manage_mcp / +manage_webhooks / manage_tokens / manage_settings, plus manage_mcp's +command-allowlist guard. Each impl keeps its `do_*(content, owner)` shape; +ADMIN_TOOL_HANDLERS wraps them into registry `execute(content, ctx)` adapters +via one factory. +""" +import json +import os +import re +import logging +from typing import Optional, Dict + +from src.tool_utils import get_mcp_manager, _parse_tool_args + +logger = logging.getLogger(__name__) + + +async def do_manage_endpoints(content: str, owner: Optional[str] = None) -> Dict: + """Manage model endpoints: list, add, delete, enable, disable.""" + from core.database import SessionLocal, ModelEndpoint + try: + args = _parse_tool_args(content) + except ValueError: + return {"error": "Invalid JSON arguments", "exit_code": 1} + + action = args.get("action", "list") + db = SessionLocal() + try: + if action == "list": + eps = db.query(ModelEndpoint).all() + items = [{"id": e.id, "name": e.name, "base_url": e.base_url, + "is_enabled": e.is_enabled} for e in eps] + return {"response": f"{len(items)} endpoints", "endpoints": items, "exit_code": 0} + + elif action == "add": + import uuid as _uuid + name = args.get("name", "") + base_url = args.get("base_url", "") + api_key = args.get("api_key", "") + if not base_url: + return {"error": "base_url is required", "exit_code": 1} + eid = str(_uuid.uuid4())[:8] + from datetime import datetime + ep = ModelEndpoint(id=eid, name=name or base_url, base_url=base_url, + api_key=api_key, is_enabled=True, + created_at=datetime.utcnow(), updated_at=datetime.utcnow()) + db.add(ep) + db.commit() + return {"response": f"Added endpoint '{name or base_url}' (id: {eid})", "exit_code": 0} + + elif action == "delete": + eid = args.get("endpoint_id", "") + ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == eid).first() + if not ep: + return {"error": f"Endpoint {eid} not found", "exit_code": 1} + name = ep.name + db.delete(ep) + db.commit() + return {"response": f"Deleted endpoint '{name}'", "exit_code": 0} + + elif action in ("enable", "disable"): + eid = args.get("endpoint_id", "") + ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == eid).first() + if not ep: + return {"error": f"Endpoint {eid} not found", "exit_code": 1} + ep.is_enabled = (action == "enable") + db.commit() + return {"response": f"Endpoint '{ep.name}' {action}d", "exit_code": 0} + + else: + return {"error": f"Unknown action: {action}", "exit_code": 1} + except Exception as e: + logger.error(f"manage_endpoints error: {e}") + return {"error": str(e), "exit_code": 1} + finally: + db.close() + + +# --------------------------------------------------------------------------- +# MCP server management tool +# --------------------------------------------------------------------------- + +# Parallel to routes/cookbook_helpers._validate_serve_cmd but deliberately the +# opposite policy: that gate guards an admin-only serve command and allows +# interpreters (python3/etc) because model-serving needs them, whereas this is +# the model/prompt-injection-reachable manage_mcp path, so interpreters and +# runners are denied here. +# +# Commands that can execute arbitrary code regardless of their arguments. These +# are NEVER accepted on the manage_mcp agent path, even if an operator lists one +# in ODYSSEUS_MCP_ALLOWED_COMMANDS -- a stdio server that genuinely needs an +# interpreter or package runner must be registered via the trusted admin route. +_MCP_DENIED_COMMANDS = frozenset({ + "sh", "bash", "zsh", "fish", "dash", "ksh", "csh", "tcsh", "ash", "busybox", + "cmd", "command.com", "powershell", "pwsh", + "python", "pypy", "node", "nodejs", "deno", "bun", "ruby", "jruby", + "perl", "raku", "php", "lua", "luajit", "tclsh", "wish", "expect", "rscript", + "groovy", "scala", "elixir", "erl", "iex", "java", "javac", "jshell", "jbang", + "kotlin", "kotlinc", "dotnet", "mono", "swift", "osascript", "tsx", "ts-node", + "npx", "bunx", "uvx", "pipx", "npm", "pnpm", "yarn", "pip", "uv", + "gem", "cargo", "go", "bundle", "poetry", "conda", "mamba", "brew", + "apt", "apt-get", "yum", "dnf", "pacman", "apk", + "env", "xargs", "nohup", "setsid", "nice", "ionice", "time", "timeout", + "watch", "stdbuf", "unbuffer", "script", "ssh", "scp", "sshpass", "sudo", + "doas", "su", "make", "cmake", "docker", "podman", "kubectl", "find", + "awk", "gawk", "sed", "vi", "vim", "nvim", "emacs", "ed", "tee", "eval", +}) + +# Argv flags that make even an allowlisted binary execute inline code. Matched +# by prefix so glued forms (-cimport os, --eval=...) are caught, not just the +# exact-token form. +_MCP_CODE_EXEC_SHORT_FLAGS = ("-c", "-e", "-m") +_MCP_CODE_EXEC_LONG_FLAGS = ("--eval", "--exec", "--print", "--module", "--command", "--require") + +_MCP_URL_SCHEMES = ("http://", "https://", "ftp://", "ftps://", "file://", "data:", "jar:", "blob:") + +# Shell metacharacters refused in command/args. Args are passed as an argv list +# (no shell), but refusing these keeps the surface narrow and obvious. +_MCP_SHELL_METACHARS = set(";|&$`><\n\r") + +# Env vars that let a child process load attacker-supplied code before main(). +_MCP_DANGEROUS_ENV = frozenset({ + "LD_PRELOAD", "LD_LIBRARY_PATH", "LD_AUDIT", "DYLD_INSERT_LIBRARIES", + "DYLD_LIBRARY_PATH", "DYLD_FRAMEWORK_PATH", "PYTHONPATH", "PYTHONSTARTUP", + "PYTHONHOME", "PYTHONEXECUTABLE", "NODE_OPTIONS", "NODE_PATH", "BASH_ENV", + "ENV", "SHELLOPTS", "PERL5LIB", "PERL5OPT", "RUBYOPT", "RUBYLIB", "GEM_PATH", + "R_PROFILE", "R_HOME", "PATH", "IFS", "PROMPT_COMMAND", +}) + + +def _mcp_allowed_commands() -> set: + """Operator-configured allowlist of safe MCP launcher basenames for the agent + path. Empty by default; set ODYSSEUS_MCP_ALLOWED_COMMANDS (comma-separated) + to opt specific trusted binaries in. Denied commands are rejected even if + listed here.""" + raw = os.environ.get("ODYSSEUS_MCP_ALLOWED_COMMANDS", "") + return {c.strip().lower() for c in raw.split(",") if c.strip()} + + +def _validate_mcp_command(command, args, env) -> Optional[str]: + """Validate a model-supplied stdio MCP registration. Returns an error string + if it must be rejected, else None. + + Closes the RCE where manage_mcp 'add' passed prompt-injection-controlled + command/args/env straight to a subprocess spawn (issue #438): a payload + smuggled into a skill description, memory entry, fetched page, or email body + could register a stdio server running arbitrary code as the app UID. + """ + if not isinstance(command, str) or not command.strip(): + return "command must be a non-empty string" + command = command.strip() + if "/" in command or "\\" in command: + return "command must be a bare executable name, not a path" + if any(ch in _MCP_SHELL_METACHARS for ch in command): + return "command contains shell metacharacters" + base = command.lower() + if base.endswith(".exe") or base.endswith(".cmd") or base.endswith(".bat"): + base = base.rsplit(".", 1)[0] + # Canonicalize a trailing version suffix so versioned aliases collapse to the + # family name (python3.11 -> python, node18 -> node, pip3 -> pip); both the + # raw basename and the canonical form are denied, so an operator cannot + # accidentally allowlist a runtime alias back into the path. + canon = re.sub(r"[-_.]?\d+(?:\.\d+)*$", "", base) + if base in _MCP_DENIED_COMMANDS or canon in _MCP_DENIED_COMMANDS: + return ( + f"command '{command}' is not allowed on the agent MCP path: " + "interpreters, runtimes, package runners, and shells can execute " + "arbitrary code. Register such a server via the admin route instead." + ) + if base not in _mcp_allowed_commands(): + return ( + f"command '{command}' is not in the MCP allowlist. Add it to " + "ODYSSEUS_MCP_ALLOWED_COMMANDS if you trust it, or register the " + "server via the admin route." + ) + + if args is not None: + if isinstance(args, str): + try: + args = json.loads(args) + except Exception: + return "args must be a JSON list" + if not isinstance(args, list): + return "args must be a list" + for a in args: + if not isinstance(a, str): + return "args must all be strings" + s = a.strip() + low = s.lower() + if any(s == f or s.startswith(f) for f in _MCP_CODE_EXEC_SHORT_FLAGS): + return f"arg '{a}' is a code-execution flag and is not allowed" + if any(low == f or low.startswith(f + "=") for f in _MCP_CODE_EXEC_LONG_FLAGS): + return f"arg '{a}' is a code-execution flag and is not allowed" + if any(low.startswith(u) for u in _MCP_URL_SCHEMES): + return f"arg '{a}' is a remote URL and is not allowed" + if any(ch in _MCP_SHELL_METACHARS for ch in a): + return f"arg '{a}' contains shell metacharacters" + + if env: + if isinstance(env, str): + try: + env = json.loads(env) + except Exception: + return "env must be a JSON object" + if not isinstance(env, dict): + return "env must be an object" + for k in env: + if str(k).strip().upper() in _MCP_DANGEROUS_ENV: + return f"env var '{k}' can inject code into the child process and is not allowed" + + return None + + +async def do_manage_mcp(content: str, owner: Optional[str] = None) -> Dict: + """Manage MCP servers: list, add, delete, enable, disable, reconnect.""" + try: + args = _parse_tool_args(content) + except ValueError: + return {"error": "Invalid JSON arguments", "exit_code": 1} + + action = args.get("action", "list") + + if action == "list": + mcp = get_mcp_manager() + if not mcp: + return {"response": "No MCP manager available", "servers": [], "exit_code": 0} + from core.database import SessionLocal, McpServer + db = SessionLocal() + try: + servers = db.query(McpServer).all() + items = [] + for s in servers: + st = mcp.get_server_status(s.id) + status = st.get("status", "disconnected") + tool_count = st.get("tool_count", 0) + items.append({"id": s.id, "name": s.name, "transport": s.transport, + "is_enabled": s.is_enabled, "status": status, + "tool_count": tool_count}) + return {"response": f"{len(items)} MCP servers", "servers": items, "exit_code": 0} + finally: + db.close() + + elif action == "add": + from core.database import SessionLocal, McpServer + import uuid as _uuid + from datetime import datetime + name = args.get("name", "") + command = args.get("command", "") + cmd_args = args.get("args", []) + env = args.get("env", {}) + if not name or not command: + return {"error": "name and command are required", "exit_code": 1} + # Validate BEFORE any DB write or spawn: a rejected registration must + # leave no enabled row (which would otherwise auto-reconnect on restart) + # and must not attempt a connection. + _mcp_err = _validate_mcp_command(command, cmd_args, env) + if _mcp_err: + return {"error": f"manage_mcp: refused unsafe server registration: {_mcp_err}", "exit_code": 1} + sid = str(_uuid.uuid4())[:8] + db = SessionLocal() + try: + srv = McpServer(id=sid, name=name, transport="stdio", command=command, + args=json.dumps(cmd_args) if isinstance(cmd_args, list) else cmd_args, + env=json.dumps(env) if isinstance(env, dict) else env, + is_enabled=True, created_at=datetime.utcnow(), updated_at=datetime.utcnow()) + db.add(srv) + db.commit() + finally: + db.close() + # Try to connect + mcp = get_mcp_manager() + tool_count = 0 + if mcp: + try: + await mcp.connect_server( + sid, name, "stdio", command=command, + args=cmd_args if isinstance(cmd_args, list) else json.loads(cmd_args), + env=env if isinstance(env, dict) else json.loads(env), + ) + st = mcp.get_server_status(sid) + tool_count = st.get("tool_count", 0) + except Exception as e: + logger.warning(f"MCP connect failed for {name}: {e}") + return {"response": f"Added MCP server '{name}' ({tool_count} tools)", "exit_code": 0} + + elif action == "delete": + sid = args.get("server_id", "") + from core.database import SessionLocal, McpServer + db = SessionLocal() + try: + srv = db.query(McpServer).filter(McpServer.id == sid).first() + if not srv: + return {"error": f"Server {sid} not found", "exit_code": 1} + name = srv.name + mcp = get_mcp_manager() + if mcp: + try: + await mcp.disconnect_server(sid) + except Exception: + pass + db.delete(srv) + db.commit() + return {"response": f"Deleted MCP server '{name}'", "exit_code": 0} + finally: + db.close() + + elif action == "reconnect": + sid = args.get("server_id", "") + mcp = get_mcp_manager() + if not mcp: + return {"error": "MCP manager not available", "exit_code": 1} + try: + await mcp.disconnect_server(sid) + from core.database import SessionLocal, McpServer + db2 = SessionLocal() + try: + srv = db2.query(McpServer).filter(McpServer.id == sid).first() + if srv: + _args = json.loads(srv.args) if srv.args else [] + _env = json.loads(srv.env) if srv.env else {} + await mcp.connect_server( + server_id=sid, + name=srv.name, + transport=srv.transport, + command=srv.command, + args=_args, + env=_env, + url=srv.url, + ) + st = mcp.get_server_status(sid) + return {"response": f"Reconnected '{srv.name}' ({st.get('tool_count', 0)} tools)", "exit_code": 0} + return {"error": f"Server {sid} not found", "exit_code": 1} + finally: + db2.close() + except Exception as e: + return {"error": str(e), "exit_code": 1} + + elif action in ("enable", "disable"): + sid = args.get("server_id", "") + from core.database import SessionLocal, McpServer + db = SessionLocal() + try: + srv = db.query(McpServer).filter(McpServer.id == sid).first() + if not srv: + return {"error": f"Server {sid} not found", "exit_code": 1} + srv.is_enabled = (action == "enable") + db.commit() + return {"response": f"MCP server '{srv.name}' {action}d", "exit_code": 0} + finally: + db.close() + + elif action == "list_tools": + mcp = get_mcp_manager() + if not mcp: + return {"response": "No MCP manager", "tools": [], "exit_code": 0} + tools = mcp.get_all_tools() + items = [{"name": t["name"], "server": t["server_name"], + "description": t.get("description", "")[:100]} for t in tools] + return {"response": f"{len(items)} MCP tools available", "tools": items, "exit_code": 0} + + else: + return {"error": f"Unknown action: {action}", "exit_code": 1} + + +# --------------------------------------------------------------------------- +# Webhook management tool +# --------------------------------------------------------------------------- + +async def do_manage_webhooks(content: str, owner: Optional[str] = None) -> Dict: + """Manage webhooks: list, add, delete, enable, disable, test.""" + from core.database import SessionLocal + try: + args = _parse_tool_args(content) + except ValueError: + return {"error": "Invalid JSON arguments", "exit_code": 1} + + action = args.get("action", "list") + db = SessionLocal() + try: + from core.database import Webhook + if action == "list": + hooks = db.query(Webhook).all() + items = [{"id": h.id, "name": h.name, "url": h.url, + "events": h.events, "is_active": h.is_active} for h in hooks] + return {"response": f"{len(items)} webhooks", "webhooks": items, "exit_code": 0} + + elif action == "add": + import uuid as _uuid + from datetime import datetime + from src.webhook_manager import validate_events, validate_webhook_url + name = args.get("name", "") + url = args.get("url", "") + events = args.get("events", "chat.completed") + if not url: + return {"error": "url is required", "exit_code": 1} + try: + url = validate_webhook_url(url) + events = validate_events(events) + except ValueError as e: + return {"error": str(e), "exit_code": 1} + wid = str(_uuid.uuid4())[:8] + hook = Webhook(id=wid, name=name or url, url=url, + events=events, is_active=True, + created_at=datetime.utcnow(), updated_at=datetime.utcnow()) + db.add(hook) + db.commit() + return {"response": f"Added webhook '{name or url}'", "exit_code": 0} + + elif action == "delete": + wid = args.get("webhook_id", "") + hook = db.query(Webhook).filter(Webhook.id == wid).first() + if not hook: + return {"error": f"Webhook {wid} not found", "exit_code": 1} + name = hook.name + db.delete(hook) + db.commit() + return {"response": f"Deleted webhook '{name}'", "exit_code": 0} + + elif action in ("enable", "disable"): + wid = args.get("webhook_id", "") + hook = db.query(Webhook).filter(Webhook.id == wid).first() + if not hook: + return {"error": f"Webhook {wid} not found", "exit_code": 1} + hook.is_active = (action == "enable") + db.commit() + return {"response": f"Webhook '{hook.name}' {action}d", "exit_code": 0} + + else: + return {"error": f"Unknown action: {action}", "exit_code": 1} + except Exception as e: + logger.error(f"manage_webhooks error: {e}") + return {"error": str(e), "exit_code": 1} + finally: + db.close() + + +# --------------------------------------------------------------------------- +# API token management tool +# --------------------------------------------------------------------------- + +async def do_manage_tokens(content: str, owner: Optional[str] = None) -> Dict: + """Manage API tokens: list, create, delete.""" + from core.database import SessionLocal, ApiToken + try: + args = _parse_tool_args(content) + except ValueError: + return {"error": "Invalid JSON arguments", "exit_code": 1} + + action = args.get("action", "list") + db = SessionLocal() + try: + if action == "list": + tokens = db.query(ApiToken).all() + items = [{"id": t.id, "name": t.name, "token_prefix": t.token_prefix + "...", + "is_active": t.is_active} for t in tokens] + return {"response": f"{len(items)} API tokens", "tokens": items, "exit_code": 0} + + elif action == "create": + import uuid as _uuid, secrets, bcrypt + from datetime import datetime + name = args.get("name", "API Token") + raw_token = secrets.token_urlsafe(32) + token_hash = bcrypt.hashpw(raw_token.encode(), bcrypt.gensalt()).decode() + tid = str(_uuid.uuid4())[:8] + t = ApiToken(id=tid, name=name, token_hash=token_hash, + token_prefix=raw_token[:8], is_active=True, + created_at=datetime.utcnow(), updated_at=datetime.utcnow()) + db.add(t) + db.commit() + return {"response": f"Created token '{name}'", "token": raw_token, "exit_code": 0} + + elif action == "delete": + tid = args.get("token_id", "") + t = db.query(ApiToken).filter(ApiToken.id == tid).first() + if not t: + return {"error": f"Token {tid} not found", "exit_code": 1} + name = t.name + db.delete(t) + db.commit() + return {"response": f"Deleted token '{name}'", "exit_code": 0} + + else: + return {"error": f"Unknown action: {action}", "exit_code": 1} + except Exception as e: + logger.error(f"manage_tokens error: {e}") + return {"error": str(e), "exit_code": 1} + finally: + db.close() + +# --------------------------------------------------------------------------- +# Settings/preferences management tool +# --------------------------------------------------------------------------- + +async def do_manage_settings(content: str, owner: Optional[str] = None) -> Dict: + """Manage user settings and preferences.""" + try: + args = _parse_tool_args(content) + except ValueError: + return {"error": "Invalid JSON arguments", "exit_code": 1} + + action = args.get("action", "list") + + from core.database import SessionLocal + db = SessionLocal() + try: + # set/get/list/delete operate on the REAL app settings (the same store + # the Settings panel writes), so changing a model / voice / search + # engine / reminder channel from chat actually takes effect. + from src.settings import load_settings, save_settings, DEFAULT_SETTINGS + + # Secrets/credentials the agent must NOT write: kept read-only (masked) + # so API keys never flow through chat. User sets these in the panel. + _SECRET_KEYS = { + "brave_api_key", "google_pse_key", "google_pse_cx", + "tavily_api_key", "serper_api_key", "app_public_url", + } + def _is_secret(k): + # `token` must be a suffix, not a substring: otherwise the int + # setting `agent_input_token_budget` (which even has a "token budget" + # alias to set it from chat) is wrongly classified as a credential. + return ( + k in _SECRET_KEYS + or k.endswith("token") + or any(t in k for t in ("api_key", "_key", "secret", "password")) + ) + + # Friendly aliases → real keys, so natural phrasing resolves. + _ALIASES_SET = { + "voice": "tts_voice", "tts voice": "tts_voice", "tts": "tts_enabled", + "text to speech": "tts_enabled", "tts provider": "tts_provider", + "speech speed": "tts_speed", "voice speed": "tts_speed", + "stt": "stt_enabled", "speech to text": "stt_enabled", "transcription": "stt_enabled", + "search engine": "search_provider", "search provider": "search_provider", + "search results": "search_result_count", "result count": "search_result_count", + "default model": "default_model", "chat model": "default_model", + "default endpoint": "default_endpoint_id", + "task model": "task_model", "background model": "task_model", + "teacher model": "teacher_model", "teacher": "teacher_enabled", + "utility model": "utility_model", "research model": "research_model", + "research max tokens": "research_max_tokens", + "vision model": "vision_model", "vision": "vision_enabled", + "image model": "image_model", "image quality": "image_quality", + "image gen": "image_gen_enabled", "image generation": "image_gen_enabled", + "reminder channel": "reminder_channel", "reminders": "reminder_channel", + "ntfy topic": "reminder_ntfy_topic", + "webhook integration": "reminder_webhook_integration_id", + "webhook template": "reminder_webhook_payload_template", "webhook payload": "reminder_webhook_payload_template", + "agent tool calls": "agent_max_tool_calls", "max tool calls": "agent_max_tool_calls", + "agent timeout": "agent_stream_timeout_seconds", "stream timeout": "agent_stream_timeout_seconds", + "token budget": "agent_input_token_budget", "input budget": "agent_input_token_budget", + "hard max": "agent_input_token_hard_max", + "token budget cap": "agent_input_token_hard_max", + "input budget cap": "agent_input_token_hard_max", + } + def _resolve(k): + k2 = (k or "").strip().lower() + if k2 in DEFAULT_SETTINGS: + return k2 + return _ALIASES_SET.get(k2, (k or "").strip()) + + _ENUMS = { + "image_quality": ["low", "medium", "high"], + "reminder_channel": ["browser", "email", "ntfy", "webhook"], + } + def _coerce(value, default): + if isinstance(default, bool): + return value if isinstance(value, bool) else str(value).strip().lower() in ("true", "on", "yes", "1", "enable", "enabled") + if isinstance(default, int): + return int(value) + return value + + def _model_slug(value: str) -> str: + import re as _re + return _re.sub(r"[^a-z0-9]+", "", (value or "").lower()) + + def _endpoint_model_from_cache(model_query: str): + """Resolve friendly model text to an enabled endpoint + real model id. + + The Settings UI stores both `_endpoint_id` and + `_model`; writing only the model leaves the runtime on the + old endpoint. Prefer cached model lists so this stays fast/offline. + """ + import json as _json + import re as _re + from core.database import ModelEndpoint + + wanted = (model_query or "").strip() + wanted_slug = _model_slug(wanted) + wanted_tokens = [_model_slug(t) for t in _re.findall(r"[A-Za-z0-9]+", wanted)] + wanted_tokens = [t for t in wanted_tokens if t] + if not wanted_slug: + return None + best = None + for ep in db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all(): + raw_models = [] + try: + raw_models = _json.loads(ep.cached_models or "[]") or [] + except Exception: + raw_models = [] + # If cache is empty, still allow matching against endpoint name + # for callers using model@endpoint elsewhere later. + for mid in raw_models: + mid = str(mid) + mid_slug = _model_slug(mid) + if not mid_slug: + continue + exact = mid.lower() == wanted.lower() + compact_match = wanted_slug in mid_slug or mid_slug in wanted_slug + token_match = bool(wanted_tokens) and all(tok in mid_slug for tok in wanted_tokens) + if exact or compact_match or token_match: + score = 3 if exact else (2 if compact_match else 1) + if not best or score > best[0]: + best = (score, ep.id, mid) + if best: + return {"endpoint_id": best[1], "model": best[2]} + return None + + def _mask(k, v): + return "••••• (set in panel)" if _is_secret(k) and v else v + + if action == "list": + s = load_settings() + shown = {k: _mask(k, v) for k, v in s.items() if k in DEFAULT_SETTINGS and not isinstance(v, dict)} + return {"response": f"{len(shown)} settings (use get/set with a key)", "settings": shown, "exit_code": 0} + + elif action == "get": + key = _resolve(args.get("key", "")) + if not key: + return {"error": "key is required", "exit_code": 1} + if key not in DEFAULT_SETTINGS: + return {"error": f"Unknown setting '{args.get('key')}'. Use action='list' to see them.", "exit_code": 1} + val = load_settings().get(key, DEFAULT_SETTINGS.get(key)) + return {"response": f"{key} = {_mask(key, val)}", "value": _mask(key, val), "exit_code": 0} + + elif action == "set": + raw = args.get("key", "") + value = args.get("value") + if not raw: + return {"error": "key is required", "exit_code": 1} + key = _resolve(raw) + if key not in DEFAULT_SETTINGS: + return {"error": f"Unknown setting '{raw}'. Use action='list' to see available settings.", "exit_code": 1} + if _is_secret(key): + return {"response": f"'{key}' is a credential/secret. For security I can't set it from chat. Open Settings and set it there.", "exit_code": 0} + # Structured settings (dicts/lists like keybinds, default_model_fallbacks) + # have no safe scalar coercion; _coerce would pass a bare string + # straight through and clobber the structure. Refuse them here; they're + # edited in their dedicated panels. (reset/delete still restore the + # default structure, which is safe.) + if isinstance(DEFAULT_SETTINGS[key], (dict, list)): + return {"response": f"'{key}' is a structured setting. Edit it in its panel, not from chat. (You can reset it to default here.)", "exit_code": 0} + try: + value = _coerce(value, DEFAULT_SETTINGS[key]) + except (ValueError, TypeError): + return {"error": f"'{value}' isn't a valid value for {key} (expected {type(DEFAULT_SETTINGS[key]).__name__}).", "exit_code": 1} + if key in _ENUMS and str(value).lower() not in _ENUMS[key]: + return {"error": f"{key} must be one of: {', '.join(_ENUMS[key])}.", "exit_code": 1} + s = load_settings() + s[key] = value + if key in {"default_model", "research_model", "utility_model", "task_model", "vision_model", "image_model"}: + resolved = _endpoint_model_from_cache(str(value)) + if resolved: + prefix = key[:-6] + s[f"{prefix}_endpoint_id"] = resolved["endpoint_id"] + s[key] = resolved["model"] + value = resolved["model"] + save_settings(s) + if key.endswith("_model") and s.get(f"{key[:-6]}_endpoint_id"): + return {"response": f"Set {key} = {value} (endpoint {s.get(f'{key[:-6]}_endpoint_id')}).", "exit_code": 0} + return {"response": f"Set {key} = {value}.", "exit_code": 0} + + elif action == "delete" or action == "reset": + key = _resolve(args.get("key", "")) + if key not in DEFAULT_SETTINGS: + return {"error": f"Unknown setting '{args.get('key')}'.", "exit_code": 1} + if _is_secret(key): + return {"response": f"'{key}' is a credential. Reset it in the panel.", "exit_code": 0} + s = load_settings() + s[key] = DEFAULT_SETTINGS[key] + save_settings(s) + return {"response": f"Reset {key} to default ({DEFAULT_SETTINGS[key]}).", "exit_code": 0} + + elif action in ("disable_tool", "enable_tool", "list_tools"): + # Tool-toggle actions. These edit settings.json:disabled_tools + # (the global list read on every chat request) rather than + # prefs.json. Friendly aliases accepted: "shell" -> "bash", + # "search" -> "web_search", "browser" -> "builtin_browser", + # "documents" -> the document tool set, "memory" -> + # manage_memory, etc. + from src.settings import get_setting, save_settings, load_settings + _ALIASES = { + "shell": ["bash"], + "terminal": ["bash"], + "search": ["web_search", "web_fetch"], + "web": ["web_search", "web_fetch"], + "browser": ["builtin_browser"], + "documents": ["create_document", "edit_document", "update_document", "suggest_document"], + "doc": ["create_document", "edit_document", "update_document", "suggest_document"], + "memory": ["manage_memory"], + "skills": ["manage_skills"], + "images": ["generate_image"], + "image": ["generate_image"], + "tasks": ["manage_tasks"], + "notes": ["manage_notes"], + "calendar": ["manage_calendar"], + "email": ["mcp__email__list_emails", "mcp__email__read_email", "mcp__email__send_email"], + "research": ["web_search", "web_fetch"], # research is a per-request flag, not a tool (closest analog) + } + + if action == "list_tools": + current = get_setting("disabled_tools", []) or [] + return { + "response": ( + f"Currently disabled: {', '.join(current) if current else '(none)'}.\n" + "Common toggles: shell (bash), search (web_search), browser, documents, " + "memory, skills, images, tasks, notes, calendar, email." + ), + "disabled": list(current), + "exit_code": 0, + } + + tool_name = (args.get("tool") or args.get("name") or "").strip().lower() + if not tool_name: + return {"error": "tool name required (e.g. 'shell', 'search', 'bash')", "exit_code": 1} + targets = _ALIASES.get(tool_name, [tool_name]) + + settings = load_settings() + current = list(settings.get("disabled_tools") or []) + before = set(current) + if action == "disable_tool": + for t in targets: + if t not in current: + current.append(t) + else: # enable_tool + current = [t for t in current if t not in targets] + after = set(current) + settings["disabled_tools"] = current + save_settings(settings) + + verb = "Disabled" if action == "disable_tool" else "Enabled" + changed = sorted(after.symmetric_difference(before)) + return { + "response": ( + f"{verb} {tool_name} ({', '.join(targets)}). " + f"Now disabled: {', '.join(current) if current else '(none)'}." + ), + "changed": changed, + "disabled": list(current), + "exit_code": 0, + } + + else: + return {"error": f"Unknown action: {action}", "exit_code": 1} + except Exception as e: + logger.error(f"manage_settings error: {e}") + return {"error": str(e), "exit_code": 1} + finally: + db.close() + + +# --------------------------------------------------------------------------- +# API call tool +# --------------------------------------------------------------------------- + + + +# ── registry adapters ──────────────────────────────────────────────────────── +def _owner_adapter(fn): + """Wrap a do_*(content, owner) impl as a registry execute(content, ctx).""" + async def _execute(content: str, ctx: dict) -> dict: + return await fn(content, ctx.get("owner")) + return _execute + + +ADMIN_TOOL_HANDLERS = { + "manage_endpoints": _owner_adapter(do_manage_endpoints), + "manage_mcp": _owner_adapter(do_manage_mcp), + "manage_webhooks": _owner_adapter(do_manage_webhooks), + "manage_tokens": _owner_adapter(do_manage_tokens), + "manage_settings": _owner_adapter(do_manage_settings), +} diff --git a/src/agent_tools/document_tools.py b/src/agent_tools/document_tools.py index 33b10c8d3..dcbea8b99 100644 --- a/src/agent_tools/document_tools.py +++ b/src/agent_tools/document_tools.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional import logging import re -import json from src.constants import MAX_READ_CHARS +from src.tool_utils import _parse_tool_args logger = logging.getLogger(__name__) @@ -154,38 +154,6 @@ def _coerce_email_document_content(existing: str, incoming: str) -> str: body = new return header.rstrip() + "\n---\n" + body -def _parse_tool_args(content): - """Parse a tool-call argument blob. - - Accepts either a JSON string or an already-decoded dict. Unwraps the - common `{"body": {...}}` envelope that smaller models emit when they - read tool descriptions like "Body is JSON: {...}" literally — they - pass `body` as a field name rather than treating it as a noun. - - Returns a dict on success, raises ValueError on bad JSON. - """ - if isinstance(content, str): - try: - args = json.loads(content) if content.strip() else {} - except (json.JSONDecodeError, TypeError) as e: - raise ValueError(str(e)) - elif isinstance(content, dict): - args = content - else: - args = {} - # Unwrap {"body": {...}} envelope — but only if `body` is the sole key - # and points at a dict. We don't want to clobber a legitimate `body` - # field on tools where it's a real arg (e.g. send_email body text). - if ( - isinstance(args, dict) - and len(args) == 1 - and "body" in args - and isinstance(args["body"], dict) - and "action" in args["body"] # extra safety: only unwrap if the inner dict looks like a tool call - ): - args = args["body"] - return args - def parse_edit_blocks(content: str) -> list: """Parse <<>>...<<>>...<<>> blocks.""" edits = [] diff --git a/src/tool_execution.py b/src/tool_execution.py index 3b4ba5eab..532b8c3b6 100644 --- a/src/tool_execution.py +++ b/src/tool_execution.py @@ -563,9 +563,7 @@ async def _execute_tool_block_impl( """ from src.tool_implementations import ( do_search_chats, do_manage_tasks, - do_manage_skills, do_api_call, do_manage_endpoints, - do_manage_mcp, do_manage_webhooks, do_manage_tokens, - do_manage_settings, do_manage_notes, + do_manage_skills, do_api_call, do_manage_notes, do_manage_calendar, do_download_model, do_serve_model, do_list_served_models, do_stop_served_model, do_tail_serve_output, @@ -808,21 +806,11 @@ async def _execute_tool_block_impl( first_line = content.split("\n")[0].strip()[:60] desc = f"api_call: {first_line}" result = await do_api_call(content) - elif tool == "manage_endpoints": - desc = "manage_endpoints" - result = await do_manage_endpoints(content, owner=owner) - elif tool == "manage_mcp": - desc = "manage_mcp" - result = await do_manage_mcp(content, owner=owner) - elif tool == "manage_webhooks": - desc = "manage_webhooks" - result = await do_manage_webhooks(content, owner=owner) - elif tool == "manage_tokens": - desc = "manage_tokens" - result = await do_manage_tokens(content, owner=owner) - elif tool == "manage_settings": - desc = "manage_settings" - result = await do_manage_settings(content, owner=owner) + elif tool in ("manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "manage_settings"): + # Registry-dispatched (agent_tools.admin_tools); owner threaded for ownership/admin checks. + desc = tool + result = await _direct_fallback(tool, content, owner=owner) \ + or {"error": f"{tool}: execution failed", "exit_code": 1} elif tool == "manage_notes": desc = "manage_notes" result = await do_manage_notes(content, owner=owner) diff --git a/src/tool_implementations.py b/src/tool_implementations.py index 3ce3ee613..cb0495aa2 100644 --- a/src/tool_implementations.py +++ b/src/tool_implementations.py @@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional from fastapi import HTTPException from src.constants import MAX_READ_CHARS, DEEP_RESEARCH_DIR, VAULT_FILE -from src.tool_utils import get_mcp_manager +from src.tool_utils import get_mcp_manager, _parse_tool_args from core.constants import internal_api_base from routes._validators import validate_remote_host, validate_ssh_port @@ -68,38 +68,6 @@ def clear_active_email() -> None: # Argument parsing # --------------------------------------------------------------------------- -def _parse_tool_args(content): - """Parse a tool-call argument blob. - - Accepts either a JSON string or an already-decoded dict. Unwraps the - common `{"body": {...}}` envelope that smaller models emit when they - read tool descriptions like "Body is JSON: {...}" literally — they - pass `body` as a field name rather than treating it as a noun. - - Returns a dict on success, raises ValueError on bad JSON. - """ - if isinstance(content, str): - try: - args = json.loads(content) if content.strip() else {} - except (json.JSONDecodeError, TypeError) as e: - raise ValueError(str(e)) - elif isinstance(content, dict): - args = content - else: - args = {} - # Unwrap {"body": {...}} envelope — but only if `body` is the sole key - # and points at a dict. We don't want to clobber a legitimate `body` - # field on tools where it's a real arg (e.g. send_email body text). - if ( - isinstance(args, dict) - and len(args) == 1 - and "body" in args - and isinstance(args["body"], dict) - and "action" in args["body"] # extra safety: only unwrap if the inner dict looks like a tool call - ): - args = args["body"] - return args - # --------------------------------------------------------------------------- # Search chats # --------------------------------------------------------------------------- @@ -588,757 +556,6 @@ async def do_manage_tasks(content: str, owner: Optional[str] = None) -> Dict: db.close() -# --------------------------------------------------------------------------- -# Endpoint management tool -# --------------------------------------------------------------------------- - -async def do_manage_endpoints(content: str, owner: Optional[str] = None) -> Dict: - """Manage model endpoints: list, add, delete, enable, disable.""" - from core.database import SessionLocal, ModelEndpoint - try: - args = _parse_tool_args(content) - except ValueError: - return {"error": "Invalid JSON arguments", "exit_code": 1} - - action = args.get("action", "list") - db = SessionLocal() - try: - if action == "list": - eps = db.query(ModelEndpoint).all() - items = [{"id": e.id, "name": e.name, "base_url": e.base_url, - "is_enabled": e.is_enabled} for e in eps] - return {"response": f"{len(items)} endpoints", "endpoints": items, "exit_code": 0} - - elif action == "add": - import uuid as _uuid - name = args.get("name", "") - base_url = args.get("base_url", "") - api_key = args.get("api_key", "") - if not base_url: - return {"error": "base_url is required", "exit_code": 1} - eid = str(_uuid.uuid4())[:8] - from datetime import datetime - ep = ModelEndpoint(id=eid, name=name or base_url, base_url=base_url, - api_key=api_key, is_enabled=True, - created_at=datetime.utcnow(), updated_at=datetime.utcnow()) - db.add(ep) - db.commit() - return {"response": f"Added endpoint '{name or base_url}' (id: {eid})", "exit_code": 0} - - elif action == "delete": - eid = args.get("endpoint_id", "") - ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == eid).first() - if not ep: - return {"error": f"Endpoint {eid} not found", "exit_code": 1} - name = ep.name - db.delete(ep) - db.commit() - return {"response": f"Deleted endpoint '{name}'", "exit_code": 0} - - elif action in ("enable", "disable"): - eid = args.get("endpoint_id", "") - ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == eid).first() - if not ep: - return {"error": f"Endpoint {eid} not found", "exit_code": 1} - ep.is_enabled = (action == "enable") - db.commit() - return {"response": f"Endpoint '{ep.name}' {action}d", "exit_code": 0} - - else: - return {"error": f"Unknown action: {action}", "exit_code": 1} - except Exception as e: - logger.error(f"manage_endpoints error: {e}") - return {"error": str(e), "exit_code": 1} - finally: - db.close() - - -# --------------------------------------------------------------------------- -# MCP server management tool -# --------------------------------------------------------------------------- - -# Parallel to routes/cookbook_helpers._validate_serve_cmd but deliberately the -# opposite policy: that gate guards an admin-only serve command and allows -# interpreters (python3/etc) because model-serving needs them, whereas this is -# the model/prompt-injection-reachable manage_mcp path, so interpreters and -# runners are denied here. -# -# Commands that can execute arbitrary code regardless of their arguments. These -# are NEVER accepted on the manage_mcp agent path, even if an operator lists one -# in ODYSSEUS_MCP_ALLOWED_COMMANDS -- a stdio server that genuinely needs an -# interpreter or package runner must be registered via the trusted admin route. -_MCP_DENIED_COMMANDS = frozenset({ - "sh", "bash", "zsh", "fish", "dash", "ksh", "csh", "tcsh", "ash", "busybox", - "cmd", "command.com", "powershell", "pwsh", - "python", "pypy", "node", "nodejs", "deno", "bun", "ruby", "jruby", - "perl", "raku", "php", "lua", "luajit", "tclsh", "wish", "expect", "rscript", - "groovy", "scala", "elixir", "erl", "iex", "java", "javac", "jshell", "jbang", - "kotlin", "kotlinc", "dotnet", "mono", "swift", "osascript", "tsx", "ts-node", - "npx", "bunx", "uvx", "pipx", "npm", "pnpm", "yarn", "pip", "uv", - "gem", "cargo", "go", "bundle", "poetry", "conda", "mamba", "brew", - "apt", "apt-get", "yum", "dnf", "pacman", "apk", - "env", "xargs", "nohup", "setsid", "nice", "ionice", "time", "timeout", - "watch", "stdbuf", "unbuffer", "script", "ssh", "scp", "sshpass", "sudo", - "doas", "su", "make", "cmake", "docker", "podman", "kubectl", "find", - "awk", "gawk", "sed", "vi", "vim", "nvim", "emacs", "ed", "tee", "eval", -}) - -# Argv flags that make even an allowlisted binary execute inline code. Matched -# by prefix so glued forms (-cimport os, --eval=...) are caught, not just the -# exact-token form. -_MCP_CODE_EXEC_SHORT_FLAGS = ("-c", "-e", "-m") -_MCP_CODE_EXEC_LONG_FLAGS = ("--eval", "--exec", "--print", "--module", "--command", "--require") - -_MCP_URL_SCHEMES = ("http://", "https://", "ftp://", "ftps://", "file://", "data:", "jar:", "blob:") - -# Shell metacharacters refused in command/args. Args are passed as an argv list -# (no shell), but refusing these keeps the surface narrow and obvious. -_MCP_SHELL_METACHARS = set(";|&$`><\n\r") - -# Env vars that let a child process load attacker-supplied code before main(). -_MCP_DANGEROUS_ENV = frozenset({ - "LD_PRELOAD", "LD_LIBRARY_PATH", "LD_AUDIT", "DYLD_INSERT_LIBRARIES", - "DYLD_LIBRARY_PATH", "DYLD_FRAMEWORK_PATH", "PYTHONPATH", "PYTHONSTARTUP", - "PYTHONHOME", "PYTHONEXECUTABLE", "NODE_OPTIONS", "NODE_PATH", "BASH_ENV", - "ENV", "SHELLOPTS", "PERL5LIB", "PERL5OPT", "RUBYOPT", "RUBYLIB", "GEM_PATH", - "R_PROFILE", "R_HOME", "PATH", "IFS", "PROMPT_COMMAND", -}) - - -def _mcp_allowed_commands() -> set: - """Operator-configured allowlist of safe MCP launcher basenames for the agent - path. Empty by default; set ODYSSEUS_MCP_ALLOWED_COMMANDS (comma-separated) - to opt specific trusted binaries in. Denied commands are rejected even if - listed here.""" - raw = os.environ.get("ODYSSEUS_MCP_ALLOWED_COMMANDS", "") - return {c.strip().lower() for c in raw.split(",") if c.strip()} - - -def _validate_mcp_command(command, args, env) -> Optional[str]: - """Validate a model-supplied stdio MCP registration. Returns an error string - if it must be rejected, else None. - - Closes the RCE where manage_mcp 'add' passed prompt-injection-controlled - command/args/env straight to a subprocess spawn (issue #438): a payload - smuggled into a skill description, memory entry, fetched page, or email body - could register a stdio server running arbitrary code as the app UID. - """ - if not isinstance(command, str) or not command.strip(): - return "command must be a non-empty string" - command = command.strip() - if "/" in command or "\\" in command: - return "command must be a bare executable name, not a path" - if any(ch in _MCP_SHELL_METACHARS for ch in command): - return "command contains shell metacharacters" - base = command.lower() - if base.endswith(".exe") or base.endswith(".cmd") or base.endswith(".bat"): - base = base.rsplit(".", 1)[0] - # Canonicalize a trailing version suffix so versioned aliases collapse to the - # family name (python3.11 -> python, node18 -> node, pip3 -> pip); both the - # raw basename and the canonical form are denied, so an operator cannot - # accidentally allowlist a runtime alias back into the path. - canon = re.sub(r"[-_.]?\d+(?:\.\d+)*$", "", base) - if base in _MCP_DENIED_COMMANDS or canon in _MCP_DENIED_COMMANDS: - return ( - f"command '{command}' is not allowed on the agent MCP path: " - "interpreters, runtimes, package runners, and shells can execute " - "arbitrary code. Register such a server via the admin route instead." - ) - if base not in _mcp_allowed_commands(): - return ( - f"command '{command}' is not in the MCP allowlist. Add it to " - "ODYSSEUS_MCP_ALLOWED_COMMANDS if you trust it, or register the " - "server via the admin route." - ) - - if args is not None: - if isinstance(args, str): - try: - args = json.loads(args) - except Exception: - return "args must be a JSON list" - if not isinstance(args, list): - return "args must be a list" - for a in args: - if not isinstance(a, str): - return "args must all be strings" - s = a.strip() - low = s.lower() - if any(s == f or s.startswith(f) for f in _MCP_CODE_EXEC_SHORT_FLAGS): - return f"arg '{a}' is a code-execution flag and is not allowed" - if any(low == f or low.startswith(f + "=") for f in _MCP_CODE_EXEC_LONG_FLAGS): - return f"arg '{a}' is a code-execution flag and is not allowed" - if any(low.startswith(u) for u in _MCP_URL_SCHEMES): - return f"arg '{a}' is a remote URL and is not allowed" - if any(ch in _MCP_SHELL_METACHARS for ch in a): - return f"arg '{a}' contains shell metacharacters" - - if env: - if isinstance(env, str): - try: - env = json.loads(env) - except Exception: - return "env must be a JSON object" - if not isinstance(env, dict): - return "env must be an object" - for k in env: - if str(k).strip().upper() in _MCP_DANGEROUS_ENV: - return f"env var '{k}' can inject code into the child process and is not allowed" - - return None - - -async def do_manage_mcp(content: str, owner: Optional[str] = None) -> Dict: - """Manage MCP servers: list, add, delete, enable, disable, reconnect.""" - try: - args = _parse_tool_args(content) - except ValueError: - return {"error": "Invalid JSON arguments", "exit_code": 1} - - action = args.get("action", "list") - - if action == "list": - mcp = get_mcp_manager() - if not mcp: - return {"response": "No MCP manager available", "servers": [], "exit_code": 0} - from core.database import SessionLocal, McpServer - db = SessionLocal() - try: - servers = db.query(McpServer).all() - items = [] - for s in servers: - st = mcp.get_server_status(s.id) - status = st.get("status", "disconnected") - tool_count = st.get("tool_count", 0) - items.append({"id": s.id, "name": s.name, "transport": s.transport, - "is_enabled": s.is_enabled, "status": status, - "tool_count": tool_count}) - return {"response": f"{len(items)} MCP servers", "servers": items, "exit_code": 0} - finally: - db.close() - - elif action == "add": - from core.database import SessionLocal, McpServer - import uuid as _uuid - from datetime import datetime - name = args.get("name", "") - command = args.get("command", "") - cmd_args = args.get("args", []) - env = args.get("env", {}) - if not name or not command: - return {"error": "name and command are required", "exit_code": 1} - # Validate BEFORE any DB write or spawn: a rejected registration must - # leave no enabled row (which would otherwise auto-reconnect on restart) - # and must not attempt a connection. - _mcp_err = _validate_mcp_command(command, cmd_args, env) - if _mcp_err: - return {"error": f"manage_mcp: refused unsafe server registration: {_mcp_err}", "exit_code": 1} - sid = str(_uuid.uuid4())[:8] - db = SessionLocal() - try: - srv = McpServer(id=sid, name=name, transport="stdio", command=command, - args=json.dumps(cmd_args) if isinstance(cmd_args, list) else cmd_args, - env=json.dumps(env) if isinstance(env, dict) else env, - is_enabled=True, created_at=datetime.utcnow(), updated_at=datetime.utcnow()) - db.add(srv) - db.commit() - finally: - db.close() - # Try to connect - mcp = get_mcp_manager() - tool_count = 0 - if mcp: - try: - await mcp.connect_server( - sid, name, "stdio", command=command, - args=cmd_args if isinstance(cmd_args, list) else json.loads(cmd_args), - env=env if isinstance(env, dict) else json.loads(env), - ) - st = mcp.get_server_status(sid) - tool_count = st.get("tool_count", 0) - except Exception as e: - logger.warning(f"MCP connect failed for {name}: {e}") - return {"response": f"Added MCP server '{name}' ({tool_count} tools)", "exit_code": 0} - - elif action == "delete": - sid = args.get("server_id", "") - from core.database import SessionLocal, McpServer - db = SessionLocal() - try: - srv = db.query(McpServer).filter(McpServer.id == sid).first() - if not srv: - return {"error": f"Server {sid} not found", "exit_code": 1} - name = srv.name - mcp = get_mcp_manager() - if mcp: - try: - await mcp.disconnect_server(sid) - except Exception: - pass - db.delete(srv) - db.commit() - return {"response": f"Deleted MCP server '{name}'", "exit_code": 0} - finally: - db.close() - - elif action == "reconnect": - sid = args.get("server_id", "") - mcp = get_mcp_manager() - if not mcp: - return {"error": "MCP manager not available", "exit_code": 1} - try: - await mcp.disconnect_server(sid) - from core.database import SessionLocal, McpServer - db2 = SessionLocal() - try: - srv = db2.query(McpServer).filter(McpServer.id == sid).first() - if srv: - _args = json.loads(srv.args) if srv.args else [] - _env = json.loads(srv.env) if srv.env else {} - await mcp.connect_server( - server_id=sid, - name=srv.name, - transport=srv.transport, - command=srv.command, - args=_args, - env=_env, - url=srv.url, - ) - st = mcp.get_server_status(sid) - return {"response": f"Reconnected '{srv.name}' ({st.get('tool_count', 0)} tools)", "exit_code": 0} - return {"error": f"Server {sid} not found", "exit_code": 1} - finally: - db2.close() - except Exception as e: - return {"error": str(e), "exit_code": 1} - - elif action in ("enable", "disable"): - sid = args.get("server_id", "") - from core.database import SessionLocal, McpServer - db = SessionLocal() - try: - srv = db.query(McpServer).filter(McpServer.id == sid).first() - if not srv: - return {"error": f"Server {sid} not found", "exit_code": 1} - srv.is_enabled = (action == "enable") - db.commit() - return {"response": f"MCP server '{srv.name}' {action}d", "exit_code": 0} - finally: - db.close() - - elif action == "list_tools": - mcp = get_mcp_manager() - if not mcp: - return {"response": "No MCP manager", "tools": [], "exit_code": 0} - tools = mcp.get_all_tools() - items = [{"name": t["name"], "server": t["server_name"], - "description": t.get("description", "")[:100]} for t in tools] - return {"response": f"{len(items)} MCP tools available", "tools": items, "exit_code": 0} - - else: - return {"error": f"Unknown action: {action}", "exit_code": 1} - - -# --------------------------------------------------------------------------- -# Webhook management tool -# --------------------------------------------------------------------------- - -async def do_manage_webhooks(content: str, owner: Optional[str] = None) -> Dict: - """Manage webhooks: list, add, delete, enable, disable, test.""" - from core.database import SessionLocal - try: - args = _parse_tool_args(content) - except ValueError: - return {"error": "Invalid JSON arguments", "exit_code": 1} - - action = args.get("action", "list") - db = SessionLocal() - try: - from core.database import Webhook - if action == "list": - hooks = db.query(Webhook).all() - items = [{"id": h.id, "name": h.name, "url": h.url, - "events": h.events, "is_active": h.is_active} for h in hooks] - return {"response": f"{len(items)} webhooks", "webhooks": items, "exit_code": 0} - - elif action == "add": - import uuid as _uuid - from datetime import datetime - from src.webhook_manager import validate_events, validate_webhook_url - name = args.get("name", "") - url = args.get("url", "") - events = args.get("events", "chat.completed") - if not url: - return {"error": "url is required", "exit_code": 1} - try: - url = validate_webhook_url(url) - events = validate_events(events) - except ValueError as e: - return {"error": str(e), "exit_code": 1} - wid = str(_uuid.uuid4())[:8] - hook = Webhook(id=wid, name=name or url, url=url, - events=events, is_active=True, - created_at=datetime.utcnow(), updated_at=datetime.utcnow()) - db.add(hook) - db.commit() - return {"response": f"Added webhook '{name or url}'", "exit_code": 0} - - elif action == "delete": - wid = args.get("webhook_id", "") - hook = db.query(Webhook).filter(Webhook.id == wid).first() - if not hook: - return {"error": f"Webhook {wid} not found", "exit_code": 1} - name = hook.name - db.delete(hook) - db.commit() - return {"response": f"Deleted webhook '{name}'", "exit_code": 0} - - elif action in ("enable", "disable"): - wid = args.get("webhook_id", "") - hook = db.query(Webhook).filter(Webhook.id == wid).first() - if not hook: - return {"error": f"Webhook {wid} not found", "exit_code": 1} - hook.is_active = (action == "enable") - db.commit() - return {"response": f"Webhook '{hook.name}' {action}d", "exit_code": 0} - - else: - return {"error": f"Unknown action: {action}", "exit_code": 1} - except Exception as e: - logger.error(f"manage_webhooks error: {e}") - return {"error": str(e), "exit_code": 1} - finally: - db.close() - - -# --------------------------------------------------------------------------- -# API token management tool -# --------------------------------------------------------------------------- - -async def do_manage_tokens(content: str, owner: Optional[str] = None) -> Dict: - """Manage API tokens: list, create, delete.""" - from core.database import SessionLocal, ApiToken - try: - args = _parse_tool_args(content) - except ValueError: - return {"error": "Invalid JSON arguments", "exit_code": 1} - - action = args.get("action", "list") - db = SessionLocal() - try: - if action == "list": - tokens = db.query(ApiToken).all() - items = [{"id": t.id, "name": t.name, "token_prefix": t.token_prefix + "...", - "is_active": t.is_active} for t in tokens] - return {"response": f"{len(items)} API tokens", "tokens": items, "exit_code": 0} - - elif action == "create": - import uuid as _uuid, secrets, bcrypt - from datetime import datetime - name = args.get("name", "API Token") - raw_token = secrets.token_urlsafe(32) - token_hash = bcrypt.hashpw(raw_token.encode(), bcrypt.gensalt()).decode() - tid = str(_uuid.uuid4())[:8] - t = ApiToken(id=tid, name=name, token_hash=token_hash, - token_prefix=raw_token[:8], is_active=True, - created_at=datetime.utcnow(), updated_at=datetime.utcnow()) - db.add(t) - db.commit() - return {"response": f"Created token '{name}'", "token": raw_token, "exit_code": 0} - - elif action == "delete": - tid = args.get("token_id", "") - t = db.query(ApiToken).filter(ApiToken.id == tid).first() - if not t: - return {"error": f"Token {tid} not found", "exit_code": 1} - name = t.name - db.delete(t) - db.commit() - return {"response": f"Deleted token '{name}'", "exit_code": 0} - - else: - return {"error": f"Unknown action: {action}", "exit_code": 1} - except Exception as e: - logger.error(f"manage_tokens error: {e}") - return {"error": str(e), "exit_code": 1} - finally: - db.close() - -# --------------------------------------------------------------------------- -# Settings/preferences management tool -# --------------------------------------------------------------------------- - -async def do_manage_settings(content: str, owner: Optional[str] = None) -> Dict: - """Manage user settings and preferences.""" - try: - args = _parse_tool_args(content) - except ValueError: - return {"error": "Invalid JSON arguments", "exit_code": 1} - - action = args.get("action", "list") - - from core.database import SessionLocal - db = SessionLocal() - try: - # set/get/list/delete operate on the REAL app settings (the same store - # the Settings panel writes), so changing a model / voice / search - # engine / reminder channel from chat actually takes effect. - from src.settings import load_settings, save_settings, DEFAULT_SETTINGS - - # Secrets/credentials the agent must NOT write — kept read-only (masked) - # so API keys never flow through chat. User sets these in the panel. - _SECRET_KEYS = { - "brave_api_key", "google_pse_key", "google_pse_cx", - "tavily_api_key", "serper_api_key", "app_public_url", - } - def _is_secret(k): - # `token` must be a suffix, not a substring: otherwise the int - # setting `agent_input_token_budget` (which even has a "token budget" - # alias to set it from chat) is wrongly classified as a credential. - return ( - k in _SECRET_KEYS - or k.endswith("token") - or any(t in k for t in ("api_key", "_key", "secret", "password")) - ) - - # Friendly aliases → real keys, so natural phrasing resolves. - _ALIASES_SET = { - "voice": "tts_voice", "tts voice": "tts_voice", "tts": "tts_enabled", - "text to speech": "tts_enabled", "tts provider": "tts_provider", - "speech speed": "tts_speed", "voice speed": "tts_speed", - "stt": "stt_enabled", "speech to text": "stt_enabled", "transcription": "stt_enabled", - "search engine": "search_provider", "search provider": "search_provider", - "search results": "search_result_count", "result count": "search_result_count", - "default model": "default_model", "chat model": "default_model", - "default endpoint": "default_endpoint_id", - "task model": "task_model", "background model": "task_model", - "teacher model": "teacher_model", "teacher": "teacher_enabled", - "utility model": "utility_model", "research model": "research_model", - "research max tokens": "research_max_tokens", - "vision model": "vision_model", "vision": "vision_enabled", - "image model": "image_model", "image quality": "image_quality", - "image gen": "image_gen_enabled", "image generation": "image_gen_enabled", - "reminder channel": "reminder_channel", "reminders": "reminder_channel", - "ntfy topic": "reminder_ntfy_topic", - "webhook integration": "reminder_webhook_integration_id", - "webhook template": "reminder_webhook_payload_template", "webhook payload": "reminder_webhook_payload_template", - "agent tool calls": "agent_max_tool_calls", "max tool calls": "agent_max_tool_calls", - "agent timeout": "agent_stream_timeout_seconds", "stream timeout": "agent_stream_timeout_seconds", - "token budget": "agent_input_token_budget", "input budget": "agent_input_token_budget", - "hard max": "agent_input_token_hard_max", - "token budget cap": "agent_input_token_hard_max", - "input budget cap": "agent_input_token_hard_max", - } - def _resolve(k): - k2 = (k or "").strip().lower() - if k2 in DEFAULT_SETTINGS: - return k2 - return _ALIASES_SET.get(k2, (k or "").strip()) - - _ENUMS = { - "image_quality": ["low", "medium", "high"], - "reminder_channel": ["browser", "email", "ntfy", "webhook"], - } - def _coerce(value, default): - if isinstance(default, bool): - return value if isinstance(value, bool) else str(value).strip().lower() in ("true", "on", "yes", "1", "enable", "enabled") - if isinstance(default, int): - return int(value) - return value - - def _model_slug(value: str) -> str: - import re as _re - return _re.sub(r"[^a-z0-9]+", "", (value or "").lower()) - - def _endpoint_model_from_cache(model_query: str): - """Resolve friendly model text to an enabled endpoint + real model id. - - The Settings UI stores both `_endpoint_id` and - `_model`; writing only the model leaves the runtime on the - old endpoint. Prefer cached model lists so this stays fast/offline. - """ - import json as _json - import re as _re - from core.database import ModelEndpoint - - wanted = (model_query or "").strip() - wanted_slug = _model_slug(wanted) - wanted_tokens = [_model_slug(t) for t in _re.findall(r"[A-Za-z0-9]+", wanted)] - wanted_tokens = [t for t in wanted_tokens if t] - if not wanted_slug: - return None - best = None - for ep in db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all(): - raw_models = [] - try: - raw_models = _json.loads(ep.cached_models or "[]") or [] - except Exception: - raw_models = [] - # If cache is empty, still allow matching against endpoint name - # for callers using model@endpoint elsewhere later. - for mid in raw_models: - mid = str(mid) - mid_slug = _model_slug(mid) - if not mid_slug: - continue - exact = mid.lower() == wanted.lower() - compact_match = wanted_slug in mid_slug or mid_slug in wanted_slug - token_match = bool(wanted_tokens) and all(tok in mid_slug for tok in wanted_tokens) - if exact or compact_match or token_match: - score = 3 if exact else (2 if compact_match else 1) - if not best or score > best[0]: - best = (score, ep.id, mid) - if best: - return {"endpoint_id": best[1], "model": best[2]} - return None - - def _mask(k, v): - return "••••• (set in panel)" if _is_secret(k) and v else v - - if action == "list": - s = load_settings() - shown = {k: _mask(k, v) for k, v in s.items() if k in DEFAULT_SETTINGS and not isinstance(v, dict)} - return {"response": f"{len(shown)} settings (use get/set with a key)", "settings": shown, "exit_code": 0} - - elif action == "get": - key = _resolve(args.get("key", "")) - if not key: - return {"error": "key is required", "exit_code": 1} - if key not in DEFAULT_SETTINGS: - return {"error": f"Unknown setting '{args.get('key')}'. Use action='list' to see them.", "exit_code": 1} - val = load_settings().get(key, DEFAULT_SETTINGS.get(key)) - return {"response": f"{key} = {_mask(key, val)}", "value": _mask(key, val), "exit_code": 0} - - elif action == "set": - raw = args.get("key", "") - value = args.get("value") - if not raw: - return {"error": "key is required", "exit_code": 1} - key = _resolve(raw) - if key not in DEFAULT_SETTINGS: - return {"error": f"Unknown setting '{raw}'. Use action='list' to see available settings.", "exit_code": 1} - if _is_secret(key): - return {"response": f"'{key}' is a credential/secret — for security I can't set it from chat. Open Settings and set it there.", "exit_code": 0} - # Structured settings (dicts/lists like keybinds, default_model_fallbacks) - # have no safe scalar coercion — _coerce would pass a bare string - # straight through and clobber the structure. Refuse them here; they're - # edited in their dedicated panels. (reset/delete still restore the - # default structure, which is safe.) - if isinstance(DEFAULT_SETTINGS[key], (dict, list)): - return {"response": f"'{key}' is a structured setting — edit it in its panel, not from chat. (You can reset it to default here.)", "exit_code": 0} - try: - value = _coerce(value, DEFAULT_SETTINGS[key]) - except (ValueError, TypeError): - return {"error": f"'{value}' isn't a valid value for {key} (expected {type(DEFAULT_SETTINGS[key]).__name__}).", "exit_code": 1} - if key in _ENUMS and str(value).lower() not in _ENUMS[key]: - return {"error": f"{key} must be one of: {', '.join(_ENUMS[key])}.", "exit_code": 1} - s = load_settings() - s[key] = value - if key in {"default_model", "research_model", "utility_model", "task_model", "vision_model", "image_model"}: - resolved = _endpoint_model_from_cache(str(value)) - if resolved: - prefix = key[:-6] - s[f"{prefix}_endpoint_id"] = resolved["endpoint_id"] - s[key] = resolved["model"] - value = resolved["model"] - save_settings(s) - if key.endswith("_model") and s.get(f"{key[:-6]}_endpoint_id"): - return {"response": f"Set {key} = {value} (endpoint {s.get(f'{key[:-6]}_endpoint_id')}).", "exit_code": 0} - return {"response": f"Set {key} = {value}.", "exit_code": 0} - - elif action == "delete" or action == "reset": - key = _resolve(args.get("key", "")) - if key not in DEFAULT_SETTINGS: - return {"error": f"Unknown setting '{args.get('key')}'.", "exit_code": 1} - if _is_secret(key): - return {"response": f"'{key}' is a credential — reset it in the panel.", "exit_code": 0} - s = load_settings() - s[key] = DEFAULT_SETTINGS[key] - save_settings(s) - return {"response": f"Reset {key} to default ({DEFAULT_SETTINGS[key]}).", "exit_code": 0} - - elif action in ("disable_tool", "enable_tool", "list_tools"): - # Tool-toggle actions. These edit settings.json:disabled_tools - # (the global list read on every chat request) rather than - # prefs.json. Friendly aliases accepted: "shell" -> "bash", - # "search" -> "web_search", "browser" -> "builtin_browser", - # "documents" -> the document tool set, "memory" -> - # manage_memory, etc. - from src.settings import get_setting, save_settings, load_settings - _ALIASES = { - "shell": ["bash"], - "terminal": ["bash"], - "search": ["web_search", "web_fetch"], - "web": ["web_search", "web_fetch"], - "browser": ["builtin_browser"], - "documents": ["create_document", "edit_document", "update_document", "suggest_document"], - "doc": ["create_document", "edit_document", "update_document", "suggest_document"], - "memory": ["manage_memory"], - "skills": ["manage_skills"], - "images": ["generate_image"], - "image": ["generate_image"], - "tasks": ["manage_tasks"], - "notes": ["manage_notes"], - "calendar": ["manage_calendar"], - "email": ["mcp__email__list_emails", "mcp__email__read_email", "mcp__email__send_email"], - "research": ["web_search", "web_fetch"], # research is a per-request flag, not a tool — closest analog - } - - if action == "list_tools": - current = get_setting("disabled_tools", []) or [] - return { - "response": ( - f"Currently disabled: {', '.join(current) if current else '(none)'}.\n" - "Common toggles: shell (bash), search (web_search), browser, documents, " - "memory, skills, images, tasks, notes, calendar, email." - ), - "disabled": list(current), - "exit_code": 0, - } - - tool_name = (args.get("tool") or args.get("name") or "").strip().lower() - if not tool_name: - return {"error": "tool name required (e.g. 'shell', 'search', 'bash')", "exit_code": 1} - targets = _ALIASES.get(tool_name, [tool_name]) - - settings = load_settings() - current = list(settings.get("disabled_tools") or []) - before = set(current) - if action == "disable_tool": - for t in targets: - if t not in current: - current.append(t) - else: # enable_tool - current = [t for t in current if t not in targets] - after = set(current) - settings["disabled_tools"] = current - save_settings(settings) - - verb = "Disabled" if action == "disable_tool" else "Enabled" - changed = sorted(after.symmetric_difference(before)) - return { - "response": ( - f"{verb} {tool_name} ({', '.join(targets)}). " - f"Now disabled: {', '.join(current) if current else '(none)'}." - ), - "changed": changed, - "disabled": list(current), - "exit_code": 0, - } - - else: - return {"error": f"Unknown action: {action}", "exit_code": 1} - except Exception as e: - logger.error(f"manage_settings error: {e}") - return {"error": str(e), "exit_code": 1} - finally: - db.close() - - -# --------------------------------------------------------------------------- -# API call tool -# --------------------------------------------------------------------------- - async def do_api_call(content: str) -> Dict: """Execute an API call to a registered integration.""" from src.integrations import execute_api_call, load_integrations @@ -3452,7 +2669,7 @@ async def do_adopt_served_model(content: str, owner: Optional[str] = None) -> Di host_only = host.split("@", 1)[-1] if host else "localhost" endpoint_url = f"http://{host_only}:{int(port)}/v1" try: - from src.tool_implementations import do_manage_endpoints # avoid forward ref issues + from src.agent_tools.admin_tools import do_manage_endpoints # moved in #3629 except Exception: do_manage_endpoints = None if do_manage_endpoints is not None: diff --git a/src/tool_utils.py b/src/tool_utils.py index cf71e78c5..bb60a1095 100644 --- a/src/tool_utils.py +++ b/src/tool_utils.py @@ -4,6 +4,8 @@ src.constants which imports nothing from src). Adding a project import here will reintroduce the circular dependency that this module exists to break. """ +import json + from src.constants import MAX_OUTPUT_CHARS _mcp_manager = None @@ -37,3 +39,36 @@ def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str: if len(text) > limit: return text[:limit] + f"\n... (truncated, {len(text)} chars total)" return text + + +def _parse_tool_args(content): + """Parse a tool-call argument blob. + + Accepts either a JSON string or an already-decoded dict. Unwraps the + common `{"body": {...}}` envelope that smaller models emit when they + read tool descriptions like "Body is JSON: {...}" literally and + pass `body` as a field name rather than treating it as a noun. + + Returns a dict on success, raises ValueError on bad JSON. + """ + if isinstance(content, str): + try: + args = json.loads(content) if content.strip() else {} + except (json.JSONDecodeError, TypeError) as e: + raise ValueError(str(e)) + elif isinstance(content, dict): + args = content + else: + args = {} + # Unwrap {"body": {...}} envelope, but only if `body` is the sole key + # and points at a dict. We don't want to clobber a legitimate `body` + # field on tools where it's a real arg (e.g. send_email body text). + if ( + isinstance(args, dict) + and len(args) == 1 + and "body" in args + and isinstance(args["body"], dict) + and "action" in args["body"] # extra safety: only unwrap if the inner dict looks like a tool call + ): + args = args["body"] + return args diff --git a/tests/test_admin_tools_registry.py b/tests/test_admin_tools_registry.py new file mode 100644 index 000000000..6d50d5601 --- /dev/null +++ b/tests/test_admin_tools_registry.py @@ -0,0 +1,69 @@ +"""Registry wiring for the config/integration admin tools (#3629). + +manage_endpoints/mcp/webhooks/tokens/settings moved from tool_implementations +into agent_tools.admin_tools. These pin the registration + the single +owner-threading adapter factory, without touching the DB (the do_* impls +themselves are exercised by their own suites). +""" +import asyncio + +from src.agent_tools import TOOL_HANDLERS +from src.agent_tools.admin_tools import ( + ADMIN_TOOL_HANDLERS, _owner_adapter, + do_manage_endpoints, do_manage_mcp, do_manage_webhooks, + do_manage_tokens, do_manage_settings, +) + +_NAMES = ["manage_endpoints", "manage_mcp", "manage_webhooks", "manage_tokens", "manage_settings"] + + +def test_all_registered_in_tool_handlers(): + for n in _NAMES: + assert n in TOOL_HANDLERS, f"{n} missing from TOOL_HANDLERS" + assert n in ADMIN_TOOL_HANDLERS + + +def test_re_exported_from_agent_tools(): + # Back-compat: importers that used `from src.agent_tools import do_manage_*` + # keep working after the move. + from src.agent_tools import ( # noqa: F401 + do_manage_endpoints, do_manage_mcp, do_manage_webhooks, + do_manage_tokens, do_manage_settings, + ) + + +def test_owner_adapter_threads_owner_from_ctx(): + seen = {} + + async def _spy(content, owner): + seen["content"] = content + seen["owner"] = owner + return {"response": "ok", "exit_code": 0} + + handler = _owner_adapter(_spy) + res = asyncio.run(handler('{"action":"list"}', {"owner": "alice", "session_id": "s1"})) + assert res["exit_code"] == 0 + assert seen == {"content": '{"action":"list"}', "owner": "alice"} + + +def test_owner_adapter_defaults_owner_to_none(): + captured = {} + + async def _spy(content, owner): + captured["owner"] = owner + return {"exit_code": 0} + + asyncio.run(_owner_adapter(_spy)("{}", {})) # ctx without owner + assert captured["owner"] is None + + +def test_parse_tool_args_lives_in_tool_utils_single_source(): + # The helper was de-duplicated into tool_utils; admin_tools imports it + # from there rather than carrying its own copy. + from src.tool_utils import _parse_tool_args + from src.agent_tools import admin_tools, document_tools + assert admin_tools._parse_tool_args is _parse_tool_args + assert document_tools._parse_tool_args is _parse_tool_args + assert _parse_tool_args('{"action":"add"}') == {"action": "add"} + # body-envelope unwrap still works + assert _parse_tool_args('{"body":{"action":"x"}}') == {"action": "x"} diff --git a/tests/test_context_budget.py b/tests/test_context_budget.py index eec8d046e..05db1d811 100644 --- a/tests/test_context_budget.py +++ b/tests/test_context_budget.py @@ -86,7 +86,8 @@ def test_default_settings_registers_hard_max_key(): def test_alias_map_registers_friendly_names(): """`manage_settings` should accept 'hard max' and friends.""" from pathlib import Path - src = Path("src/tool_implementations.py").read_text() + # manage_settings (and its alias map) moved to agent_tools/admin_tools.py in #3629. + src = Path("src/agent_tools/admin_tools.py").read_text() assert '"hard max": "agent_input_token_hard_max"' in src assert '"token budget cap": "agent_input_token_hard_max"' in src assert '"input budget cap": "agent_input_token_hard_max"' in src diff --git a/tests/test_manage_mcp_command_allowlist.py b/tests/test_manage_mcp_command_allowlist.py index 2d1c49e4b..8daea2b0b 100644 --- a/tests/test_manage_mcp_command_allowlist.py +++ b/tests/test_manage_mcp_command_allowlist.py @@ -26,8 +26,8 @@ clear_fake_database_modules() import core.database as cdb from core.database import McpServer -import src.tool_implementations as ti -from src.tool_implementations import _validate_mcp_command +import src.agent_tools.admin_tools as ti # do_manage_mcp/get_mcp_manager moved here in the registry migration +from src.agent_tools.admin_tools import _validate_mcp_command _TS, _ENGINE, _TMPDB = make_temp_sqlite(cdb.Base.metadata) diff --git a/tests/test_manage_settings_token_budget.py b/tests/test_manage_settings_token_budget.py index 31fce6dba..82cd01401 100644 --- a/tests/test_manage_settings_token_budget.py +++ b/tests/test_manage_settings_token_budget.py @@ -3,7 +3,7 @@ import asyncio import json import src.settings as settings_mod -from src.tool_implementations import do_manage_settings +from src.agent_tools.admin_tools import do_manage_settings def test_set_token_budget_is_not_refused_as_secret(monkeypatch): diff --git a/tests/test_mcp_reconnect_args.py b/tests/test_mcp_reconnect_args.py index b2a1e8b4f..8fbf61218 100644 --- a/tests/test_mcp_reconnect_args.py +++ b/tests/test_mcp_reconnect_args.py @@ -8,7 +8,7 @@ from types import SimpleNamespace def test_reconnect_passes_full_server_config(): """do_manage_mcp reconnect must pass name/transport/command/args/env/url.""" - from src.tool_implementations import do_manage_mcp + from src.agent_tools.admin_tools import do_manage_mcp fake_mcp = MagicMock() fake_mcp.disconnect_server = AsyncMock() @@ -28,7 +28,7 @@ def test_reconnect_passes_full_server_config(): fake_db = MagicMock() fake_db.query.return_value.filter.return_value.first.return_value = fake_srv - with patch("src.tool_implementations.get_mcp_manager", return_value=fake_mcp), \ + with patch("src.agent_tools.admin_tools.get_mcp_manager", return_value=fake_mcp), \ patch("core.database.SessionLocal", return_value=fake_db): result = asyncio.run(do_manage_mcp( json.dumps({"action": "reconnect", "server_id": "srv-123"}) diff --git a/tests/test_review_regressions.py b/tests/test_review_regressions.py index 20e201f5b..c10aaba85 100644 --- a/tests/test_review_regressions.py +++ b/tests/test_review_regressions.py @@ -821,7 +821,7 @@ async def test_webhook_tool_reuses_private_url_validation(): monkeypatch.setitem(sys.modules, "core.database", fake_core_db) monkeypatch.setitem(sys.modules, "src.database", fake_src_db) - from src.tool_implementations import do_manage_webhooks + from src.agent_tools.admin_tools import do_manage_webhooks try: result = await do_manage_webhooks(