mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-19 11:15:24 -04:00
fix(cookbook): validate agent SSH targets (#4429)
This commit is contained in:
@@ -12,12 +12,24 @@ import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from src.constants import MAX_READ_CHARS, DEEP_RESEARCH_DIR, VAULT_FILE
|
||||
from src.tool_utils import get_mcp_manager
|
||||
from core.constants import internal_api_base
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _string_arg(value: Any) -> str:
|
||||
return "" if value is None else str(value).strip()
|
||||
|
||||
|
||||
def _validate_cookbook_ssh_target(remote_host: Any, ssh_port: Any = "") -> tuple[str, str]:
|
||||
remote = validate_remote_host(_string_arg(remote_host) or None) or ""
|
||||
sport = validate_ssh_port(_string_arg(ssh_port) or None) or ""
|
||||
return remote, sport
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active email state
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -3025,6 +3037,10 @@ async def _cookbook_kill_session(session_id: str, *, remote_host: str = "",
|
||||
break
|
||||
|
||||
if remote:
|
||||
try:
|
||||
remote, sport = _validate_cookbook_ssh_target(remote, sport)
|
||||
except HTTPException as e:
|
||||
return {"error": str(getattr(e, "detail", e)), "exit_code": 1}
|
||||
_pf = f"-p {shlex.quote(str(sport))} " if sport and str(sport) != "22" else ""
|
||||
cmd = (
|
||||
f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no "
|
||||
@@ -3113,8 +3129,8 @@ async def do_tail_serve_output(content: str, owner: Optional[str] = None) -> Dic
|
||||
tail = 400
|
||||
tail = max(20, min(tail, 4000))
|
||||
headers = _internal_headers()
|
||||
remote = (args.get("remote_host") or args.get("host") or "").strip()
|
||||
sport = (args.get("ssh_port") or "").strip()
|
||||
remote = _string_arg(args.get("remote_host") or args.get("host"))
|
||||
sport = _string_arg(args.get("ssh_port"))
|
||||
# Resolve host from cookbook state if caller didn't pass one — same
|
||||
# lookup _cookbook_kill_session uses.
|
||||
if not remote:
|
||||
@@ -3132,6 +3148,12 @@ async def do_tail_serve_output(content: str, owner: Optional[str] = None) -> Dic
|
||||
if not sport:
|
||||
sport = t.get("sshPort") or ""
|
||||
break
|
||||
if remote:
|
||||
try:
|
||||
remote, sport = _validate_cookbook_ssh_target(remote, sport)
|
||||
except HTTPException as e:
|
||||
return {"error": str(getattr(e, "detail", e)), "exit_code": 1}
|
||||
|
||||
# Prefer the persisted /tmp/odysseus-tmux/SESSION.log file over the
|
||||
# live tmux pane. The pane is what the user would see scrolling on
|
||||
# their screen — including the post-crash neofetch banner and the
|
||||
@@ -3309,7 +3331,7 @@ async def do_adopt_served_model(content: str, owner: Optional[str] = None) -> Di
|
||||
except ValueError:
|
||||
return {"error": "Invalid JSON arguments", "exit_code": 1}
|
||||
|
||||
host = (args.get("host") or args.get("remote_host") or "").strip()
|
||||
host = _string_arg(args.get("host") or args.get("remote_host"))
|
||||
sess = (args.get("tmux_session") or args.get("session_id") or "").strip()
|
||||
model = (args.get("model") or args.get("repo_id") or "").strip()
|
||||
port = args.get("port") or 8000
|
||||
@@ -3320,6 +3342,12 @@ async def do_adopt_served_model(content: str, owner: Optional[str] = None) -> Di
|
||||
return {"error": "tmux_session and model are required", "exit_code": 1}
|
||||
|
||||
# Verify tmux session exists on the target host
|
||||
if host:
|
||||
try:
|
||||
host, _ = _validate_cookbook_ssh_target(host)
|
||||
except HTTPException as e:
|
||||
return {"error": str(getattr(e, "detail", e)), "exit_code": 1}
|
||||
|
||||
headers = _internal_headers()
|
||||
if host:
|
||||
check = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {shlex.quote(host)} 'tmux has-session -t {shlex.quote(sess)} 2>&1'"
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from src import tool_implementations as tools
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, data=None, status_code=200):
|
||||
self._data = data or {}
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(self._data)
|
||||
self.content = self.text.encode("utf-8")
|
||||
self.headers = {"content-type": "application/json"}
|
||||
|
||||
def json(self):
|
||||
return self._data
|
||||
|
||||
|
||||
def _install_httpx_client(monkeypatch, *, state=None, posts=None):
|
||||
import httpx
|
||||
|
||||
posts = posts if posts is not None else []
|
||||
state = state if state is not None else {"tasks": []}
|
||||
|
||||
class FakeAsyncClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
return FakeResponse(state)
|
||||
|
||||
async def post(self, url, json=None, **kwargs):
|
||||
posts.append((url, json, kwargs))
|
||||
return FakeResponse({"stdout": "", "stderr": "", "exit_code": 0})
|
||||
|
||||
monkeypatch.setattr(httpx, "AsyncClient", FakeAsyncClient)
|
||||
return posts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_served_model_rejects_invalid_remote_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(monkeypatch)
|
||||
|
||||
result = await tools.do_stop_served_model(
|
||||
json.dumps({"session_id": "serve-abc123", "remote_host": "-bad"})
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_served_model_rejects_invalid_state_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(
|
||||
monkeypatch,
|
||||
state={
|
||||
"tasks": [
|
||||
{
|
||||
"sessionId": "serve-abc123",
|
||||
"remoteHost": "-bad",
|
||||
"sshPort": "22",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
result = await tools.do_stop_served_model(
|
||||
json.dumps({"session_id": "serve-abc123"})
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_served_model_rejects_invalid_ssh_port_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(monkeypatch)
|
||||
|
||||
result = await tools.do_stop_served_model(
|
||||
json.dumps(
|
||||
{
|
||||
"session_id": "serve-abc123",
|
||||
"remote_host": "gpu-box",
|
||||
"ssh_port": "not-a-port",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid ssh_port" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_served_model_uses_validated_remote_target(monkeypatch):
|
||||
posts = _install_httpx_client(monkeypatch)
|
||||
|
||||
result = await tools.do_stop_served_model(
|
||||
json.dumps(
|
||||
{
|
||||
"session_id": "serve-abc123",
|
||||
"remote_host": "user@gpu-box",
|
||||
"ssh_port": 2222,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 0
|
||||
assert len(posts) == 1
|
||||
command = posts[0][1]["command"]
|
||||
assert "ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no" in command
|
||||
assert "-p 2222 user@gpu-box" in command
|
||||
assert "tmux kill-session -t serve-abc123" in command
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_download_rejects_invalid_remote_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(monkeypatch)
|
||||
|
||||
result = await tools.do_cancel_download(
|
||||
json.dumps({"session_id": "cookbook-abc123", "remote_host": "-bad"})
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_download_rejects_invalid_state_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(
|
||||
monkeypatch,
|
||||
state={
|
||||
"tasks": [
|
||||
{
|
||||
"sessionId": "cookbook-abc123",
|
||||
"remoteHost": "-bad",
|
||||
"sshPort": "22",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
result = await tools.do_cancel_download(
|
||||
json.dumps({"session_id": "cookbook-abc123"})
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_serve_output_rejects_invalid_remote_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(monkeypatch)
|
||||
|
||||
result = await tools.do_tail_serve_output(
|
||||
json.dumps({"session_id": "serve-abc123", "remote_host": "-bad"})
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tail_serve_output_rejects_invalid_state_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(
|
||||
monkeypatch,
|
||||
state={
|
||||
"tasks": [
|
||||
{
|
||||
"sessionId": "serve-abc123",
|
||||
"remoteHost": "-bad",
|
||||
"sshPort": "22",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
result = await tools.do_tail_serve_output(
|
||||
json.dumps({"session_id": "serve-abc123"})
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adopt_served_model_rejects_invalid_remote_host_before_shell(monkeypatch):
|
||||
posts = _install_httpx_client(monkeypatch)
|
||||
|
||||
result = await tools.do_adopt_served_model(
|
||||
json.dumps(
|
||||
{
|
||||
"tmux_session": "serve_abc123",
|
||||
"model": "org/model",
|
||||
"host": "-bad",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["exit_code"] == 1
|
||||
assert "Invalid remote_host" in result["error"]
|
||||
assert posts == []
|
||||
Reference in New Issue
Block a user