fix(hwfit): validate remote SSH detection targets (#3718)

This commit is contained in:
RaresKeY
2026-06-11 01:43:49 +03:00
committed by GitHub
parent 218b9ecbc8
commit d1a5a7d680
8 changed files with 164 additions and 62 deletions
+4
View File
@@ -366,6 +366,10 @@ def _ssh_exec_argv(
strict_host_key_checking: bool | None = None,
) -> list[str]:
"""Build a consistent ssh argv for remote command execution."""
remote_value = str(remote or "").strip()
remote_host = remote_value.rsplit("@", 1)[-1]
if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"):
raise ValueError("Invalid SSH remote host")
argv = ["ssh"]
if connect_timeout is not None:
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
+31
View File
@@ -0,0 +1,31 @@
import re
from fastapi import HTTPException
_REMOTE_HOST_RE = re.compile(
r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$"
)
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
def validate_remote_host(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _REMOTE_HOST_RE.match(v):
raise HTTPException(
400,
"Invalid remote_host — must be host or user@host, no SSH option syntax",
)
return v
def validate_ssh_port(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _SSH_PORT_RE.fullmatch(str(v)):
raise HTTPException(400, "Invalid ssh_port")
port = int(v)
if port < 1 or port > 65535:
raise HTTPException(400, "Invalid ssh_port")
return str(port)
+1 -23
View File
@@ -11,6 +11,7 @@ import shlex
from fastapi import HTTPException
from pydantic import BaseModel
from routes._validators import validate_remote_host, validate_ssh_port
from core.platform_compat import _ssh_exec_argv
logger = logging.getLogger(__name__)
@@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
# Include pattern is a glob: allow typical safe glyphs only.
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
# Remote host: either `user@host` or plain `host` (alias is allowed), where host
# is a safe DNS-like token or a short SSH config alias.
_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$")
# HF tokens and API tokens are url-safe base64-like.
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
# Anything beyond plain alphanumerics + dash + underscore could break out
# of the shell/PowerShell contexts the value lands in.
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
# A download target directory. Absolute or ~-relative path; safe path glyphs
# only (no quotes or shell metacharacters). Spaces are allowed because command
@@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None:
return v
def _validate_remote_host(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _REMOTE_HOST_RE.match(v):
raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax")
return v
def _validate_token(v: str | None) -> str | None:
if v is None or v == "":
return None
@@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None:
return v
def _validate_ssh_port(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _SSH_PORT_RE.fullmatch(str(v)):
raise HTTPException(400, "Invalid ssh_port")
port = int(v)
if port < 1 or port > 65535:
raise HTTPException(400, "Invalid ssh_port")
return str(port)
def _validate_gpus(v: str | None) -> str | None:
if v is None or v == "":
return None
+43 -31
View File
@@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE
from pydantic import BaseModel
from core.middleware import require_admin
from routes._validators import validate_remote_host, validate_ssh_port
from core.platform_compat import (
IS_WINDOWS,
detached_popen_kwargs,
@@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR
logger = logging.getLogger(__name__)
from routes.cookbook_helpers import (
_SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE,
_validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token,
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
_SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token,
_validate_local_dir, _validate_gpus, _shell_path,
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
@@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter:
else:
_validate_repo_id(req.repo_id)
_validate_include(req.include)
_validate_remote_host(req.remote_host)
req.ssh_port = _validate_ssh_port(req.ssh_port)
validate_remote_host(req.remote_host)
req.ssh_port = validate_ssh_port(req.ssh_port)
req.local_dir = _validate_local_dir(req.local_dir)
req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token())
_validate_token(req.hf_token)
@@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter:
# Validate shell-bound inputs, matching the sibling list_gpus endpoint —
# `host`/`ssh_port` are interpolated into an ssh command below, so an
# unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection.
host = _validate_remote_host(host)
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
raise HTTPException(400, "Invalid ssh_port")
host = validate_remote_host(host)
ssh_port = validate_ssh_port(ssh_port)
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
model_dirs = []
@@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter:
# listening" check without requiring ss/netstat/nmap.
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
if ssh_port and str(ssh_port) != "22":
if not _SSH_PORT_RE.match(str(ssh_port)):
try:
ssh_port = validate_ssh_port(ssh_port)
except HTTPException:
return None
ssh_base.extend(["-p", str(ssh_port)])
host_arg = remote
if not _REMOTE_HOST_RE.match(host_arg):
try:
host_arg = validate_remote_host(remote)
except HTTPException:
return None
if not host_arg:
return None
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
script = (
@@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter:
"""
require_admin(request)
# Defence-in-depth: reject values that could break out of shell contexts.
_validate_remote_host(req.remote_host)
req.ssh_port = _validate_ssh_port(req.ssh_port)
validate_remote_host(req.remote_host)
req.ssh_port = validate_ssh_port(req.ssh_port)
req.gpus = _validate_gpus(req.gpus)
req.hf_token = req.hf_token or _load_stored_hf_token()
_validate_token(req.hf_token)
@@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter:
async def server_setup(request: Request, req: SetupRequest):
"""Install required dependencies on a remote server via SSH."""
require_admin(request)
host = _validate_remote_host(req.host)
host = validate_remote_host(req.host)
if not host:
raise HTTPException(400, "host is required")
port = req.ssh_port
if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port):
raise HTTPException(400, "Invalid ssh_port")
port = validate_ssh_port(port)
pf = f"-p {port} " if port and port != "22" else ""
# Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux
@@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter:
`busy` is True when free_mb/total_mb < 0.5.
"""
require_admin(request)
host = _validate_remote_host(host)
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
raise HTTPException(400, "Invalid ssh_port")
host = validate_remote_host(host)
ssh_port = validate_ssh_port(ssh_port)
gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits"
nvidia_error = None
try:
@@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter:
sig = (req.signal or "TERM").upper()
if sig not in ("TERM", "KILL", "INT"):
raise HTTPException(400, "signal must be TERM, KILL, or INT")
host = _validate_remote_host(req.host)
if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port):
raise HTTPException(400, "Invalid ssh_port")
host = validate_remote_host(req.host)
req.ssh_port = validate_ssh_port(req.ssh_port)
kill_cmd = f"kill -{sig} {req.pid}"
try:
if host:
@@ -2382,14 +2383,19 @@ def setup_cookbook_routes() -> APIRouter:
host = (srv.get("host") or "").strip()
if not host:
continue # local-only entry; the /proc scan handles it
if not _REMOTE_HOST_RE.match(host):
try:
host = validate_remote_host(host)
except HTTPException:
continue
sport = str(srv.get("port") or "").strip()
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
if sport and sport != "22":
if not _SSH_PORT_RE.match(sport):
try:
sport = validate_ssh_port(sport)
except HTTPException:
continue
ssh_base.extend(["-p", sport])
if sport != "22":
ssh_base.extend(["-p", sport])
try:
ls = subprocess.run(
@@ -2743,12 +2749,18 @@ def setup_cookbook_routes() -> APIRouter:
if not _SESSION_ID_RE.match(session_id):
logger.warning(f"Skipping task with unsafe session_id: {session_id!r}")
continue
if remote and not _REMOTE_HOST_RE.match(remote):
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
continue
if _tport and not _SSH_PORT_RE.match(str(_tport)):
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
continue
if remote:
try:
remote = validate_remote_host(remote)
except HTTPException:
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
continue
if _tport:
try:
_tport = validate_ssh_port(str(_tport))
except HTTPException:
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
continue
if task_platform == "windows" and remote:
# Windows: check PID file + Get-Process, read log tail
sd = "$env:TEMP\\odysseus-sessions"
+15 -1
View File
@@ -1,7 +1,9 @@
import re
from copy import deepcopy
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from routes._validators import validate_remote_host, validate_ssh_port
# Backends the manual hardware simulator accepts. Must stay a subset of what
@@ -11,6 +13,14 @@ from fastapi import APIRouter
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
host_value = validate_remote_host(host) or ""
port_value = validate_ssh_port(ssh_port) or ""
if port_value and not host_value:
raise HTTPException(400, "ssh_port requires host")
return host_value, port_value
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
"""Manual hardware is a "what if I had this setup" simulator —
REPLACES the detected hardware entirely instead of adding to it.
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
"""Detect and return current system hardware info. Pass host=user@server for remote.
fresh=true bypasses the per-host cache (the Rescan button)."""
from services.hwfit.hardware import detect_system
host, ssh_port = _validate_detection_target(host, ssh_port)
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
@router.get("/models")
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
from services.hwfit.hardware import detect_system
from services.hwfit.fit import rank_models
from services.hwfit.models import get_models, model_catalog_path
host, ssh_port = _validate_detection_target(host, ssh_port)
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
@@ -229,6 +241,7 @@ def setup_hwfit_routes():
from services.hwfit.hardware import detect_system
from services.hwfit.models import get_models
from services.hwfit.profiles import compute_serve_profiles
host, ssh_port = _validate_detection_target(host, ssh_port)
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
if system.get("error"):
return {"system": system, "profiles": [], "error": system["error"]}
@@ -279,6 +292,7 @@ def setup_hwfit_routes():
"""Rank image generation models against detected hardware."""
from services.hwfit.hardware import detect_system
from services.hwfit.image_models import rank_image_models
host, ssh_port = _validate_detection_target(host, ssh_port)
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
-7
View File
@@ -26,7 +26,6 @@ from routes.cookbook_helpers import (
_validate_repo_id,
_validate_serve_cmd,
_validate_serve_model_id,
_validate_ssh_port,
_shell_path,
run_ssh_command_async,
)
@@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path():
)
def test_validate_ssh_port_rejects_shell_payload():
with pytest.raises(HTTPException):
_validate_ssh_port("22; touch /tmp/pwned")
assert _validate_ssh_port("2222") == "2222"
def test_validate_local_dir_accepts_external_drive_paths_with_spaces():
path = "/Volumes/T7 2TB/AI Models/llamacpp"
+47
View File
@@ -0,0 +1,47 @@
import pytest
from fastapi import HTTPException
from core.platform_compat import _ssh_exec_argv
from routes.hwfit_routes import setup_hwfit_routes
def _endpoint(path: str):
router = setup_hwfit_routes()
for route in router.routes:
if getattr(route, "path", "") == path:
return route.endpoint
raise AssertionError(f"{path} route not found")
@pytest.mark.parametrize(
"path,kwargs",
[
("/api/hwfit/system", {}),
("/api/hwfit/models", {"limit": 1}),
("/api/hwfit/profiles", {"model": "demo"}),
("/api/hwfit/image-models", {}),
],
)
def test_hwfit_routes_reject_ssh_option_host(path, kwargs):
endpoint = _endpoint(path)
with pytest.raises(HTTPException) as exc:
endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs)
assert exc.value.status_code == 400
def test_hwfit_routes_reject_port_without_host():
endpoint = _endpoint("/api/hwfit/system")
with pytest.raises(HTTPException) as exc:
endpoint(host="", ssh_port="2222")
assert exc.value.status_code == 400
def test_ssh_argv_rejects_option_shaped_remote():
with pytest.raises(ValueError):
_ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true")
with pytest.raises(ValueError):
_ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true")
+23
View File
@@ -0,0 +1,23 @@
import pytest
from fastapi import HTTPException
from routes._validators import validate_remote_host, validate_ssh_port
def test_validate_ssh_port_rejects_shell_payload():
for port in ["22;id", "$(id)", "-p 22", "0", "65536"]:
with pytest.raises(HTTPException):
validate_ssh_port(port)
assert validate_ssh_port("2222") == "2222"
def test_validate_remote_host_rejects_ssh_option_shape():
for host in [
"-oProxyCommand=sh",
"alice@-oProxyCommand=sh",
"--",
"-p2222",
]:
with pytest.raises(HTTPException):
validate_remote_host(host)
assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1"