fix(hwfit): validate remote SSH detection targets (#3718)

This commit is contained in:
RaresKeY
2026-06-11 01:43:49 +03:00
committed by GitHub
parent 218b9ecbc8
commit d1a5a7d680
8 changed files with 164 additions and 62 deletions
+15 -1
View File
@@ -1,7 +1,9 @@
import re
from copy import deepcopy
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from routes._validators import validate_remote_host, validate_ssh_port
# Backends the manual hardware simulator accepts. Must stay a subset of what
@@ -11,6 +13,14 @@ from fastapi import APIRouter
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
host_value = validate_remote_host(host) or ""
port_value = validate_ssh_port(ssh_port) or ""
if port_value and not host_value:
raise HTTPException(400, "ssh_port requires host")
return host_value, port_value
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
"""Manual hardware is a "what if I had this setup" simulator —
REPLACES the detected hardware entirely instead of adding to it.
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
"""Detect and return current system hardware info. Pass host=user@server for remote.
fresh=true bypasses the per-host cache (the Rescan button)."""
from services.hwfit.hardware import detect_system
host, ssh_port = _validate_detection_target(host, ssh_port)
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
@router.get("/models")
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
from services.hwfit.hardware import detect_system
from services.hwfit.fit import rank_models
from services.hwfit.models import get_models, model_catalog_path
host, ssh_port = _validate_detection_target(host, ssh_port)
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
@@ -229,6 +241,7 @@ def setup_hwfit_routes():
from services.hwfit.hardware import detect_system
from services.hwfit.models import get_models
from services.hwfit.profiles import compute_serve_profiles
host, ssh_port = _validate_detection_target(host, ssh_port)
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
if system.get("error"):
return {"system": system, "profiles": [], "error": system["error"]}
@@ -279,6 +292,7 @@ def setup_hwfit_routes():
"""Rank image generation models against detected hardware."""
from services.hwfit.hardware import detect_system
from services.hwfit.image_models import rank_image_models
host, ssh_port = _validate_detection_target(host, ssh_port)
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}