fix(platform): Improve WSL SSH remote compatibility (#3316)

* fix(platform): add WSL compatibility functions and path translation
fix(cookbook): enhance model scan script to support additional HuggingFace cache paths
fix(hardware): improve cache key generation for remote SSH context
test(tests): add tests for WSL detection and path translation functionality

* fix(cookbook): prefer prebuilt wheels for llama-cpp-python and normalize package aliases

* fix: enable StrictHostKeyChecking in nvidia probe
refactor: consolidate ssh & powershell command execution to utility functions in core module
refactor: consolidate nvidia path candidates in to single variables in core module
tests: add tests for new utility functions

* fix: correct wrong variable name
This commit is contained in:
horribleCodes
2026-06-08 00:33:50 +02:00
committed by GitHub
parent 73315e6ddc
commit 9c90f62657
8 changed files with 763 additions and 33 deletions
+176
View File
@@ -161,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = (
("usr", "bin", "bash.exe"), ("usr", "bin", "bash.exe"),
) )
# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH.
_SSH_PATH_MEMBERS = (
"/usr/bin",
"/usr/local/bin",
"/usr/local/cuda/bin",
"/usr/lib/wsl/lib"
)
# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH.
NVIDIA_PATH_CANDIDATES = (
"/usr/bin/nvidia-smi",
"/usr/local/bin/nvidia-smi",
"/usr/local/cuda/bin/nvidia-smi",
"/usr/lib/wsl/lib/nvidia-smi",
)
def _ssh_path_override() -> str:
"""Build the PATH export snippet used for remote SSH shell probes."""
return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; "
SSH_PATH_OVERRIDE = _ssh_path_override()
def _windows_bash_fallbacks() -> List[str]: def _windows_bash_fallbacks() -> List[str]:
roots: List[str] = [] roots: List[str] = []
@@ -268,3 +291,156 @@ def run_script_argv(script_path) -> List[str]:
comspec = os.environ.get("ComSpec", "cmd.exe") comspec = os.environ.get("ComSpec", "cmd.exe")
return [comspec, "/c", str(script_path)] return [comspec, "/c", str(script_path)]
return ["sh", str(script_path)] return ["sh", str(script_path)]
def is_wsl() -> bool:
"""True if running inside Windows Subsystem for Linux (WSL)."""
import sys
if sys.platform.startswith("linux") or os.name == "posix":
try:
with open("/proc/version", "r") as f:
if "microsoft" in f.read().lower():
return True
except Exception:
pass
return False
def translate_path(path_str: str) -> str:
"""Translate a path (possibly a Windows path) to the current OS format.
Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running
under WSL, translating them to /mnt/c/foo.
Also handles standard path normalization to avoid string breakages.
"""
if not path_str:
return path_str
if is_wsl():
path_str = path_str.replace("\\", "/")
import re
m = re.match(r"^([a-zA-Z]):(.*)", path_str)
if m:
drive = m.group(1).lower()
rest = m.group(2)
if not rest.startswith("/"):
rest = "/" + rest
return f"/mnt/{drive}{rest}"
try:
return str(Path(path_str).resolve())
except Exception:
return path_str
def get_wsl_windows_user_profile() -> Optional[str]:
"""Retrieve the Windows host User Profile path from inside WSL."""
if not is_wsl():
return None
try:
r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5)
if r.returncode == 0 and r.stdout.strip():
return translate_path(r.stdout.strip())
except Exception:
pass
try:
users_dir = "/mnt/c/Users"
if os.path.isdir(users_dir):
for entry in os.listdir(users_dir):
if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"):
path = os.path.join(users_dir, entry)
if os.path.isdir(path):
return path
except Exception:
pass
return None
def _ssh_exec_argv(
remote: str,
ssh_port: str | None,
*,
remote_cmd: str | None = None,
connect_timeout: int | None = None,
strict_host_key_checking: bool | None = None,
) -> list[str]:
"""Build a consistent ssh argv for remote command execution."""
argv = ["ssh"]
if connect_timeout is not None:
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
if strict_host_key_checking is not None:
argv.extend(
[
"-o",
"StrictHostKeyChecking=yes"
if strict_host_key_checking
else "StrictHostKeyChecking=no",
]
)
if ssh_port and ssh_port != "22":
argv.extend(["-p", str(ssh_port)])
argv.append(remote)
if remote_cmd is not None:
argv.append(remote_cmd)
return argv
def run_ssh_command(
remote: str,
ssh_port: str | None,
remote_cmd: str,
*,
timeout: float,
connect_timeout: int | None = None,
strict_host_key_checking: bool | None = None,
text: bool = True,
) -> subprocess.CompletedProcess:
"""Run an ssh command with centralized timeout and stderr/stdout capture."""
return subprocess.run(
_ssh_exec_argv(
remote,
ssh_port,
remote_cmd=remote_cmd,
connect_timeout=connect_timeout,
strict_host_key_checking=strict_host_key_checking,
),
timeout=timeout,
capture_output=True,
text=text,
)
def _windows_powershell_argv(
command: str,
*,
no_profile: bool = True,
non_interactive: bool = True,
) -> List[str]:
argv: List[str] = ["powershell.exe"]
if no_profile:
argv.append("-NoProfile")
if non_interactive:
argv.append("-NonInteractive")
argv.extend(["-Command", command])
return argv
def run_wsl_windows_powershell(
command: str,
*,
timeout: float = 5,
) -> subprocess.CompletedProcess[str]:
"""Run a PowerShell command on the Windows host from WSL.
Raises ``RuntimeError`` when called outside WSL.
"""
if not is_wsl():
raise RuntimeError("run_wsl_windows_powershell is only supported in WSL")
return subprocess.run(
_windows_powershell_argv(command),
capture_output=True,
text=True,
timeout=timeout,
)
+49 -3
View File
@@ -11,6 +11,8 @@ import shlex
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from core.platform_compat import _ssh_exec_argv
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -213,7 +215,10 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
# before being embedded in the install command. Plain names (e.g. # before being embedded in the install command. Plain names (e.g.
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``. # ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
pkg = shlex.quote(package) pkg = shlex.quote(package)
if IS_WINDOWS and "llama-cpp-python" in package: # llama-cpp-python source builds are brittle on older distro pip/packaging
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
# this package is requested so dependency-install tasks are reliable.
if "llama-cpp-python" in package:
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu" pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}") base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
@@ -275,11 +280,14 @@ def _user_shell_path_bootstrap() -> list[str]:
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi', ' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
'fi', 'fi',
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }', 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
] ]
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str: def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
"""Build the standalone Python scanner used by /api/model/cached.""" """Build the standalone Python scanner used by /api/model/cached.
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
"""
lines = [ lines = [
"import json, os, re, shutil, subprocess, urllib.request", "import json, os, re, shutil, subprocess, urllib.request",
"models = []", "models = []",
@@ -372,6 +380,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.", " # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
" # When HOME is /root, expanduser() misses that persisted cache.", " # When HOME is /root, expanduser() misses that persisted cache.",
" add('/app/.cache/huggingface/hub')", " add('/app/.cache/huggingface/hub')",
f" add({add_hf_cache!r})" if add_hf_cache else "",
" return candidates", " return candidates",
"def scan_dir(p):", "def scan_dir(p):",
" if not os.path.isdir(p) or not safe_path(p): return", " if not os.path.isdir(p) or not safe_path(p): return",
@@ -989,3 +998,40 @@ def _diagnose_serve_output(text: str) -> dict | None:
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}], "suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
} }
return None return None
async def run_ssh_command_async(
remote: str,
ssh_port: str | None,
remote_cmd: str,
*,
timeout: float,
connect_timeout: int | None = None,
strict_host_key_checking: bool | None = None,
stdin_data: bytes | None = None,
) -> tuple[int, bytes, bytes]:
"""Run an ssh command with centralized timeout and stderr/stdout capture.
Async version of core.platform_compat.run_ssh_command_sync.
"""
import asyncio
proc = await asyncio.create_subprocess_exec(
*_ssh_exec_argv(
remote,
ssh_port,
remote_cmd=remote_cmd,
connect_timeout=connect_timeout,
strict_host_key_checking=strict_host_key_checking,
),
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(input=stdin_data), timeout=timeout
)
except asyncio.TimeoutError:
proc.kill()
await proc.communicate()
raise
return proc.returncode or 0, stdout, stderr
+66 -18
View File
@@ -20,6 +20,8 @@ from pydantic import BaseModel
from core.middleware import require_admin from core.middleware import require_admin
from core.platform_compat import ( from core.platform_compat import (
IS_WINDOWS, IS_WINDOWS,
SSH_PATH_OVERRIDE,
NVIDIA_PATH_CANDIDATES,
detached_popen_kwargs, detached_popen_kwargs,
find_bash, find_bash,
git_bash_path, git_bash_path,
@@ -27,6 +29,8 @@ from core.platform_compat import (
pid_alive, pid_alive,
safe_chmod, safe_chmod,
which_tool, which_tool,
translate_path,
get_wsl_windows_user_profile,
) )
from routes.shell_routes import TMUX_LOG_DIR from routes.shell_routes import TMUX_LOG_DIR
@@ -41,7 +45,7 @@ from routes.cookbook_helpers import (
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script, _append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain, _append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd, _pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
_diagnose_serve_output, _diagnose_serve_output, run_ssh_command_async,
ModelDownloadRequest, ServeRequest, ModelDownloadRequest, ServeRequest,
) )
@@ -557,24 +561,35 @@ def setup_cookbook_routes() -> APIRouter:
for d in model_dir.split(','): for d in model_dir.split(','):
d = d.strip() d = d.strip()
if d: if d:
model_dirs.append(d) translated_d = translate_path(d) if not host else d
paths_code = _cached_model_scan_script(model_dirs) model_dirs.append(translated_d)
win_hf_hub = None
if not host:
win_profile = get_wsl_windows_user_profile()
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
scan_py = TMUX_LOG_DIR / "scan_cache.py" scan_py = TMUX_LOG_DIR / "scan_cache.py"
scan_py.write_text(paths_code, encoding="utf-8") scan_py.write_text(paths_code, encoding="utf-8")
scan_payload = scan_py.read_bytes()
if host: if host:
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
if platform == "windows": if platform == "windows":
# Windows: use 'python' and pipe via stdin with double-quote wrapping remote_cmd = "python -"
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
else: else:
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'" # POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
proc = await asyncio.create_subprocess_shell( remote_cmd = (
cmd, "if command -v python3 >/dev/null 2>&1; then python3 -; "
stdout=asyncio.subprocess.PIPE, "elif command -v python >/dev/null 2>&1; then python -; "
stderr=asyncio.subprocess.PIPE, "else echo \"python3/python not found\" >&2; exit 127; fi"
cwd=str(Path.home()), )
rc, stdout_b, stderr_b = await run_ssh_command_async(
host,
ssh_port,
remote_cmd,
timeout=60,
stdin_data=scan_payload,
) )
else: else:
# LOCAL scan: use sys.executable (the venv Python Odysseus is already # LOCAL scan: use sys.executable (the venv Python Odysseus is already
@@ -594,7 +609,7 @@ def setup_cookbook_routes() -> APIRouter:
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
cwd=str(Path.home()), cwd=str(Path.home()),
) )
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60) stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
models = [] models = []
try: try:
@@ -874,6 +889,12 @@ def setup_cookbook_routes() -> APIRouter:
# pip cache so they don't fail mid-build with "No space left" (#1219) # pip cache so they don't fail mid-build with "No space left" (#1219)
# and leave the dep installed-but-unusable (#1459). # and leave the dep installed-but-unusable (#1459).
req.cmd = _pip_install_no_cache(req.cmd) req.cmd = _pip_install_no_cache(req.cmd)
# Accept common aliases and enforce server extras for llama-cpp so
# `python -m llama_cpp.server` has all runtime dependencies.
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
# PEP-508-style package spec — letters, digits, `.-_` for the # PEP-508-style package spec — letters, digits, `.-_` for the
# name; `[` `]` for extras; `<>=!~,` for version specifiers. # name; `[` `]` for extras; `<>=!~,` for version specifiers.
# v2 review HIGH-14: tightened from the previous regex which # v2 review HIGH-14: tightened from the previous regex which
@@ -1354,11 +1375,38 @@ def setup_cookbook_routes() -> APIRouter:
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8): async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None).""" """Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
if host: if host:
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else "" candidates = [query]
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'" stripped = query.strip()
proc = await asyncio.create_subprocess_shell( if stripped.startswith("nvidia-smi "):
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE args = stripped[len("nvidia-smi "):]
) candidates.append(
"bash -lc "
+ shlex.quote(
f"{SSH_PATH_OVERRIDE}"
f"nvidia-smi {args}"
)
)
for nvidia_path in NVIDIA_PATH_CANDIDATES:
candidates.append(f"{nvidia_path} {args}")
last_err = "nvidia-smi failed"
for candidate in candidates:
try:
rc, stdout, stderr = await run_ssh_command_async(
host,
ssh_port,
candidate,
connect_timeout=5,
timeout=timeout,
)
except asyncio.TimeoutError:
return None, "nvidia-smi timed out"
if rc == 0:
return stdout.decode("utf-8", errors="replace"), None
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
if err:
last_err = err
return None, last_err
else: else:
proc = await asyncio.create_subprocess_exec( proc = await asyncio.create_subprocess_exec(
*shlex.split(query), *shlex.split(query),
+23 -1
View File
@@ -47,6 +47,9 @@ _STATE_PATH = _DATA_DIR / "cookbook_state.json"
import tempfile import tempfile
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux" _TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
from core.platform_compat import NVIDIA_PATH_CANDIDATES, SSH_PATH_OVERRIDE
def fail(msg: str, code: int = 1) -> None: def fail(msg: str, code: int = 1) -> None:
sys.stderr.write(f"error: {msg}\n") sys.stderr.write(f"error: {msg}\n")
@@ -160,7 +163,26 @@ def cmd_gpus(args) -> None:
prefix = _ssh_prefix(args.host, args.ssh_port) prefix = _ssh_prefix(args.host, args.ssh_port)
cmd = prefix + (query.split() if not prefix else [query]) cmd = prefix + (query.split() if not prefix else [query])
try: try:
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15) if prefix:
candidates = [query]
args_part = query[len("nvidia-smi "):]
candidates.append(
"bash -lc "
+ repr(
f"{SSH_PATH_OVERRIDE}"
f"nvidia-smi {args_part}"
)
)
for nvidia_path in NVIDIA_PATH_CANDIDATES:
candidates.append(f"{nvidia_path} {args_part}")
out = None
for candidate in candidates:
out = subprocess.run(prefix + [candidate], capture_output=True, text=True, timeout=15)
if out.returncode == 0:
break
else:
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
except FileNotFoundError: except FileNotFoundError:
# No nvidia-smi locally → try the Metal fallback before giving up. # No nvidia-smi locally → try the Metal fallback before giving up.
if not prefix: if not prefix:
+32 -11
View File
@@ -4,6 +4,13 @@ import re
import shutil import shutil
import subprocess import subprocess
import time import time
import shlex
from core.platform_compat import (
NVIDIA_PATH_CANDIDATES,
SSH_PATH_OVERRIDE,
run_ssh_command,
)
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
# from 30 min so changing filters doesn't keep re-probing the rig every # from 30 min so changing filters doesn't keep re-probing the rig every
@@ -21,16 +28,17 @@ def _run(cmd):
if _remote_host: if _remote_host:
# Run command on remote host via SSH # Run command on remote host via SSH
if isinstance(cmd, list): if isinstance(cmd, list):
cmd_str = " ".join(cmd) cmd_str = shlex.join(str(c) for c in cmd)
else: else:
cmd_str = cmd cmd_str = cmd
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"] r = run_ssh_command(
if _remote_port and _remote_port != "22": _remote_host,
ssh_cmd += ["-p", _remote_port] _remote_port,
ssh_cmd += [_remote_host, cmd_str] cmd_str,
r = subprocess.run( timeout=15,
ssh_cmd, connect_timeout=5,
capture_output=True, text=True, timeout=15, strict_host_key_checking=False,
text=True,
) )
else: else:
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10) r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
@@ -83,7 +91,7 @@ def _detect_nvidia():
# Retry through a login shell with the common CUDA bin dirs on PATH. # Retry through a login shell with the common CUDA bin dirs on PATH.
if not out and _remote_host: if not out and _remote_host:
out = _run( out = _run(
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin:/usr/lib/wsl/lib\"; " f"bash -lc '{SSH_PATH_OVERRIDE}"
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'" "nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
) )
# Last resort: call nvidia-smi by absolute path. Some hosts have a login # Last resort: call nvidia-smi by absolute path. Some hosts have a login
@@ -92,7 +100,7 @@ def _detect_nvidia():
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path # Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
# that may not be in the server process's PATH. # that may not be in the server process's PATH.
if not out: if not out:
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi", "/usr/lib/wsl/lib/nvidia-smi"): for _p in NVIDIA_PATH_CANDIDATES:
# Use list form so subprocess.run (local) resolves the absolute path # Use list form so subprocess.run (local) resolves the absolute path
# correctly instead of treating the whole string as an executable name. # correctly instead of treating the whole string as an executable name.
if _remote_host: if _remote_host:
@@ -590,6 +598,19 @@ def _detect_windows():
_cache_by_host = {} # host -> (timestamp, result) _cache_by_host = {} # host -> (timestamp, result)
def _cache_key(host: str, ssh_port: str, platform_name: str):
"""Build a stable cache key that isolates remote SSH context.
Same host aliases can have different hardware due to visibility, forwarding etc.
To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key.
"""
return (
host or "_local",
str(ssh_port or ""),
str(platform_name or "").lower(),
)
def detect_system(host="", ssh_port="", platform="", fresh=False): def detect_system(host="", ssh_port="", platform="", fresh=False):
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely """Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
changes, and probing a remote host over SSH is slow). Pass fresh=True to changes, and probing a remote host over SSH is slow). Pass fresh=True to
@@ -599,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False):
""" """
global _remote_host, _remote_port, _remote_platform global _remote_host, _remote_port, _remote_platform
cache_key = host or "_local" cache_key = _cache_key(host, ssh_port, platform)
now = time.time() now = time.time()
if not fresh and cache_key in _cache_by_host: if not fresh and cache_key in _cache_by_host:
ts, cached = _cache_by_host[cache_key] ts, cached = _cache_by_host[cache_key]
+97
View File
@@ -25,6 +25,7 @@ from routes.cookbook_helpers import (
_validate_serve_cmd, _validate_serve_cmd,
_validate_serve_model_id, _validate_serve_model_id,
_validate_ssh_port, _validate_ssh_port,
run_ssh_command_async,
) )
@@ -35,6 +36,56 @@ def test_safe_env_prefix_accepts_quoted_venv_path():
) )
@pytest.mark.asyncio
async def test_run_ssh_command_executes_with_stdin_and_returns_output(monkeypatch):
captured = {}
class _Proc:
returncode = 0
async def communicate(self, input=None):
captured["input"] = input
return b"stdout", b"stderr"
async def _fake_exec(*args, **kwargs):
captured["args"] = list(args)
captured["stdin"] = kwargs.get("stdin")
captured["stdout"] = kwargs.get("stdout")
captured["stderr"] = kwargs.get("stderr")
return _Proc()
monkeypatch.setattr("asyncio.create_subprocess_exec", _fake_exec)
rc, out, err = await run_ssh_command_async(
"alice@gpu-box",
"2222",
"python -",
timeout=5,
connect_timeout=4,
strict_host_key_checking=False,
stdin_data=b"python -m pip install vllm",
)
assert rc == 0
assert out == b"stdout"
assert err == b"stderr"
assert captured["args"] == [
"ssh",
"-o",
"ConnectTimeout=4",
"-o",
"StrictHostKeyChecking=no",
"-p",
"2222",
"alice@gpu-box",
"python -",
]
assert captured["stdin"] is not None
assert captured["stdout"] is not None
assert captured["stderr"] is not None
assert captured["input"] == b"python -m pip install vllm"
def test_safe_env_prefix_leaves_compound_conda_prefix_unchanged(): def test_safe_env_prefix_leaves_compound_conda_prefix_unchanged():
prefix = 'eval "$(conda shell.bash hook)" && conda activate qwen35' prefix = 'eval "$(conda shell.bash hook)" && conda activate qwen35'
assert _safe_env_prefix(prefix) == prefix assert _safe_env_prefix(prefix) == prefix
@@ -170,6 +221,8 @@ def test_pip_install_fallback_chain_quotes_extras_spec():
chain = _pip_install_fallback_chain("llama-cpp-python[server]", python_cmd="pip") chain = _pip_install_fallback_chain("llama-cpp-python[server]", python_cmd="pip")
# Quoted in both the plain and the --user attempt. # Quoted in both the plain and the --user attempt.
assert chain.count("'llama-cpp-python[server]'") == 2 assert chain.count("'llama-cpp-python[server]'") == 2
# llama-cpp installs must prefer prebuilt wheels to avoid fragile source builds.
assert "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu" in chain
# Never the unquoted form (bracket-glob risk). # Never the unquoted form (bracket-glob risk).
assert "install -q llama-cpp-python[server]" not in chain assert "install -q llama-cpp-python[server]" not in chain
# A plain package name is still passed through unquoted (no regression). # A plain package name is still passed through unquoted (no regression).
@@ -194,6 +247,17 @@ def test_serve_runner_installs_llama_cpp_server_extra():
assert "_pip_install_fallback_chain('llama-cpp-python[server]'" in src assert "_pip_install_fallback_chain('llama-cpp-python[server]'" in src
def test_serve_pip_install_normalizes_llama_cpp_alias_and_adds_wheel_index():
import pathlib
src = (pathlib.Path(__file__).resolve().parent.parent
/ "routes" / "cookbook_routes.py").read_text(encoding="utf-8")
assert "re.sub(r\"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])\", \"llama-cpp-python[server]\", req.cmd)" in src
assert "if \"llama-cpp-python\" in req.cmd and \"--extra-index-url\" not in req.cmd:" in src
assert "https://abetlen.github.io/llama-cpp-python/whl/cpu" in src
def test_vllm_preflight_reports_cli_and_version(): def test_vllm_preflight_reports_cli_and_version():
lines = [] lines = []
@@ -289,6 +353,7 @@ def test_local_tooling_path_export_converts_windows_paths_for_bash():
def test_user_shell_path_bootstrap_falls_back_to_python_on_windows_bash(): def test_user_shell_path_bootstrap_falls_back_to_python_on_windows_bash():
script = "\n".join(_user_shell_path_bootstrap()) script = "\n".join(_user_shell_path_bootstrap())
assert 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }' in script assert 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }' in script
assert 'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }' in script
def test_serve_preflight_failure_keeps_tmux_pane_visible(): def test_serve_preflight_failure_keeps_tmux_pane_visible():
@@ -606,3 +671,35 @@ def test_pip_install_no_cache_is_idempotent_and_scoped():
# not a pip install -> unchanged # not a pip install -> unchanged
assert _pip_install_no_cache("vllm serve --model x") == "vllm serve --model x" assert _pip_install_no_cache("vllm serve --model x") == "vllm serve --model x"
assert _pip_install_no_cache("") == "" assert _pip_install_no_cache("") == ""
def test_cached_model_scan_runs_additional_hf_cache(tmp_path):
extra_cache = tmp_path / "extra_hf_cache"
model_dir = extra_cache / "models--acme--sample-7b"
snap = model_dir / "snapshots" / "rev-1"
snap.mkdir(parents=True)
weights = snap / "model.safetensors"
weights.write_bytes(b"abc123")
scan_py = tmp_path / "scan_cache.py"
scan_py.write_text(
_cached_model_scan_script(add_hf_cache=str(extra_cache)),
encoding="utf-8",
)
proc = subprocess.run(
[sys.executable, str(scan_py)],
check=True,
capture_output=True,
text=True,
)
models = json.loads(proc.stdout)
by_repo = {m["repo_id"]: m for m in models}
assert "acme/sample-7b" in by_repo
rec = by_repo["acme/sample-7b"]
assert rec["path"] == str(extra_cache)
assert rec["nb_files"] == 1
assert rec["size_bytes"] == len(b"abc123")
assert rec["has_incomplete"] is False
assert rec["is_diffusion"] is False
+78
View File
@@ -71,3 +71,81 @@ def test_no_gpu_still_none(monkeypatch):
"""No nvidia-smi output → still None, no spurious unified GPU.""" """No nvidia-smi output → still None, no spurious unified GPU."""
monkeypatch.setattr(hardware, "_run", lambda cmd: None) monkeypatch.setattr(hardware, "_run", lambda cmd: None)
assert hardware._detect_nvidia() is None assert hardware._detect_nvidia() is None
def test_detect_system_cache_separates_same_host_different_ports(monkeypatch):
"""Keep cache separate by host+port+platform, don't use cached data"""
ram_gb = 0
def _ram():
nonlocal ram_gb
ram_gb += 1
return ram_gb * 64.0
monkeypatch.setattr(hardware, "_get_ram_gb", _ram)
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0)
monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16)
monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen")
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None)
monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None)
monkeypatch.setattr(hardware, "_detect_amd", lambda: None)
monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64")
def _windows_probe():
nonlocal ram_gb
ram_gb += 1
return {
"total_ram_gb": ram_gb * 64.0,
"available_ram_gb": 40.0,
"cpu_cores": 16,
"cpu_name": "AMD Ryzen",
"has_gpu": False,
"gpu_name": None,
"gpu_vram_gb": None,
"gpu_count": 0,
"backend": "cpu_x86",
"homogeneous": True,
"gpu_error": None,
"platform": "windows",
}
monkeypatch.setattr(hardware, "_detect_windows", _windows_probe)
hardware._cache_by_host.clear()
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
hardware.detect_system(host="user@wsl-host", ssh_port="2222", platform="linux", fresh=False)
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="windows", fresh=False)
assert len(hardware._cache_by_host) == 3
assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0
assert hardware._cache_by_host[("user@wsl-host", "2222", "linux")][1]["total_ram_gb"] == 128.0
assert hardware._cache_by_host[("user@wsl-host", "22", "windows")][1]["total_ram_gb"] == 192.0
def test_detect_system_cache_hits_when_remote_context_matches(monkeypatch):
"""Cache hits when host+port+platform match"""
ram_gb = 0
def _ram():
nonlocal ram_gb
ram_gb += 1
return ram_gb * 64.0
monkeypatch.setattr(hardware, "_get_ram_gb", _ram)
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0)
monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16)
monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen")
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None)
monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None)
monkeypatch.setattr(hardware, "_detect_amd", lambda: None)
monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64")
hardware._cache_by_host.clear()
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
hardware.detect_system(fresh=False)
hardware.detect_system(fresh=False)
assert len(hardware._cache_by_host) == 2
assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0
assert hardware._cache_by_host[("_local", "", "")][1]["total_ram_gb"] == 128.0
+242
View File
@@ -1,6 +1,8 @@
"""Regression tests for cross-platform helper behavior.""" """Regression tests for cross-platform helper behavior."""
import importlib.util import importlib.util
import io
import sys
from pathlib import Path from pathlib import Path
@@ -59,3 +61,243 @@ def test_find_bash_skips_windows_wsl_stub(monkeypatch):
monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected) monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected)
assert platform_compat.find_bash() == expected assert platform_compat.find_bash() == expected
def test_is_wsl_true_when_proc_version_mentions_microsoft(monkeypatch):
monkeypatch.setattr(sys, "platform", "linux", raising=False)
def fake_open(path, mode="r", *args, **kwargs):
assert path == "/proc/version"
assert mode == "r"
return io.StringIO("Linux version 6.6.0 microsoft standard")
monkeypatch.setattr("builtins.open", fake_open)
assert platform_compat.is_wsl() is True
def test_is_wsl_false_when_proc_version_is_not_microsoft(monkeypatch):
monkeypatch.setattr(sys, "platform", "linux", raising=False)
monkeypatch.setattr("builtins.open", lambda *_a, **_k: io.StringIO("Linux version 6.6.0 generic"))
assert platform_compat.is_wsl() is False
def test_is_wsl_false_on_non_posix_without_proc_probe(monkeypatch):
monkeypatch.setattr(sys, "platform", "win32", raising=False)
monkeypatch.setattr(platform_compat.os, "name", "nt", raising=False)
def fail_open(*_args, **_kwargs):
raise AssertionError("open should not be called when platform is not Linux/POSIX")
monkeypatch.setattr("builtins.open", fail_open)
assert platform_compat.is_wsl() is False
def test_translate_path_converts_windows_drive_path_on_wsl(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
out = platform_compat.translate_path(r"C:\Users\alice\models\qwen.gguf")
assert out == "/mnt/c/Users/alice/models/qwen.gguf"
def test_translate_path_resolves_paths_when_not_wsl(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
assert platform_compat.translate_path(".") == str(Path(".").resolve())
def test_translate_path_returns_input_when_resolve_fails(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
class _BrokenPath:
def __init__(self, _value):
pass
def resolve(self):
raise RuntimeError("boom")
monkeypatch.setattr(platform_compat, "Path", _BrokenPath)
assert platform_compat.translate_path("weird::path") == "weird::path"
def test_get_wsl_windows_user_profile_prefers_powershell(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
class _Result:
returncode = 0
stdout = "C:\\Users\\alice\\n"
monkeypatch.setattr(platform_compat.subprocess, "run", lambda *_a, **_k: _Result())
monkeypatch.setattr(platform_compat, "translate_path", lambda _v: "/mnt/c/Users/alice")
assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice"
def test_get_wsl_windows_user_profile_falls_back_to_users_dir(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
def raise_run(*_a, **_k):
raise OSError("powershell unavailable")
monkeypatch.setattr(platform_compat.subprocess, "run", raise_run)
monkeypatch.setattr(
platform_compat.os,
"listdir",
lambda _path: ["All Users", "Default", "Public", "alice"],
)
def fake_isdir(path):
return path in {"/mnt/c/Users", "/mnt/c/Users/alice"}
monkeypatch.setattr(platform_compat.os.path, "isdir", fake_isdir)
assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice"
def test_get_wsl_windows_user_profile_returns_none_when_nothing_found(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
monkeypatch.setattr(
platform_compat.subprocess,
"run",
lambda *_a, **_k: (_ for _ in ()).throw(OSError("powershell unavailable")),
)
monkeypatch.setattr(platform_compat.os.path, "isdir", lambda _path: False)
assert platform_compat.get_wsl_windows_user_profile() is None
def test_nvidia_path_override_is_correct_string(monkeypatch):
monkeypatch.setattr(platform_compat, "_SSH_PATH_MEMBERS", ["path1", "path2"])
assert platform_compat._ssh_path_override() == "export PATH=\"$PATH:path1:path2\"; "
def test_windows_powershell_argv_defaults_include_no_profile_and_noninteractive():
argv = platform_compat._windows_powershell_argv("Write-Output Hello")
assert argv == [
"powershell.exe",
"-NoProfile",
"-NonInteractive",
"-Command",
"Write-Output Hello",
]
def test_windows_powershell_argv_respects_disabled_flags():
argv = platform_compat._windows_powershell_argv(
"Write-Output Hello",
no_profile=False,
non_interactive=False,
)
assert argv == ["powershell.exe", "-Command", "Write-Output Hello"]
def test_run_wsl_windows_powershell_raises_outside_wsl(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
try:
platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=2)
raise AssertionError("Expected RuntimeError")
except RuntimeError as exc:
assert "only supported in WSL" in str(exc)
def test_run_wsl_windows_powershell_calls_subprocess_with_expected_argv(monkeypatch):
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
captured = {}
class _Result:
returncode = 0
stdout = "ok\n"
stderr = ""
def _fake_run(args, **kwargs):
captured["args"] = list(args)
captured["kwargs"] = kwargs
return _Result()
monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run)
result = platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=9)
assert result.returncode == 0
assert captured["args"] == [
"powershell.exe",
"-NoProfile",
"-NonInteractive",
"-Command",
"Write-Output Hello",
]
assert captured["kwargs"]["capture_output"] is True
assert captured["kwargs"]["text"] is True
assert captured["kwargs"]["timeout"] == 9
def test_ssh_exec_argv_builds_default_command():
argv = platform_compat._ssh_exec_argv("alice@gpu-box", None, remote_cmd="echo ok")
assert argv == ["ssh", "alice@gpu-box", "echo ok"]
def test_ssh_exec_argv_includes_port_and_options():
argv = platform_compat._ssh_exec_argv(
"alice@gpu-box",
"2222",
remote_cmd="tmux ls",
connect_timeout=6,
strict_host_key_checking=False,
)
assert argv == [
"ssh",
"-o",
"ConnectTimeout=6",
"-o",
"StrictHostKeyChecking=no",
"-p",
"2222",
"alice@gpu-box",
"tmux ls",
]
def test_run_ssh_command_uses_built_argv(monkeypatch):
captured = {}
class _Result:
returncode = 0
stdout = "ok"
stderr = ""
def _fake_run(args, **kwargs):
captured["args"] = list(args)
captured["kwargs"] = kwargs
return _Result()
monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run)
result = platform_compat.run_ssh_command(
"alice@gpu-box",
"2200",
"tmux ls",
timeout=7,
connect_timeout=3,
strict_host_key_checking=True,
text=False,
)
assert result.returncode == 0
assert captured["args"] == [
"ssh",
"-o",
"ConnectTimeout=3",
"-o",
"StrictHostKeyChecking=yes",
"-p",
"2200",
"alice@gpu-box",
"tmux ls",
]
assert captured["kwargs"]["timeout"] == 7
assert captured["kwargs"]["capture_output"] is True
assert captured["kwargs"]["text"] is False