From 9c90f626574684afd623f4d079508e8dbbad13f7 Mon Sep 17 00:00:00 2001 From: horribleCodes Date: Mon, 8 Jun 2026 00:33:50 +0200 Subject: [PATCH] fix(platform): Improve WSL SSH remote compatibility (#3316) * fix(platform): add WSL compatibility functions and path translation fix(cookbook): enhance model scan script to support additional HuggingFace cache paths fix(hardware): improve cache key generation for remote SSH context test(tests): add tests for WSL detection and path translation functionality * fix(cookbook): prefer prebuilt wheels for llama-cpp-python and normalize package aliases * fix: enable StrictHostKeyChecking in nvidia probe refactor: consolidate ssh & powershell command execution to utility functions in core module refactor: consolidate nvidia path candidates in to single variables in core module tests: add tests for new utility functions * fix: correct wrong variable name --- core/platform_compat.py | 176 +++++++++++++++++++++ routes/cookbook_helpers.py | 52 ++++++- routes/cookbook_routes.py | 84 +++++++--- scripts/odysseus-cookbook | 24 ++- services/hwfit/hardware.py | 43 +++-- tests/test_cookbook_helpers.py | 97 ++++++++++++ tests/test_hwfit_unified_nvidia.py | 78 ++++++++++ tests/test_platform_compat.py | 242 +++++++++++++++++++++++++++++ 8 files changed, 763 insertions(+), 33 deletions(-) diff --git a/core/platform_compat.py b/core/platform_compat.py index f2141ea75..3eda4a107 100644 --- a/core/platform_compat.py +++ b/core/platform_compat.py @@ -161,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = ( ("usr", "bin", "bash.exe"), ) +# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH. +_SSH_PATH_MEMBERS = ( + "/usr/bin", + "/usr/local/bin", + "/usr/local/cuda/bin", + "/usr/lib/wsl/lib" +) +# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH. +NVIDIA_PATH_CANDIDATES = ( + "/usr/bin/nvidia-smi", + "/usr/local/bin/nvidia-smi", + "/usr/local/cuda/bin/nvidia-smi", + "/usr/lib/wsl/lib/nvidia-smi", +) + + +def _ssh_path_override() -> str: + """Build the PATH export snippet used for remote SSH shell probes.""" + return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; " + + +SSH_PATH_OVERRIDE = _ssh_path_override() + def _windows_bash_fallbacks() -> List[str]: roots: List[str] = [] @@ -268,3 +291,156 @@ def run_script_argv(script_path) -> List[str]: comspec = os.environ.get("ComSpec", "cmd.exe") return [comspec, "/c", str(script_path)] return ["sh", str(script_path)] + + +def is_wsl() -> bool: + """True if running inside Windows Subsystem for Linux (WSL).""" + import sys + if sys.platform.startswith("linux") or os.name == "posix": + try: + with open("/proc/version", "r") as f: + if "microsoft" in f.read().lower(): + return True + except Exception: + pass + return False + + +def translate_path(path_str: str) -> str: + """Translate a path (possibly a Windows path) to the current OS format. + + Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running + under WSL, translating them to /mnt/c/foo. + Also handles standard path normalization to avoid string breakages. + """ + if not path_str: + return path_str + + if is_wsl(): + path_str = path_str.replace("\\", "/") + import re + m = re.match(r"^([a-zA-Z]):(.*)", path_str) + if m: + drive = m.group(1).lower() + rest = m.group(2) + if not rest.startswith("/"): + rest = "/" + rest + return f"/mnt/{drive}{rest}" + + try: + return str(Path(path_str).resolve()) + except Exception: + return path_str + + +def get_wsl_windows_user_profile() -> Optional[str]: + """Retrieve the Windows host User Profile path from inside WSL.""" + if not is_wsl(): + return None + try: + r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5) + if r.returncode == 0 and r.stdout.strip(): + return translate_path(r.stdout.strip()) + except Exception: + pass + + try: + users_dir = "/mnt/c/Users" + if os.path.isdir(users_dir): + for entry in os.listdir(users_dir): + if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"): + path = os.path.join(users_dir, entry) + if os.path.isdir(path): + return path + except Exception: + pass + return None + + +def _ssh_exec_argv( + remote: str, + ssh_port: str | None, + *, + remote_cmd: str | None = None, + connect_timeout: int | None = None, + strict_host_key_checking: bool | None = None, +) -> list[str]: + """Build a consistent ssh argv for remote command execution.""" + argv = ["ssh"] + if connect_timeout is not None: + argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"]) + if strict_host_key_checking is not None: + argv.extend( + [ + "-o", + "StrictHostKeyChecking=yes" + if strict_host_key_checking + else "StrictHostKeyChecking=no", + ] + ) + if ssh_port and ssh_port != "22": + argv.extend(["-p", str(ssh_port)]) + argv.append(remote) + if remote_cmd is not None: + argv.append(remote_cmd) + return argv + + +def run_ssh_command( + remote: str, + ssh_port: str | None, + remote_cmd: str, + *, + timeout: float, + connect_timeout: int | None = None, + strict_host_key_checking: bool | None = None, + text: bool = True, +) -> subprocess.CompletedProcess: + """Run an ssh command with centralized timeout and stderr/stdout capture.""" + return subprocess.run( + _ssh_exec_argv( + remote, + ssh_port, + remote_cmd=remote_cmd, + connect_timeout=connect_timeout, + strict_host_key_checking=strict_host_key_checking, + ), + timeout=timeout, + capture_output=True, + text=text, + ) + + +def _windows_powershell_argv( + command: str, + *, + no_profile: bool = True, + non_interactive: bool = True, +) -> List[str]: + argv: List[str] = ["powershell.exe"] + if no_profile: + argv.append("-NoProfile") + if non_interactive: + argv.append("-NonInteractive") + argv.extend(["-Command", command]) + return argv + + +def run_wsl_windows_powershell( + command: str, + *, + timeout: float = 5, +) -> subprocess.CompletedProcess[str]: + """Run a PowerShell command on the Windows host from WSL. + + Raises ``RuntimeError`` when called outside WSL. + """ + + if not is_wsl(): + raise RuntimeError("run_wsl_windows_powershell is only supported in WSL") + return subprocess.run( + _windows_powershell_argv(command), + capture_output=True, + text=True, + timeout=timeout, + ) diff --git a/routes/cookbook_helpers.py b/routes/cookbook_helpers.py index 3af227861..a154f3718 100644 --- a/routes/cookbook_helpers.py +++ b/routes/cookbook_helpers.py @@ -11,6 +11,8 @@ import shlex from fastapi import HTTPException from pydantic import BaseModel +from core.platform_compat import _ssh_exec_argv + logger = logging.getLogger(__name__) @@ -213,7 +215,10 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p # before being embedded in the install command. Plain names (e.g. # ``huggingface_hub``) are returned unchanged by ``shlex.quote``. pkg = shlex.quote(package) - if IS_WINDOWS and "llama-cpp-python" in package: + # llama-cpp-python source builds are brittle on older distro pip/packaging + # stacks (common on WSL images). Prefer the prebuilt wheel index whenever + # this package is requested so dependency-install tasks are reliable. + if "llama-cpp-python" in package: pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu" base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}") @@ -275,11 +280,14 @@ def _user_shell_path_bootstrap() -> list[str]: ' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi', 'fi', 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }', + 'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }', ] -def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str: - """Build the standalone Python scanner used by /api/model/cached.""" +def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str: + """Build the standalone Python scanner used by /api/model/cached. + Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.) + """ lines = [ "import json, os, re, shutil, subprocess, urllib.request", "models = []", @@ -372,6 +380,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str: " # Docker images mount ./data/huggingface at /app/.cache/huggingface.", " # When HOME is /root, expanduser() misses that persisted cache.", " add('/app/.cache/huggingface/hub')", + f" add({add_hf_cache!r})" if add_hf_cache else "", " return candidates", "def scan_dir(p):", " if not os.path.isdir(p) or not safe_path(p): return", @@ -989,3 +998,40 @@ def _diagnose_serve_output(text: str) -> dict | None: "suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}], } return None + + +async def run_ssh_command_async( + remote: str, + ssh_port: str | None, + remote_cmd: str, + *, + timeout: float, + connect_timeout: int | None = None, + strict_host_key_checking: bool | None = None, + stdin_data: bytes | None = None, +) -> tuple[int, bytes, bytes]: + """Run an ssh command with centralized timeout and stderr/stdout capture. + Async version of core.platform_compat.run_ssh_command_sync. + """ + import asyncio + proc = await asyncio.create_subprocess_exec( + *_ssh_exec_argv( + remote, + ssh_port, + remote_cmd=remote_cmd, + connect_timeout=connect_timeout, + strict_host_key_checking=strict_host_key_checking, + ), + stdin=asyncio.subprocess.PIPE if stdin_data is not None else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(input=stdin_data), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.communicate() + raise + return proc.returncode or 0, stdout, stderr \ No newline at end of file diff --git a/routes/cookbook_routes.py b/routes/cookbook_routes.py index 04ad05522..84ec80a71 100644 --- a/routes/cookbook_routes.py +++ b/routes/cookbook_routes.py @@ -20,6 +20,8 @@ from pydantic import BaseModel from core.middleware import require_admin from core.platform_compat import ( IS_WINDOWS, + SSH_PATH_OVERRIDE, + NVIDIA_PATH_CANDIDATES, detached_popen_kwargs, find_bash, git_bash_path, @@ -27,6 +29,8 @@ from core.platform_compat import ( pid_alive, safe_chmod, which_tool, + translate_path, + get_wsl_windows_user_profile, ) from routes.shell_routes import TMUX_LOG_DIR @@ -41,7 +45,7 @@ from routes.cookbook_helpers import ( _append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script, _append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain, _pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd, - _diagnose_serve_output, + _diagnose_serve_output, run_ssh_command_async, ModelDownloadRequest, ServeRequest, ) @@ -557,24 +561,35 @@ def setup_cookbook_routes() -> APIRouter: for d in model_dir.split(','): d = d.strip() if d: - model_dirs.append(d) - paths_code = _cached_model_scan_script(model_dirs) + translated_d = translate_path(d) if not host else d + model_dirs.append(translated_d) + win_hf_hub = None + if not host: + win_profile = get_wsl_windows_user_profile() + win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None + + paths_code = _cached_model_scan_script(model_dirs, win_hf_hub) scan_py = TMUX_LOG_DIR / "scan_cache.py" scan_py.write_text(paths_code, encoding="utf-8") + scan_payload = scan_py.read_bytes() if host: - _pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else "" if platform == "windows": - # Windows: use 'python' and pipe via stdin with double-quote wrapping - cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\'' + remote_cmd = "python -" else: - cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'" - proc = await asyncio.create_subprocess_shell( - cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - cwd=str(Path.home()), + # POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found. + remote_cmd = ( + "if command -v python3 >/dev/null 2>&1; then python3 -; " + "elif command -v python >/dev/null 2>&1; then python -; " + "else echo \"python3/python not found\" >&2; exit 127; fi" + ) + rc, stdout_b, stderr_b = await run_ssh_command_async( + host, + ssh_port, + remote_cmd, + timeout=60, + stdin_data=scan_payload, ) else: # LOCAL scan: use sys.executable (the venv Python Odysseus is already @@ -594,7 +609,7 @@ def setup_cookbook_routes() -> APIRouter: stderr=asyncio.subprocess.PIPE, cwd=str(Path.home()), ) - stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60) + stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60) models = [] try: @@ -874,6 +889,12 @@ def setup_cookbook_routes() -> APIRouter: # pip cache so they don't fail mid-build with "No space left" (#1219) # and leave the dep installed-but-unusable (#1459). req.cmd = _pip_install_no_cache(req.cmd) + # Accept common aliases and enforce server extras for llama-cpp so + # `python -m llama_cpp.server` has all runtime dependencies. + req.cmd = re.sub(r"(?=!~,` for version specifiers. # v2 review HIGH-14: tightened from the previous regex which @@ -1354,11 +1375,38 @@ def setup_cookbook_routes() -> APIRouter: async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8): """Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None).""" if host: - pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else "" - cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'" - proc = await asyncio.create_subprocess_shell( - cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) + candidates = [query] + stripped = query.strip() + if stripped.startswith("nvidia-smi "): + args = stripped[len("nvidia-smi "):] + candidates.append( + "bash -lc " + + shlex.quote( + f"{SSH_PATH_OVERRIDE}" + f"nvidia-smi {args}" + ) + ) + for nvidia_path in NVIDIA_PATH_CANDIDATES: + candidates.append(f"{nvidia_path} {args}") + + last_err = "nvidia-smi failed" + for candidate in candidates: + try: + rc, stdout, stderr = await run_ssh_command_async( + host, + ssh_port, + candidate, + connect_timeout=5, + timeout=timeout, + ) + except asyncio.TimeoutError: + return None, "nvidia-smi timed out" + if rc == 0: + return stdout.decode("utf-8", errors="replace"), None + err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200] + if err: + last_err = err + return None, last_err else: proc = await asyncio.create_subprocess_exec( *shlex.split(query), diff --git a/scripts/odysseus-cookbook b/scripts/odysseus-cookbook index 860a7903b..66a3057d2 100755 --- a/scripts/odysseus-cookbook +++ b/scripts/odysseus-cookbook @@ -47,6 +47,9 @@ _STATE_PATH = _DATA_DIR / "cookbook_state.json" import tempfile _TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux" +from core.platform_compat import NVIDIA_PATH_CANDIDATES, SSH_PATH_OVERRIDE + + def fail(msg: str, code: int = 1) -> None: sys.stderr.write(f"error: {msg}\n") @@ -160,7 +163,26 @@ def cmd_gpus(args) -> None: prefix = _ssh_prefix(args.host, args.ssh_port) cmd = prefix + (query.split() if not prefix else [query]) try: - out = subprocess.run(cmd, capture_output=True, text=True, timeout=15) + if prefix: + candidates = [query] + args_part = query[len("nvidia-smi "):] + candidates.append( + "bash -lc " + + repr( + f"{SSH_PATH_OVERRIDE}" + f"nvidia-smi {args_part}" + ) + ) + for nvidia_path in NVIDIA_PATH_CANDIDATES: + candidates.append(f"{nvidia_path} {args_part}") + + out = None + for candidate in candidates: + out = subprocess.run(prefix + [candidate], capture_output=True, text=True, timeout=15) + if out.returncode == 0: + break + else: + out = subprocess.run(cmd, capture_output=True, text=True, timeout=15) except FileNotFoundError: # No nvidia-smi locally → try the Metal fallback before giving up. if not prefix: diff --git a/services/hwfit/hardware.py b/services/hwfit/hardware.py index 2b47ffa2a..47ec94d44 100644 --- a/services/hwfit/hardware.py +++ b/services/hwfit/hardware.py @@ -4,6 +4,13 @@ import re import shutil import subprocess import time +import shlex + +from core.platform_compat import ( + NVIDIA_PATH_CANDIDATES, + SSH_PATH_OVERRIDE, + run_ssh_command, +) CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped # from 30 min so changing filters doesn't keep re-probing the rig every @@ -21,16 +28,17 @@ def _run(cmd): if _remote_host: # Run command on remote host via SSH if isinstance(cmd, list): - cmd_str = " ".join(cmd) + cmd_str = shlex.join(str(c) for c in cmd) else: cmd_str = cmd - ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"] - if _remote_port and _remote_port != "22": - ssh_cmd += ["-p", _remote_port] - ssh_cmd += [_remote_host, cmd_str] - r = subprocess.run( - ssh_cmd, - capture_output=True, text=True, timeout=15, + r = run_ssh_command( + _remote_host, + _remote_port, + cmd_str, + timeout=15, + connect_timeout=5, + strict_host_key_checking=False, + text=True, ) else: r = subprocess.run(cmd, capture_output=True, text=True, timeout=10) @@ -83,7 +91,7 @@ def _detect_nvidia(): # Retry through a login shell with the common CUDA bin dirs on PATH. if not out and _remote_host: out = _run( - "bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin:/usr/lib/wsl/lib\"; " + f"bash -lc '{SSH_PATH_OVERRIDE}" "nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'" ) # Last resort: call nvidia-smi by absolute path. Some hosts have a login @@ -92,7 +100,7 @@ def _detect_nvidia(): # Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path # that may not be in the server process's PATH. if not out: - for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi", "/usr/lib/wsl/lib/nvidia-smi"): + for _p in NVIDIA_PATH_CANDIDATES: # Use list form so subprocess.run (local) resolves the absolute path # correctly instead of treating the whole string as an executable name. if _remote_host: @@ -590,6 +598,19 @@ def _detect_windows(): _cache_by_host = {} # host -> (timestamp, result) +def _cache_key(host: str, ssh_port: str, platform_name: str): + """Build a stable cache key that isolates remote SSH context. + + Same host aliases can have different hardware due to visibility, forwarding etc. + To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key. + """ + return ( + host or "_local", + str(ssh_port or ""), + str(platform_name or "").lower(), + ) + + def detect_system(host="", ssh_port="", platform="", fresh=False): """Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely changes, and probing a remote host over SSH is slow). Pass fresh=True to @@ -599,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False): """ global _remote_host, _remote_port, _remote_platform - cache_key = host or "_local" + cache_key = _cache_key(host, ssh_port, platform) now = time.time() if not fresh and cache_key in _cache_by_host: ts, cached = _cache_by_host[cache_key] diff --git a/tests/test_cookbook_helpers.py b/tests/test_cookbook_helpers.py index bd05dd8a5..5666a7fd2 100644 --- a/tests/test_cookbook_helpers.py +++ b/tests/test_cookbook_helpers.py @@ -25,6 +25,7 @@ from routes.cookbook_helpers import ( _validate_serve_cmd, _validate_serve_model_id, _validate_ssh_port, + run_ssh_command_async, ) @@ -35,6 +36,56 @@ def test_safe_env_prefix_accepts_quoted_venv_path(): ) +@pytest.mark.asyncio +async def test_run_ssh_command_executes_with_stdin_and_returns_output(monkeypatch): + captured = {} + + class _Proc: + returncode = 0 + + async def communicate(self, input=None): + captured["input"] = input + return b"stdout", b"stderr" + + async def _fake_exec(*args, **kwargs): + captured["args"] = list(args) + captured["stdin"] = kwargs.get("stdin") + captured["stdout"] = kwargs.get("stdout") + captured["stderr"] = kwargs.get("stderr") + return _Proc() + + monkeypatch.setattr("asyncio.create_subprocess_exec", _fake_exec) + + rc, out, err = await run_ssh_command_async( + "alice@gpu-box", + "2222", + "python -", + timeout=5, + connect_timeout=4, + strict_host_key_checking=False, + stdin_data=b"python -m pip install vllm", + ) + + assert rc == 0 + assert out == b"stdout" + assert err == b"stderr" + assert captured["args"] == [ + "ssh", + "-o", + "ConnectTimeout=4", + "-o", + "StrictHostKeyChecking=no", + "-p", + "2222", + "alice@gpu-box", + "python -", + ] + assert captured["stdin"] is not None + assert captured["stdout"] is not None + assert captured["stderr"] is not None + assert captured["input"] == b"python -m pip install vllm" + + def test_safe_env_prefix_leaves_compound_conda_prefix_unchanged(): prefix = 'eval "$(conda shell.bash hook)" && conda activate qwen35' assert _safe_env_prefix(prefix) == prefix @@ -170,6 +221,8 @@ def test_pip_install_fallback_chain_quotes_extras_spec(): chain = _pip_install_fallback_chain("llama-cpp-python[server]", python_cmd="pip") # Quoted in both the plain and the --user attempt. assert chain.count("'llama-cpp-python[server]'") == 2 + # llama-cpp installs must prefer prebuilt wheels to avoid fragile source builds. + assert "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu" in chain # Never the unquoted form (bracket-glob risk). assert "install -q llama-cpp-python[server]" not in chain # A plain package name is still passed through unquoted (no regression). @@ -194,6 +247,17 @@ def test_serve_runner_installs_llama_cpp_server_extra(): assert "_pip_install_fallback_chain('llama-cpp-python[server]'" in src +def test_serve_pip_install_normalizes_llama_cpp_alias_and_adds_wheel_index(): + import pathlib + + src = (pathlib.Path(__file__).resolve().parent.parent + / "routes" / "cookbook_routes.py").read_text(encoding="utf-8") + + assert "re.sub(r\"(?/dev/null 2>&1 || python3() { python "$@"; }' in script + assert 'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }' in script def test_serve_preflight_failure_keeps_tmux_pane_visible(): @@ -606,3 +671,35 @@ def test_pip_install_no_cache_is_idempotent_and_scoped(): # not a pip install -> unchanged assert _pip_install_no_cache("vllm serve --model x") == "vllm serve --model x" assert _pip_install_no_cache("") == "" + + +def test_cached_model_scan_runs_additional_hf_cache(tmp_path): + extra_cache = tmp_path / "extra_hf_cache" + model_dir = extra_cache / "models--acme--sample-7b" + snap = model_dir / "snapshots" / "rev-1" + snap.mkdir(parents=True) + weights = snap / "model.safetensors" + weights.write_bytes(b"abc123") + + scan_py = tmp_path / "scan_cache.py" + scan_py.write_text( + _cached_model_scan_script(add_hf_cache=str(extra_cache)), + encoding="utf-8", + ) + proc = subprocess.run( + [sys.executable, str(scan_py)], + check=True, + capture_output=True, + text=True, + ) + + models = json.loads(proc.stdout) + by_repo = {m["repo_id"]: m for m in models} + + assert "acme/sample-7b" in by_repo + rec = by_repo["acme/sample-7b"] + assert rec["path"] == str(extra_cache) + assert rec["nb_files"] == 1 + assert rec["size_bytes"] == len(b"abc123") + assert rec["has_incomplete"] is False + assert rec["is_diffusion"] is False diff --git a/tests/test_hwfit_unified_nvidia.py b/tests/test_hwfit_unified_nvidia.py index 009288e31..0fdf751dd 100644 --- a/tests/test_hwfit_unified_nvidia.py +++ b/tests/test_hwfit_unified_nvidia.py @@ -71,3 +71,81 @@ def test_no_gpu_still_none(monkeypatch): """No nvidia-smi output → still None, no spurious unified GPU.""" monkeypatch.setattr(hardware, "_run", lambda cmd: None) assert hardware._detect_nvidia() is None + + +def test_detect_system_cache_separates_same_host_different_ports(monkeypatch): + """Keep cache separate by host+port+platform, don't use cached data""" + ram_gb = 0 + + def _ram(): + nonlocal ram_gb + ram_gb += 1 + return ram_gb * 64.0 + + monkeypatch.setattr(hardware, "_get_ram_gb", _ram) + monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0) + monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16) + monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen") + monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None) + monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None) + monkeypatch.setattr(hardware, "_detect_amd", lambda: None) + monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64") + + def _windows_probe(): + nonlocal ram_gb + ram_gb += 1 + return { + "total_ram_gb": ram_gb * 64.0, + "available_ram_gb": 40.0, + "cpu_cores": 16, + "cpu_name": "AMD Ryzen", + "has_gpu": False, + "gpu_name": None, + "gpu_vram_gb": None, + "gpu_count": 0, + "backend": "cpu_x86", + "homogeneous": True, + "gpu_error": None, + "platform": "windows", + } + + monkeypatch.setattr(hardware, "_detect_windows", _windows_probe) + hardware._cache_by_host.clear() + + hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False) + hardware.detect_system(host="user@wsl-host", ssh_port="2222", platform="linux", fresh=False) + hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="windows", fresh=False) + + assert len(hardware._cache_by_host) == 3 + assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0 + assert hardware._cache_by_host[("user@wsl-host", "2222", "linux")][1]["total_ram_gb"] == 128.0 + assert hardware._cache_by_host[("user@wsl-host", "22", "windows")][1]["total_ram_gb"] == 192.0 + + +def test_detect_system_cache_hits_when_remote_context_matches(monkeypatch): + """Cache hits when host+port+platform match""" + ram_gb = 0 + + def _ram(): + nonlocal ram_gb + ram_gb += 1 + return ram_gb * 64.0 + + monkeypatch.setattr(hardware, "_get_ram_gb", _ram) + monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0) + monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16) + monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen") + monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None) + monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None) + monkeypatch.setattr(hardware, "_detect_amd", lambda: None) + monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64") + hardware._cache_by_host.clear() + + hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False) + hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False) + hardware.detect_system(fresh=False) + hardware.detect_system(fresh=False) + + assert len(hardware._cache_by_host) == 2 + assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0 + assert hardware._cache_by_host[("_local", "", "")][1]["total_ram_gb"] == 128.0 diff --git a/tests/test_platform_compat.py b/tests/test_platform_compat.py index fbb43b802..2c45b9ce0 100644 --- a/tests/test_platform_compat.py +++ b/tests/test_platform_compat.py @@ -1,6 +1,8 @@ """Regression tests for cross-platform helper behavior.""" import importlib.util +import io +import sys from pathlib import Path @@ -59,3 +61,243 @@ def test_find_bash_skips_windows_wsl_stub(monkeypatch): monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected) assert platform_compat.find_bash() == expected + + +def test_is_wsl_true_when_proc_version_mentions_microsoft(monkeypatch): + monkeypatch.setattr(sys, "platform", "linux", raising=False) + + def fake_open(path, mode="r", *args, **kwargs): + assert path == "/proc/version" + assert mode == "r" + return io.StringIO("Linux version 6.6.0 microsoft standard") + + monkeypatch.setattr("builtins.open", fake_open) + + assert platform_compat.is_wsl() is True + + +def test_is_wsl_false_when_proc_version_is_not_microsoft(monkeypatch): + monkeypatch.setattr(sys, "platform", "linux", raising=False) + monkeypatch.setattr("builtins.open", lambda *_a, **_k: io.StringIO("Linux version 6.6.0 generic")) + + assert platform_compat.is_wsl() is False + + +def test_is_wsl_false_on_non_posix_without_proc_probe(monkeypatch): + monkeypatch.setattr(sys, "platform", "win32", raising=False) + monkeypatch.setattr(platform_compat.os, "name", "nt", raising=False) + + def fail_open(*_args, **_kwargs): + raise AssertionError("open should not be called when platform is not Linux/POSIX") + + monkeypatch.setattr("builtins.open", fail_open) + + assert platform_compat.is_wsl() is False + + +def test_translate_path_converts_windows_drive_path_on_wsl(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: True) + + out = platform_compat.translate_path(r"C:\Users\alice\models\qwen.gguf") + + assert out == "/mnt/c/Users/alice/models/qwen.gguf" + + +def test_translate_path_resolves_paths_when_not_wsl(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: False) + + assert platform_compat.translate_path(".") == str(Path(".").resolve()) + + +def test_translate_path_returns_input_when_resolve_fails(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: False) + + class _BrokenPath: + def __init__(self, _value): + pass + + def resolve(self): + raise RuntimeError("boom") + + monkeypatch.setattr(platform_compat, "Path", _BrokenPath) + + assert platform_compat.translate_path("weird::path") == "weird::path" + + +def test_get_wsl_windows_user_profile_prefers_powershell(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: True) + + class _Result: + returncode = 0 + stdout = "C:\\Users\\alice\\n" + + monkeypatch.setattr(platform_compat.subprocess, "run", lambda *_a, **_k: _Result()) + monkeypatch.setattr(platform_compat, "translate_path", lambda _v: "/mnt/c/Users/alice") + + assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice" + + +def test_get_wsl_windows_user_profile_falls_back_to_users_dir(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: True) + + def raise_run(*_a, **_k): + raise OSError("powershell unavailable") + + monkeypatch.setattr(platform_compat.subprocess, "run", raise_run) + monkeypatch.setattr( + platform_compat.os, + "listdir", + lambda _path: ["All Users", "Default", "Public", "alice"], + ) + + def fake_isdir(path): + return path in {"/mnt/c/Users", "/mnt/c/Users/alice"} + + monkeypatch.setattr(platform_compat.os.path, "isdir", fake_isdir) + + assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice" + + +def test_get_wsl_windows_user_profile_returns_none_when_nothing_found(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: True) + monkeypatch.setattr( + platform_compat.subprocess, + "run", + lambda *_a, **_k: (_ for _ in ()).throw(OSError("powershell unavailable")), + ) + monkeypatch.setattr(platform_compat.os.path, "isdir", lambda _path: False) + + assert platform_compat.get_wsl_windows_user_profile() is None + + +def test_nvidia_path_override_is_correct_string(monkeypatch): + monkeypatch.setattr(platform_compat, "_SSH_PATH_MEMBERS", ["path1", "path2"]) + assert platform_compat._ssh_path_override() == "export PATH=\"$PATH:path1:path2\"; " + + +def test_windows_powershell_argv_defaults_include_no_profile_and_noninteractive(): + argv = platform_compat._windows_powershell_argv("Write-Output Hello") + assert argv == [ + "powershell.exe", + "-NoProfile", + "-NonInteractive", + "-Command", + "Write-Output Hello", + ] + + +def test_windows_powershell_argv_respects_disabled_flags(): + argv = platform_compat._windows_powershell_argv( + "Write-Output Hello", + no_profile=False, + non_interactive=False, + ) + assert argv == ["powershell.exe", "-Command", "Write-Output Hello"] + + +def test_run_wsl_windows_powershell_raises_outside_wsl(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: False) + try: + platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=2) + raise AssertionError("Expected RuntimeError") + except RuntimeError as exc: + assert "only supported in WSL" in str(exc) + + +def test_run_wsl_windows_powershell_calls_subprocess_with_expected_argv(monkeypatch): + monkeypatch.setattr(platform_compat, "is_wsl", lambda: True) + captured = {} + + class _Result: + returncode = 0 + stdout = "ok\n" + stderr = "" + + def _fake_run(args, **kwargs): + captured["args"] = list(args) + captured["kwargs"] = kwargs + return _Result() + + monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run) + + result = platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=9) + + assert result.returncode == 0 + assert captured["args"] == [ + "powershell.exe", + "-NoProfile", + "-NonInteractive", + "-Command", + "Write-Output Hello", + ] + assert captured["kwargs"]["capture_output"] is True + assert captured["kwargs"]["text"] is True + assert captured["kwargs"]["timeout"] == 9 + + +def test_ssh_exec_argv_builds_default_command(): + argv = platform_compat._ssh_exec_argv("alice@gpu-box", None, remote_cmd="echo ok") + assert argv == ["ssh", "alice@gpu-box", "echo ok"] + + +def test_ssh_exec_argv_includes_port_and_options(): + argv = platform_compat._ssh_exec_argv( + "alice@gpu-box", + "2222", + remote_cmd="tmux ls", + connect_timeout=6, + strict_host_key_checking=False, + ) + assert argv == [ + "ssh", + "-o", + "ConnectTimeout=6", + "-o", + "StrictHostKeyChecking=no", + "-p", + "2222", + "alice@gpu-box", + "tmux ls", + ] + + +def test_run_ssh_command_uses_built_argv(monkeypatch): + captured = {} + + class _Result: + returncode = 0 + stdout = "ok" + stderr = "" + + def _fake_run(args, **kwargs): + captured["args"] = list(args) + captured["kwargs"] = kwargs + return _Result() + + monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run) + + result = platform_compat.run_ssh_command( + "alice@gpu-box", + "2200", + "tmux ls", + timeout=7, + connect_timeout=3, + strict_host_key_checking=True, + text=False, + ) + + assert result.returncode == 0 + assert captured["args"] == [ + "ssh", + "-o", + "ConnectTimeout=3", + "-o", + "StrictHostKeyChecking=yes", + "-p", + "2200", + "alice@gpu-box", + "tmux ls", + ] + assert captured["kwargs"]["timeout"] == 7 + assert captured["kwargs"]["capture_output"] is True + assert captured["kwargs"]["text"] is False