From d1a5a7d680e5b06249ad19a8790a9347d78961d9 Mon Sep 17 00:00:00 2001 From: RaresKeY <158580472+RaresKeY@users.noreply.github.com> Date: Thu, 11 Jun 2026 01:43:49 +0300 Subject: [PATCH] fix(hwfit): validate remote SSH detection targets (#3718) --- core/platform_compat.py | 4 ++ routes/_validators.py | 31 +++++++++++ routes/cookbook_helpers.py | 24 +-------- routes/cookbook_routes.py | 74 ++++++++++++++++----------- routes/hwfit_routes.py | 16 +++++- tests/test_cookbook_helpers.py | 7 --- tests/test_hwfit_remote_validation.py | 47 +++++++++++++++++ tests/test_route_validators.py | 23 +++++++++ 8 files changed, 164 insertions(+), 62 deletions(-) create mode 100644 routes/_validators.py create mode 100644 tests/test_hwfit_remote_validation.py create mode 100644 tests/test_route_validators.py diff --git a/core/platform_compat.py b/core/platform_compat.py index 3eda4a107..b3b157111 100644 --- a/core/platform_compat.py +++ b/core/platform_compat.py @@ -366,6 +366,10 @@ def _ssh_exec_argv( strict_host_key_checking: bool | None = None, ) -> list[str]: """Build a consistent ssh argv for remote command execution.""" + remote_value = str(remote or "").strip() + remote_host = remote_value.rsplit("@", 1)[-1] + if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"): + raise ValueError("Invalid SSH remote host") argv = ["ssh"] if connect_timeout is not None: argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"]) diff --git a/routes/_validators.py b/routes/_validators.py new file mode 100644 index 000000000..aa4cf00cc --- /dev/null +++ b/routes/_validators.py @@ -0,0 +1,31 @@ +import re + +from fastapi import HTTPException + + +_REMOTE_HOST_RE = re.compile( + r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$" +) +_SSH_PORT_RE = re.compile(r"^\d{1,5}$") + + +def validate_remote_host(v: str | None) -> str | None: + if v is None or v == "": + return None + if not _REMOTE_HOST_RE.match(v): + raise HTTPException( + 400, + "Invalid remote_host — must be host or user@host, no SSH option syntax", + ) + return v + + +def validate_ssh_port(v: str | None) -> str | None: + if v is None or v == "": + return None + if not _SSH_PORT_RE.fullmatch(str(v)): + raise HTTPException(400, "Invalid ssh_port") + port = int(v) + if port < 1 or port > 65535: + raise HTTPException(400, "Invalid ssh_port") + return str(port) diff --git a/routes/cookbook_helpers.py b/routes/cookbook_helpers.py index 709245287..53bdde80e 100644 --- a/routes/cookbook_helpers.py +++ b/routes/cookbook_helpers.py @@ -11,6 +11,7 @@ import shlex from fastapi import HTTPException from pydantic import BaseModel +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import _ssh_exec_argv logger = logging.getLogger(__name__) @@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") _OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$") # Include pattern is a glob: allow typical safe glyphs only. _INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$") -# Remote host: either `user@host` or plain `host` (alias is allowed), where host -# is a safe DNS-like token or a short SSH config alias. -_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$") # HF tokens and API tokens are url-safe base64-like. _TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$") # Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef". # Anything beyond plain alphanumerics + dash + underscore could break out # of the shell/PowerShell contexts the value lands in. _SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") -_SSH_PORT_RE = re.compile(r"^\d{1,5}$") _GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$") # A download target directory. Absolute or ~-relative path; safe path glyphs # only (no quotes or shell metacharacters). Spaces are allowed because command @@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None: return v -def _validate_remote_host(v: str | None) -> str | None: - if v is None or v == "": - return None - if not _REMOTE_HOST_RE.match(v): - raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax") - return v - - def _validate_token(v: str | None) -> str | None: if v is None or v == "": return None @@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None: return v -def _validate_ssh_port(v: str | None) -> str | None: - if v is None or v == "": - return None - if not _SSH_PORT_RE.fullmatch(str(v)): - raise HTTPException(400, "Invalid ssh_port") - port = int(v) - if port < 1 or port > 65535: - raise HTTPException(400, "Invalid ssh_port") - return str(port) - - def _validate_gpus(v: str | None) -> str | None: if v is None or v == "": return None diff --git a/routes/cookbook_routes.py b/routes/cookbook_routes.py index 4a4764232..36f98aeae 100644 --- a/routes/cookbook_routes.py +++ b/routes/cookbook_routes.py @@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE from pydantic import BaseModel from core.middleware import require_admin +from routes._validators import validate_remote_host, validate_ssh_port from core.platform_compat import ( IS_WINDOWS, detached_popen_kwargs, @@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR logger = logging.getLogger(__name__) from routes.cookbook_helpers import ( - _SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE, - _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token, - _validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path, + _SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token, + _validate_local_dir, _validate_gpus, _shell_path, _ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase, _safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines, _append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script, @@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter: else: _validate_repo_id(req.repo_id) _validate_include(req.include) - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.local_dir = _validate_local_dir(req.local_dir) req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token()) _validate_token(req.hf_token) @@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter: # Validate shell-bound inputs, matching the sibling list_gpus endpoint — # `host`/`ssh_port` are interpolated into an ssh command below, so an # unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection. - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True) model_dirs = [] @@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter: # listening" check without requiring ss/netstat/nmap. ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if ssh_port and str(ssh_port) != "22": - if not _SSH_PORT_RE.match(str(ssh_port)): + try: + ssh_port = validate_ssh_port(ssh_port) + except HTTPException: return None ssh_base.extend(["-p", str(ssh_port)]) - host_arg = remote - if not _REMOTE_HOST_RE.match(host_arg): + try: + host_arg = validate_remote_host(remote) + except HTTPException: + return None + if not host_arg: return None probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1)) script = ( @@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter: """ require_admin(request) # Defence-in-depth: reject values that could break out of shell contexts. - _validate_remote_host(req.remote_host) - req.ssh_port = _validate_ssh_port(req.ssh_port) + validate_remote_host(req.remote_host) + req.ssh_port = validate_ssh_port(req.ssh_port) req.gpus = _validate_gpus(req.gpus) req.hf_token = req.hf_token or _load_stored_hf_token() _validate_token(req.hf_token) @@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter: async def server_setup(request: Request, req: SetupRequest): """Install required dependencies on a remote server via SSH.""" require_admin(request) - host = _validate_remote_host(req.host) + host = validate_remote_host(req.host) if not host: raise HTTPException(400, "host is required") port = req.ssh_port - if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port): - raise HTTPException(400, "Invalid ssh_port") + port = validate_ssh_port(port) pf = f"-p {port} " if port and port != "22" else "" # Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux @@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter: `busy` is True when free_mb/total_mb < 0.5. """ require_admin(request) - host = _validate_remote_host(host) - if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(host) + ssh_port = validate_ssh_port(ssh_port) gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits" nvidia_error = None try: @@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter: sig = (req.signal or "TERM").upper() if sig not in ("TERM", "KILL", "INT"): raise HTTPException(400, "signal must be TERM, KILL, or INT") - host = _validate_remote_host(req.host) - if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port): - raise HTTPException(400, "Invalid ssh_port") + host = validate_remote_host(req.host) + req.ssh_port = validate_ssh_port(req.ssh_port) kill_cmd = f"kill -{sig} {req.pid}" try: if host: @@ -2382,14 +2383,19 @@ def setup_cookbook_routes() -> APIRouter: host = (srv.get("host") or "").strip() if not host: continue # local-only entry; the /proc scan handles it - if not _REMOTE_HOST_RE.match(host): + try: + host = validate_remote_host(host) + except HTTPException: continue sport = str(srv.get("port") or "").strip() ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"] if sport and sport != "22": - if not _SSH_PORT_RE.match(sport): + try: + sport = validate_ssh_port(sport) + except HTTPException: continue - ssh_base.extend(["-p", sport]) + if sport != "22": + ssh_base.extend(["-p", sport]) try: ls = subprocess.run( @@ -2743,12 +2749,18 @@ def setup_cookbook_routes() -> APIRouter: if not _SESSION_ID_RE.match(session_id): logger.warning(f"Skipping task with unsafe session_id: {session_id!r}") continue - if remote and not _REMOTE_HOST_RE.match(remote): - logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") - continue - if _tport and not _SSH_PORT_RE.match(str(_tport)): - logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") - continue + if remote: + try: + remote = validate_remote_host(remote) + except HTTPException: + logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}") + continue + if _tport: + try: + _tport = validate_ssh_port(str(_tport)) + except HTTPException: + logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}") + continue if task_platform == "windows" and remote: # Windows: check PID file + Get-Process, read log tail sd = "$env:TEMP\\odysseus-sessions" diff --git a/routes/hwfit_routes.py b/routes/hwfit_routes.py index eb408ac9d..564c3a03c 100644 --- a/routes/hwfit_routes.py +++ b/routes/hwfit_routes.py @@ -1,7 +1,9 @@ import re from copy import deepcopy -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port # Backends the manual hardware simulator accepts. Must stay a subset of what @@ -11,6 +13,14 @@ from fastapi import APIRouter _MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"} +def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]: + host_value = validate_remote_host(host) or "" + port_value = validate_ssh_port(ssh_port) or "" + if port_value and not host_value: + raise HTTPException(400, "ssh_port requires host") + return host_value, port_value + + def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""): """Manual hardware is a "what if I had this setup" simulator — REPLACES the detected hardware entirely instead of adding to it. @@ -105,6 +115,7 @@ def setup_hwfit_routes(): """Detect and return current system hardware info. Pass host=user@server for remote. fresh=true bypasses the per-host cache (the Rescan button).""" from services.hwfit.hardware import detect_system + host, ssh_port = _validate_detection_target(host, ssh_port) return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) @router.get("/models") @@ -118,6 +129,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.fit import rank_models from services.hwfit.models import get_models, model_catalog_path + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} @@ -229,6 +241,7 @@ def setup_hwfit_routes(): from services.hwfit.hardware import detect_system from services.hwfit.models import get_models from services.hwfit.profiles import compute_serve_profiles + host, ssh_port = _validate_detection_target(host, ssh_port) system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh) if system.get("error"): return {"system": system, "profiles": [], "error": system["error"]} @@ -279,6 +292,7 @@ def setup_hwfit_routes(): """Rank image generation models against detected hardware.""" from services.hwfit.hardware import detect_system from services.hwfit.image_models import rank_image_models + host, ssh_port = _validate_detection_target(host, ssh_port) system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)) if system.get("error"): return {"system": system, "models": [], "error": system["error"]} diff --git a/tests/test_cookbook_helpers.py b/tests/test_cookbook_helpers.py index acc001812..779b48e3c 100644 --- a/tests/test_cookbook_helpers.py +++ b/tests/test_cookbook_helpers.py @@ -26,7 +26,6 @@ from routes.cookbook_helpers import ( _validate_repo_id, _validate_serve_cmd, _validate_serve_model_id, - _validate_ssh_port, _shell_path, run_ssh_command_async, ) @@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path(): ) -def test_validate_ssh_port_rejects_shell_payload(): - with pytest.raises(HTTPException): - _validate_ssh_port("22; touch /tmp/pwned") - assert _validate_ssh_port("2222") == "2222" - - def test_validate_local_dir_accepts_external_drive_paths_with_spaces(): path = "/Volumes/T7 2TB/AI Models/llamacpp" diff --git a/tests/test_hwfit_remote_validation.py b/tests/test_hwfit_remote_validation.py new file mode 100644 index 000000000..aee2aaadb --- /dev/null +++ b/tests/test_hwfit_remote_validation.py @@ -0,0 +1,47 @@ +import pytest +from fastapi import HTTPException + +from core.platform_compat import _ssh_exec_argv +from routes.hwfit_routes import setup_hwfit_routes + + +def _endpoint(path: str): + router = setup_hwfit_routes() + for route in router.routes: + if getattr(route, "path", "") == path: + return route.endpoint + raise AssertionError(f"{path} route not found") + + +@pytest.mark.parametrize( + "path,kwargs", + [ + ("/api/hwfit/system", {}), + ("/api/hwfit/models", {"limit": 1}), + ("/api/hwfit/profiles", {"model": "demo"}), + ("/api/hwfit/image-models", {}), + ], +) +def test_hwfit_routes_reject_ssh_option_host(path, kwargs): + endpoint = _endpoint(path) + + with pytest.raises(HTTPException) as exc: + endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs) + + assert exc.value.status_code == 400 + + +def test_hwfit_routes_reject_port_without_host(): + endpoint = _endpoint("/api/hwfit/system") + + with pytest.raises(HTTPException) as exc: + endpoint(host="", ssh_port="2222") + + assert exc.value.status_code == 400 + + +def test_ssh_argv_rejects_option_shaped_remote(): + with pytest.raises(ValueError): + _ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true") + with pytest.raises(ValueError): + _ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true") diff --git a/tests/test_route_validators.py b/tests/test_route_validators.py new file mode 100644 index 000000000..a6fc07a98 --- /dev/null +++ b/tests/test_route_validators.py @@ -0,0 +1,23 @@ +import pytest +from fastapi import HTTPException + +from routes._validators import validate_remote_host, validate_ssh_port + + +def test_validate_ssh_port_rejects_shell_payload(): + for port in ["22;id", "$(id)", "-p 22", "0", "65536"]: + with pytest.raises(HTTPException): + validate_ssh_port(port) + assert validate_ssh_port("2222") == "2222" + + +def test_validate_remote_host_rejects_ssh_option_shape(): + for host in [ + "-oProxyCommand=sh", + "alice@-oProxyCommand=sh", + "--", + "-p2222", + ]: + with pytest.raises(HTTPException): + validate_remote_host(host) + assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1"