mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
fix(platform): Improve WSL SSH remote compatibility (#3316)
* fix(platform): add WSL compatibility functions and path translation fix(cookbook): enhance model scan script to support additional HuggingFace cache paths fix(hardware): improve cache key generation for remote SSH context test(tests): add tests for WSL detection and path translation functionality * fix(cookbook): prefer prebuilt wheels for llama-cpp-python and normalize package aliases * fix: enable StrictHostKeyChecking in nvidia probe refactor: consolidate ssh & powershell command execution to utility functions in core module refactor: consolidate nvidia path candidates in to single variables in core module tests: add tests for new utility functions * fix: correct wrong variable name
This commit is contained in:
@@ -161,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = (
|
|||||||
("usr", "bin", "bash.exe"),
|
("usr", "bin", "bash.exe"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH.
|
||||||
|
_SSH_PATH_MEMBERS = (
|
||||||
|
"/usr/bin",
|
||||||
|
"/usr/local/bin",
|
||||||
|
"/usr/local/cuda/bin",
|
||||||
|
"/usr/lib/wsl/lib"
|
||||||
|
)
|
||||||
|
# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH.
|
||||||
|
NVIDIA_PATH_CANDIDATES = (
|
||||||
|
"/usr/bin/nvidia-smi",
|
||||||
|
"/usr/local/bin/nvidia-smi",
|
||||||
|
"/usr/local/cuda/bin/nvidia-smi",
|
||||||
|
"/usr/lib/wsl/lib/nvidia-smi",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_path_override() -> str:
|
||||||
|
"""Build the PATH export snippet used for remote SSH shell probes."""
|
||||||
|
return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; "
|
||||||
|
|
||||||
|
|
||||||
|
SSH_PATH_OVERRIDE = _ssh_path_override()
|
||||||
|
|
||||||
|
|
||||||
def _windows_bash_fallbacks() -> List[str]:
|
def _windows_bash_fallbacks() -> List[str]:
|
||||||
roots: List[str] = []
|
roots: List[str] = []
|
||||||
@@ -268,3 +291,156 @@ def run_script_argv(script_path) -> List[str]:
|
|||||||
comspec = os.environ.get("ComSpec", "cmd.exe")
|
comspec = os.environ.get("ComSpec", "cmd.exe")
|
||||||
return [comspec, "/c", str(script_path)]
|
return [comspec, "/c", str(script_path)]
|
||||||
return ["sh", str(script_path)]
|
return ["sh", str(script_path)]
|
||||||
|
|
||||||
|
|
||||||
|
def is_wsl() -> bool:
|
||||||
|
"""True if running inside Windows Subsystem for Linux (WSL)."""
|
||||||
|
import sys
|
||||||
|
if sys.platform.startswith("linux") or os.name == "posix":
|
||||||
|
try:
|
||||||
|
with open("/proc/version", "r") as f:
|
||||||
|
if "microsoft" in f.read().lower():
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def translate_path(path_str: str) -> str:
|
||||||
|
"""Translate a path (possibly a Windows path) to the current OS format.
|
||||||
|
|
||||||
|
Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running
|
||||||
|
under WSL, translating them to /mnt/c/foo.
|
||||||
|
Also handles standard path normalization to avoid string breakages.
|
||||||
|
"""
|
||||||
|
if not path_str:
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
if is_wsl():
|
||||||
|
path_str = path_str.replace("\\", "/")
|
||||||
|
import re
|
||||||
|
m = re.match(r"^([a-zA-Z]):(.*)", path_str)
|
||||||
|
if m:
|
||||||
|
drive = m.group(1).lower()
|
||||||
|
rest = m.group(2)
|
||||||
|
if not rest.startswith("/"):
|
||||||
|
rest = "/" + rest
|
||||||
|
return f"/mnt/{drive}{rest}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return str(Path(path_str).resolve())
|
||||||
|
except Exception:
|
||||||
|
return path_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_wsl_windows_user_profile() -> Optional[str]:
|
||||||
|
"""Retrieve the Windows host User Profile path from inside WSL."""
|
||||||
|
if not is_wsl():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5)
|
||||||
|
if r.returncode == 0 and r.stdout.strip():
|
||||||
|
return translate_path(r.stdout.strip())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
users_dir = "/mnt/c/Users"
|
||||||
|
if os.path.isdir(users_dir):
|
||||||
|
for entry in os.listdir(users_dir):
|
||||||
|
if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"):
|
||||||
|
path = os.path.join(users_dir, entry)
|
||||||
|
if os.path.isdir(path):
|
||||||
|
return path
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _ssh_exec_argv(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
*,
|
||||||
|
remote_cmd: str | None = None,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build a consistent ssh argv for remote command execution."""
|
||||||
|
argv = ["ssh"]
|
||||||
|
if connect_timeout is not None:
|
||||||
|
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||||
|
if strict_host_key_checking is not None:
|
||||||
|
argv.extend(
|
||||||
|
[
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=yes"
|
||||||
|
if strict_host_key_checking
|
||||||
|
else "StrictHostKeyChecking=no",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if ssh_port and ssh_port != "22":
|
||||||
|
argv.extend(["-p", str(ssh_port)])
|
||||||
|
argv.append(remote)
|
||||||
|
if remote_cmd is not None:
|
||||||
|
argv.append(remote_cmd)
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def run_ssh_command(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
remote_cmd: str,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
text: bool = True,
|
||||||
|
) -> subprocess.CompletedProcess:
|
||||||
|
"""Run an ssh command with centralized timeout and stderr/stdout capture."""
|
||||||
|
return subprocess.run(
|
||||||
|
_ssh_exec_argv(
|
||||||
|
remote,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd=remote_cmd,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
strict_host_key_checking=strict_host_key_checking,
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
capture_output=True,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _windows_powershell_argv(
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
no_profile: bool = True,
|
||||||
|
non_interactive: bool = True,
|
||||||
|
) -> List[str]:
|
||||||
|
argv: List[str] = ["powershell.exe"]
|
||||||
|
if no_profile:
|
||||||
|
argv.append("-NoProfile")
|
||||||
|
if non_interactive:
|
||||||
|
argv.append("-NonInteractive")
|
||||||
|
argv.extend(["-Command", command])
|
||||||
|
return argv
|
||||||
|
|
||||||
|
|
||||||
|
def run_wsl_windows_powershell(
|
||||||
|
command: str,
|
||||||
|
*,
|
||||||
|
timeout: float = 5,
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
"""Run a PowerShell command on the Windows host from WSL.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` when called outside WSL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not is_wsl():
|
||||||
|
raise RuntimeError("run_wsl_windows_powershell is only supported in WSL")
|
||||||
|
return subprocess.run(
|
||||||
|
_windows_powershell_argv(command),
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import shlex
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.platform_compat import _ssh_exec_argv
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -213,7 +215,10 @@ def _pip_install_fallback_chain(package: str, *, python_cmd: str = "python3 -m p
|
|||||||
# before being embedded in the install command. Plain names (e.g.
|
# before being embedded in the install command. Plain names (e.g.
|
||||||
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
# ``huggingface_hub``) are returned unchanged by ``shlex.quote``.
|
||||||
pkg = shlex.quote(package)
|
pkg = shlex.quote(package)
|
||||||
if IS_WINDOWS and "llama-cpp-python" in package:
|
# llama-cpp-python source builds are brittle on older distro pip/packaging
|
||||||
|
# stacks (common on WSL images). Prefer the prebuilt wheel index whenever
|
||||||
|
# this package is requested so dependency-install tasks are reliable.
|
||||||
|
if "llama-cpp-python" in package:
|
||||||
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
pkg += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
|
|
||||||
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
base = _pip_install_attempt(f"{python_cmd} install -q{upgrade_flag} {pkg}")
|
||||||
@@ -275,11 +280,14 @@ def _user_shell_path_bootstrap() -> list[str]:
|
|||||||
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
' if [ -n "$ODYSSEUS_USER_PATH" ]; then export PATH="$ODYSSEUS_USER_PATH:$PATH"; fi',
|
||||||
'fi',
|
'fi',
|
||||||
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }',
|
||||||
|
'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
def _cached_model_scan_script(model_dirs: list[str] | None = None, add_hf_cache: str | None = None) -> str:
|
||||||
"""Build the standalone Python scanner used by /api/model/cached."""
|
"""Build the standalone Python scanner used by /api/model/cached.
|
||||||
|
Allows for an additional HuggingFace cache path to be scanned (i.e. Windows HF cache for local WSL envs.)
|
||||||
|
"""
|
||||||
lines = [
|
lines = [
|
||||||
"import json, os, re, shutil, subprocess, urllib.request",
|
"import json, os, re, shutil, subprocess, urllib.request",
|
||||||
"models = []",
|
"models = []",
|
||||||
@@ -372,6 +380,7 @@ def _cached_model_scan_script(model_dirs: list[str] | None = None) -> str:
|
|||||||
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
|
" # Docker images mount ./data/huggingface at /app/.cache/huggingface.",
|
||||||
" # When HOME is /root, expanduser() misses that persisted cache.",
|
" # When HOME is /root, expanduser() misses that persisted cache.",
|
||||||
" add('/app/.cache/huggingface/hub')",
|
" add('/app/.cache/huggingface/hub')",
|
||||||
|
f" add({add_hf_cache!r})" if add_hf_cache else "",
|
||||||
" return candidates",
|
" return candidates",
|
||||||
"def scan_dir(p):",
|
"def scan_dir(p):",
|
||||||
" if not os.path.isdir(p) or not safe_path(p): return",
|
" if not os.path.isdir(p) or not safe_path(p): return",
|
||||||
@@ -989,3 +998,40 @@ def _diagnose_serve_output(text: str) -> dict | None:
|
|||||||
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
"suggestions": [{"label": "inspect traceback and retry with adjusted backend/settings", "op": "manual"}],
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def run_ssh_command_async(
|
||||||
|
remote: str,
|
||||||
|
ssh_port: str | None,
|
||||||
|
remote_cmd: str,
|
||||||
|
*,
|
||||||
|
timeout: float,
|
||||||
|
connect_timeout: int | None = None,
|
||||||
|
strict_host_key_checking: bool | None = None,
|
||||||
|
stdin_data: bytes | None = None,
|
||||||
|
) -> tuple[int, bytes, bytes]:
|
||||||
|
"""Run an ssh command with centralized timeout and stderr/stdout capture.
|
||||||
|
Async version of core.platform_compat.run_ssh_command_sync.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*_ssh_exec_argv(
|
||||||
|
remote,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd=remote_cmd,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
strict_host_key_checking=strict_host_key_checking,
|
||||||
|
),
|
||||||
|
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
proc.communicate(input=stdin_data), timeout=timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.communicate()
|
||||||
|
raise
|
||||||
|
return proc.returncode or 0, stdout, stderr
|
||||||
+66
-18
@@ -20,6 +20,8 @@ from pydantic import BaseModel
|
|||||||
from core.middleware import require_admin
|
from core.middleware import require_admin
|
||||||
from core.platform_compat import (
|
from core.platform_compat import (
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
|
SSH_PATH_OVERRIDE,
|
||||||
|
NVIDIA_PATH_CANDIDATES,
|
||||||
detached_popen_kwargs,
|
detached_popen_kwargs,
|
||||||
find_bash,
|
find_bash,
|
||||||
git_bash_path,
|
git_bash_path,
|
||||||
@@ -27,6 +29,8 @@ from core.platform_compat import (
|
|||||||
pid_alive,
|
pid_alive,
|
||||||
safe_chmod,
|
safe_chmod,
|
||||||
which_tool,
|
which_tool,
|
||||||
|
translate_path,
|
||||||
|
get_wsl_windows_user_profile,
|
||||||
)
|
)
|
||||||
from routes.shell_routes import TMUX_LOG_DIR
|
from routes.shell_routes import TMUX_LOG_DIR
|
||||||
|
|
||||||
@@ -41,7 +45,7 @@ from routes.cookbook_helpers import (
|
|||||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||||
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
|
_append_vllm_linux_preflight_lines, _ollama_bind_from_cmd, _pip_install_fallback_chain,
|
||||||
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
_pip_install_no_cache, _user_shell_path_bootstrap, _venv_safe_local_pip_install_cmd,
|
||||||
_diagnose_serve_output,
|
_diagnose_serve_output, run_ssh_command_async,
|
||||||
ModelDownloadRequest, ServeRequest,
|
ModelDownloadRequest, ServeRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -557,24 +561,35 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
for d in model_dir.split(','):
|
for d in model_dir.split(','):
|
||||||
d = d.strip()
|
d = d.strip()
|
||||||
if d:
|
if d:
|
||||||
model_dirs.append(d)
|
translated_d = translate_path(d) if not host else d
|
||||||
paths_code = _cached_model_scan_script(model_dirs)
|
model_dirs.append(translated_d)
|
||||||
|
win_hf_hub = None
|
||||||
|
if not host:
|
||||||
|
win_profile = get_wsl_windows_user_profile()
|
||||||
|
win_hf_hub = os.path.join(win_profile, ".cache", "huggingface", "hub") if win_profile else None
|
||||||
|
|
||||||
|
paths_code = _cached_model_scan_script(model_dirs, win_hf_hub)
|
||||||
|
|
||||||
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
scan_py = TMUX_LOG_DIR / "scan_cache.py"
|
||||||
scan_py.write_text(paths_code, encoding="utf-8")
|
scan_py.write_text(paths_code, encoding="utf-8")
|
||||||
|
scan_payload = scan_py.read_bytes()
|
||||||
|
|
||||||
if host:
|
if host:
|
||||||
_pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
|
||||||
if platform == "windows":
|
if platform == "windows":
|
||||||
# Windows: use 'python' and pipe via stdin with double-quote wrapping
|
remote_cmd = "python -"
|
||||||
cmd = f'ssh {_pf}{host} "python -" < \'{scan_py}\''
|
|
||||||
else:
|
else:
|
||||||
cmd = f"ssh {_pf}{host} 'python3 -' < '{scan_py}'"
|
# POSIX: use 'python3' if available, fall back to 'python'; throw if neither is found.
|
||||||
proc = await asyncio.create_subprocess_shell(
|
remote_cmd = (
|
||||||
cmd,
|
"if command -v python3 >/dev/null 2>&1; then python3 -; "
|
||||||
stdout=asyncio.subprocess.PIPE,
|
"elif command -v python >/dev/null 2>&1; then python -; "
|
||||||
stderr=asyncio.subprocess.PIPE,
|
"else echo \"python3/python not found\" >&2; exit 127; fi"
|
||||||
cwd=str(Path.home()),
|
)
|
||||||
|
rc, stdout_b, stderr_b = await run_ssh_command_async(
|
||||||
|
host,
|
||||||
|
ssh_port,
|
||||||
|
remote_cmd,
|
||||||
|
timeout=60,
|
||||||
|
stdin_data=scan_payload,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
# LOCAL scan: use sys.executable (the venv Python Odysseus is already
|
||||||
@@ -594,7 +609,7 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
cwd=str(Path.home()),
|
cwd=str(Path.home()),
|
||||||
)
|
)
|
||||||
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=60)
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
try:
|
try:
|
||||||
@@ -874,6 +889,12 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
# pip cache so they don't fail mid-build with "No space left" (#1219)
|
||||||
# and leave the dep installed-but-unusable (#1459).
|
# and leave the dep installed-but-unusable (#1459).
|
||||||
req.cmd = _pip_install_no_cache(req.cmd)
|
req.cmd = _pip_install_no_cache(req.cmd)
|
||||||
|
# Accept common aliases and enforce server extras for llama-cpp so
|
||||||
|
# `python -m llama_cpp.server` has all runtime dependencies.
|
||||||
|
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])", "llama-cpp-python[server]", req.cmd)
|
||||||
|
req.cmd = re.sub(r"(?<![A-Za-z0-9_.-])llama-cpp-python(?!\[)", "llama-cpp-python[server]", req.cmd)
|
||||||
|
if "llama-cpp-python" in req.cmd and "--extra-index-url" not in req.cmd:
|
||||||
|
req.cmd += " --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu"
|
||||||
# PEP-508-style package spec — letters, digits, `.-_` for the
|
# PEP-508-style package spec — letters, digits, `.-_` for the
|
||||||
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
# name; `[` `]` for extras; `<>=!~,` for version specifiers.
|
||||||
# v2 review HIGH-14: tightened from the previous regex which
|
# v2 review HIGH-14: tightened from the previous regex which
|
||||||
@@ -1354,11 +1375,38 @@ def setup_cookbook_routes() -> APIRouter:
|
|||||||
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
async def _run_nvidia_smi(query: str, host: str | None, ssh_port: str | None, timeout: int = 8):
|
||||||
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
"""Run nvidia-smi locally or over SSH. Returns (stdout, error_or_None)."""
|
||||||
if host:
|
if host:
|
||||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port != "22" else ""
|
candidates = [query]
|
||||||
cmd = f"ssh -o ConnectTimeout=5 -o StrictHostKeyChecking=no {pf}{host} '{query}'"
|
stripped = query.strip()
|
||||||
proc = await asyncio.create_subprocess_shell(
|
if stripped.startswith("nvidia-smi "):
|
||||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
args = stripped[len("nvidia-smi "):]
|
||||||
)
|
candidates.append(
|
||||||
|
"bash -lc "
|
||||||
|
+ shlex.quote(
|
||||||
|
f"{SSH_PATH_OVERRIDE}"
|
||||||
|
f"nvidia-smi {args}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||||
|
candidates.append(f"{nvidia_path} {args}")
|
||||||
|
|
||||||
|
last_err = "nvidia-smi failed"
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
rc, stdout, stderr = await run_ssh_command_async(
|
||||||
|
host,
|
||||||
|
ssh_port,
|
||||||
|
candidate,
|
||||||
|
connect_timeout=5,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return None, "nvidia-smi timed out"
|
||||||
|
if rc == 0:
|
||||||
|
return stdout.decode("utf-8", errors="replace"), None
|
||||||
|
err = (stderr.decode("utf-8", errors="replace") or "").strip()[:200]
|
||||||
|
if err:
|
||||||
|
last_err = err
|
||||||
|
return None, last_err
|
||||||
else:
|
else:
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*shlex.split(query),
|
*shlex.split(query),
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ _STATE_PATH = _DATA_DIR / "cookbook_state.json"
|
|||||||
import tempfile
|
import tempfile
|
||||||
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
_TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
||||||
|
|
||||||
|
from core.platform_compat import NVIDIA_PATH_CANDIDATES, SSH_PATH_OVERRIDE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def fail(msg: str, code: int = 1) -> None:
|
def fail(msg: str, code: int = 1) -> None:
|
||||||
sys.stderr.write(f"error: {msg}\n")
|
sys.stderr.write(f"error: {msg}\n")
|
||||||
@@ -160,7 +163,26 @@ def cmd_gpus(args) -> None:
|
|||||||
prefix = _ssh_prefix(args.host, args.ssh_port)
|
prefix = _ssh_prefix(args.host, args.ssh_port)
|
||||||
cmd = prefix + (query.split() if not prefix else [query])
|
cmd = prefix + (query.split() if not prefix else [query])
|
||||||
try:
|
try:
|
||||||
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
if prefix:
|
||||||
|
candidates = [query]
|
||||||
|
args_part = query[len("nvidia-smi "):]
|
||||||
|
candidates.append(
|
||||||
|
"bash -lc "
|
||||||
|
+ repr(
|
||||||
|
f"{SSH_PATH_OVERRIDE}"
|
||||||
|
f"nvidia-smi {args_part}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for nvidia_path in NVIDIA_PATH_CANDIDATES:
|
||||||
|
candidates.append(f"{nvidia_path} {args_part}")
|
||||||
|
|
||||||
|
out = None
|
||||||
|
for candidate in candidates:
|
||||||
|
out = subprocess.run(prefix + [candidate], capture_output=True, text=True, timeout=15)
|
||||||
|
if out.returncode == 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
out = subprocess.run(cmd, capture_output=True, text=True, timeout=15)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
# No nvidia-smi locally → try the Metal fallback before giving up.
|
# No nvidia-smi locally → try the Metal fallback before giving up.
|
||||||
if not prefix:
|
if not prefix:
|
||||||
|
|||||||
+32
-11
@@ -4,6 +4,13 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
import shlex
|
||||||
|
|
||||||
|
from core.platform_compat import (
|
||||||
|
NVIDIA_PATH_CANDIDATES,
|
||||||
|
SSH_PATH_OVERRIDE,
|
||||||
|
run_ssh_command,
|
||||||
|
)
|
||||||
|
|
||||||
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
||||||
# from 30 min so changing filters doesn't keep re-probing the rig every
|
# from 30 min so changing filters doesn't keep re-probing the rig every
|
||||||
@@ -21,16 +28,17 @@ def _run(cmd):
|
|||||||
if _remote_host:
|
if _remote_host:
|
||||||
# Run command on remote host via SSH
|
# Run command on remote host via SSH
|
||||||
if isinstance(cmd, list):
|
if isinstance(cmd, list):
|
||||||
cmd_str = " ".join(cmd)
|
cmd_str = shlex.join(str(c) for c in cmd)
|
||||||
else:
|
else:
|
||||||
cmd_str = cmd
|
cmd_str = cmd
|
||||||
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
|
r = run_ssh_command(
|
||||||
if _remote_port and _remote_port != "22":
|
_remote_host,
|
||||||
ssh_cmd += ["-p", _remote_port]
|
_remote_port,
|
||||||
ssh_cmd += [_remote_host, cmd_str]
|
cmd_str,
|
||||||
r = subprocess.run(
|
timeout=15,
|
||||||
ssh_cmd,
|
connect_timeout=5,
|
||||||
capture_output=True, text=True, timeout=15,
|
strict_host_key_checking=False,
|
||||||
|
text=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||||
@@ -83,7 +91,7 @@ def _detect_nvidia():
|
|||||||
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
||||||
if not out and _remote_host:
|
if not out and _remote_host:
|
||||||
out = _run(
|
out = _run(
|
||||||
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin:/usr/lib/wsl/lib\"; "
|
f"bash -lc '{SSH_PATH_OVERRIDE}"
|
||||||
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
||||||
)
|
)
|
||||||
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
||||||
@@ -92,7 +100,7 @@ def _detect_nvidia():
|
|||||||
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
|
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
|
||||||
# that may not be in the server process's PATH.
|
# that may not be in the server process's PATH.
|
||||||
if not out:
|
if not out:
|
||||||
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi", "/usr/lib/wsl/lib/nvidia-smi"):
|
for _p in NVIDIA_PATH_CANDIDATES:
|
||||||
# Use list form so subprocess.run (local) resolves the absolute path
|
# Use list form so subprocess.run (local) resolves the absolute path
|
||||||
# correctly instead of treating the whole string as an executable name.
|
# correctly instead of treating the whole string as an executable name.
|
||||||
if _remote_host:
|
if _remote_host:
|
||||||
@@ -590,6 +598,19 @@ def _detect_windows():
|
|||||||
_cache_by_host = {} # host -> (timestamp, result)
|
_cache_by_host = {} # host -> (timestamp, result)
|
||||||
|
|
||||||
|
|
||||||
|
def _cache_key(host: str, ssh_port: str, platform_name: str):
|
||||||
|
"""Build a stable cache key that isolates remote SSH context.
|
||||||
|
|
||||||
|
Same host aliases can have different hardware due to visibility, forwarding etc.
|
||||||
|
To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
host or "_local",
|
||||||
|
str(ssh_port or ""),
|
||||||
|
str(platform_name or "").lower(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||||
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
||||||
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
||||||
@@ -599,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False):
|
|||||||
"""
|
"""
|
||||||
global _remote_host, _remote_port, _remote_platform
|
global _remote_host, _remote_port, _remote_platform
|
||||||
|
|
||||||
cache_key = host or "_local"
|
cache_key = _cache_key(host, ssh_port, platform)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if not fresh and cache_key in _cache_by_host:
|
if not fresh and cache_key in _cache_by_host:
|
||||||
ts, cached = _cache_by_host[cache_key]
|
ts, cached = _cache_by_host[cache_key]
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from routes.cookbook_helpers import (
|
|||||||
_validate_serve_cmd,
|
_validate_serve_cmd,
|
||||||
_validate_serve_model_id,
|
_validate_serve_model_id,
|
||||||
_validate_ssh_port,
|
_validate_ssh_port,
|
||||||
|
run_ssh_command_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,6 +36,56 @@ def test_safe_env_prefix_accepts_quoted_venv_path():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_ssh_command_executes_with_stdin_and_returns_output(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class _Proc:
|
||||||
|
returncode = 0
|
||||||
|
|
||||||
|
async def communicate(self, input=None):
|
||||||
|
captured["input"] = input
|
||||||
|
return b"stdout", b"stderr"
|
||||||
|
|
||||||
|
async def _fake_exec(*args, **kwargs):
|
||||||
|
captured["args"] = list(args)
|
||||||
|
captured["stdin"] = kwargs.get("stdin")
|
||||||
|
captured["stdout"] = kwargs.get("stdout")
|
||||||
|
captured["stderr"] = kwargs.get("stderr")
|
||||||
|
return _Proc()
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.create_subprocess_exec", _fake_exec)
|
||||||
|
|
||||||
|
rc, out, err = await run_ssh_command_async(
|
||||||
|
"alice@gpu-box",
|
||||||
|
"2222",
|
||||||
|
"python -",
|
||||||
|
timeout=5,
|
||||||
|
connect_timeout=4,
|
||||||
|
strict_host_key_checking=False,
|
||||||
|
stdin_data=b"python -m pip install vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rc == 0
|
||||||
|
assert out == b"stdout"
|
||||||
|
assert err == b"stderr"
|
||||||
|
assert captured["args"] == [
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=4",
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=no",
|
||||||
|
"-p",
|
||||||
|
"2222",
|
||||||
|
"alice@gpu-box",
|
||||||
|
"python -",
|
||||||
|
]
|
||||||
|
assert captured["stdin"] is not None
|
||||||
|
assert captured["stdout"] is not None
|
||||||
|
assert captured["stderr"] is not None
|
||||||
|
assert captured["input"] == b"python -m pip install vllm"
|
||||||
|
|
||||||
|
|
||||||
def test_safe_env_prefix_leaves_compound_conda_prefix_unchanged():
|
def test_safe_env_prefix_leaves_compound_conda_prefix_unchanged():
|
||||||
prefix = 'eval "$(conda shell.bash hook)" && conda activate qwen35'
|
prefix = 'eval "$(conda shell.bash hook)" && conda activate qwen35'
|
||||||
assert _safe_env_prefix(prefix) == prefix
|
assert _safe_env_prefix(prefix) == prefix
|
||||||
@@ -170,6 +221,8 @@ def test_pip_install_fallback_chain_quotes_extras_spec():
|
|||||||
chain = _pip_install_fallback_chain("llama-cpp-python[server]", python_cmd="pip")
|
chain = _pip_install_fallback_chain("llama-cpp-python[server]", python_cmd="pip")
|
||||||
# Quoted in both the plain and the --user attempt.
|
# Quoted in both the plain and the --user attempt.
|
||||||
assert chain.count("'llama-cpp-python[server]'") == 2
|
assert chain.count("'llama-cpp-python[server]'") == 2
|
||||||
|
# llama-cpp installs must prefer prebuilt wheels to avoid fragile source builds.
|
||||||
|
assert "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu" in chain
|
||||||
# Never the unquoted form (bracket-glob risk).
|
# Never the unquoted form (bracket-glob risk).
|
||||||
assert "install -q llama-cpp-python[server]" not in chain
|
assert "install -q llama-cpp-python[server]" not in chain
|
||||||
# A plain package name is still passed through unquoted (no regression).
|
# A plain package name is still passed through unquoted (no regression).
|
||||||
@@ -194,6 +247,17 @@ def test_serve_runner_installs_llama_cpp_server_extra():
|
|||||||
assert "_pip_install_fallback_chain('llama-cpp-python[server]'" in src
|
assert "_pip_install_fallback_chain('llama-cpp-python[server]'" in src
|
||||||
|
|
||||||
|
|
||||||
|
def test_serve_pip_install_normalizes_llama_cpp_alias_and_adds_wheel_index():
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
src = (pathlib.Path(__file__).resolve().parent.parent
|
||||||
|
/ "routes" / "cookbook_routes.py").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "re.sub(r\"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])\", \"llama-cpp-python[server]\", req.cmd)" in src
|
||||||
|
assert "if \"llama-cpp-python\" in req.cmd and \"--extra-index-url\" not in req.cmd:" in src
|
||||||
|
assert "https://abetlen.github.io/llama-cpp-python/whl/cpu" in src
|
||||||
|
|
||||||
|
|
||||||
def test_vllm_preflight_reports_cli_and_version():
|
def test_vllm_preflight_reports_cli_and_version():
|
||||||
lines = []
|
lines = []
|
||||||
|
|
||||||
@@ -289,6 +353,7 @@ def test_local_tooling_path_export_converts_windows_paths_for_bash():
|
|||||||
def test_user_shell_path_bootstrap_falls_back_to_python_on_windows_bash():
|
def test_user_shell_path_bootstrap_falls_back_to_python_on_windows_bash():
|
||||||
script = "\n".join(_user_shell_path_bootstrap())
|
script = "\n".join(_user_shell_path_bootstrap())
|
||||||
assert 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }' in script
|
assert 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }' in script
|
||||||
|
assert 'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }' in script
|
||||||
|
|
||||||
|
|
||||||
def test_serve_preflight_failure_keeps_tmux_pane_visible():
|
def test_serve_preflight_failure_keeps_tmux_pane_visible():
|
||||||
@@ -606,3 +671,35 @@ def test_pip_install_no_cache_is_idempotent_and_scoped():
|
|||||||
# not a pip install -> unchanged
|
# not a pip install -> unchanged
|
||||||
assert _pip_install_no_cache("vllm serve --model x") == "vllm serve --model x"
|
assert _pip_install_no_cache("vllm serve --model x") == "vllm serve --model x"
|
||||||
assert _pip_install_no_cache("") == ""
|
assert _pip_install_no_cache("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_cached_model_scan_runs_additional_hf_cache(tmp_path):
|
||||||
|
extra_cache = tmp_path / "extra_hf_cache"
|
||||||
|
model_dir = extra_cache / "models--acme--sample-7b"
|
||||||
|
snap = model_dir / "snapshots" / "rev-1"
|
||||||
|
snap.mkdir(parents=True)
|
||||||
|
weights = snap / "model.safetensors"
|
||||||
|
weights.write_bytes(b"abc123")
|
||||||
|
|
||||||
|
scan_py = tmp_path / "scan_cache.py"
|
||||||
|
scan_py.write_text(
|
||||||
|
_cached_model_scan_script(add_hf_cache=str(extra_cache)),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
proc = subprocess.run(
|
||||||
|
[sys.executable, str(scan_py)],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = json.loads(proc.stdout)
|
||||||
|
by_repo = {m["repo_id"]: m for m in models}
|
||||||
|
|
||||||
|
assert "acme/sample-7b" in by_repo
|
||||||
|
rec = by_repo["acme/sample-7b"]
|
||||||
|
assert rec["path"] == str(extra_cache)
|
||||||
|
assert rec["nb_files"] == 1
|
||||||
|
assert rec["size_bytes"] == len(b"abc123")
|
||||||
|
assert rec["has_incomplete"] is False
|
||||||
|
assert rec["is_diffusion"] is False
|
||||||
|
|||||||
@@ -71,3 +71,81 @@ def test_no_gpu_still_none(monkeypatch):
|
|||||||
"""No nvidia-smi output → still None, no spurious unified GPU."""
|
"""No nvidia-smi output → still None, no spurious unified GPU."""
|
||||||
monkeypatch.setattr(hardware, "_run", lambda cmd: None)
|
monkeypatch.setattr(hardware, "_run", lambda cmd: None)
|
||||||
assert hardware._detect_nvidia() is None
|
assert hardware._detect_nvidia() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_system_cache_separates_same_host_different_ports(monkeypatch):
|
||||||
|
"""Keep cache separate by host+port+platform, don't use cached data"""
|
||||||
|
ram_gb = 0
|
||||||
|
|
||||||
|
def _ram():
|
||||||
|
nonlocal ram_gb
|
||||||
|
ram_gb += 1
|
||||||
|
return ram_gb * 64.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(hardware, "_get_ram_gb", _ram)
|
||||||
|
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0)
|
||||||
|
monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16)
|
||||||
|
monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen")
|
||||||
|
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None)
|
||||||
|
monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None)
|
||||||
|
monkeypatch.setattr(hardware, "_detect_amd", lambda: None)
|
||||||
|
monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64")
|
||||||
|
|
||||||
|
def _windows_probe():
|
||||||
|
nonlocal ram_gb
|
||||||
|
ram_gb += 1
|
||||||
|
return {
|
||||||
|
"total_ram_gb": ram_gb * 64.0,
|
||||||
|
"available_ram_gb": 40.0,
|
||||||
|
"cpu_cores": 16,
|
||||||
|
"cpu_name": "AMD Ryzen",
|
||||||
|
"has_gpu": False,
|
||||||
|
"gpu_name": None,
|
||||||
|
"gpu_vram_gb": None,
|
||||||
|
"gpu_count": 0,
|
||||||
|
"backend": "cpu_x86",
|
||||||
|
"homogeneous": True,
|
||||||
|
"gpu_error": None,
|
||||||
|
"platform": "windows",
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(hardware, "_detect_windows", _windows_probe)
|
||||||
|
hardware._cache_by_host.clear()
|
||||||
|
|
||||||
|
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
|
||||||
|
hardware.detect_system(host="user@wsl-host", ssh_port="2222", platform="linux", fresh=False)
|
||||||
|
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="windows", fresh=False)
|
||||||
|
|
||||||
|
assert len(hardware._cache_by_host) == 3
|
||||||
|
assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0
|
||||||
|
assert hardware._cache_by_host[("user@wsl-host", "2222", "linux")][1]["total_ram_gb"] == 128.0
|
||||||
|
assert hardware._cache_by_host[("user@wsl-host", "22", "windows")][1]["total_ram_gb"] == 192.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_system_cache_hits_when_remote_context_matches(monkeypatch):
|
||||||
|
"""Cache hits when host+port+platform match"""
|
||||||
|
ram_gb = 0
|
||||||
|
|
||||||
|
def _ram():
|
||||||
|
nonlocal ram_gb
|
||||||
|
ram_gb += 1
|
||||||
|
return ram_gb * 64.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(hardware, "_get_ram_gb", _ram)
|
||||||
|
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0)
|
||||||
|
monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16)
|
||||||
|
monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen")
|
||||||
|
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None)
|
||||||
|
monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None)
|
||||||
|
monkeypatch.setattr(hardware, "_detect_amd", lambda: None)
|
||||||
|
monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64")
|
||||||
|
hardware._cache_by_host.clear()
|
||||||
|
|
||||||
|
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
|
||||||
|
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
|
||||||
|
hardware.detect_system(fresh=False)
|
||||||
|
hardware.detect_system(fresh=False)
|
||||||
|
|
||||||
|
assert len(hardware._cache_by_host) == 2
|
||||||
|
assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0
|
||||||
|
assert hardware._cache_by_host[("_local", "", "")][1]["total_ram_gb"] == 128.0
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Regression tests for cross-platform helper behavior."""
|
"""Regression tests for cross-platform helper behavior."""
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import io
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
@@ -59,3 +61,243 @@ def test_find_bash_skips_windows_wsl_stub(monkeypatch):
|
|||||||
monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected)
|
monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected)
|
||||||
|
|
||||||
assert platform_compat.find_bash() == expected
|
assert platform_compat.find_bash() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_wsl_true_when_proc_version_mentions_microsoft(monkeypatch):
|
||||||
|
monkeypatch.setattr(sys, "platform", "linux", raising=False)
|
||||||
|
|
||||||
|
def fake_open(path, mode="r", *args, **kwargs):
|
||||||
|
assert path == "/proc/version"
|
||||||
|
assert mode == "r"
|
||||||
|
return io.StringIO("Linux version 6.6.0 microsoft standard")
|
||||||
|
|
||||||
|
monkeypatch.setattr("builtins.open", fake_open)
|
||||||
|
|
||||||
|
assert platform_compat.is_wsl() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_wsl_false_when_proc_version_is_not_microsoft(monkeypatch):
|
||||||
|
monkeypatch.setattr(sys, "platform", "linux", raising=False)
|
||||||
|
monkeypatch.setattr("builtins.open", lambda *_a, **_k: io.StringIO("Linux version 6.6.0 generic"))
|
||||||
|
|
||||||
|
assert platform_compat.is_wsl() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_wsl_false_on_non_posix_without_proc_probe(monkeypatch):
|
||||||
|
monkeypatch.setattr(sys, "platform", "win32", raising=False)
|
||||||
|
monkeypatch.setattr(platform_compat.os, "name", "nt", raising=False)
|
||||||
|
|
||||||
|
def fail_open(*_args, **_kwargs):
|
||||||
|
raise AssertionError("open should not be called when platform is not Linux/POSIX")
|
||||||
|
|
||||||
|
monkeypatch.setattr("builtins.open", fail_open)
|
||||||
|
|
||||||
|
assert platform_compat.is_wsl() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_path_converts_windows_drive_path_on_wsl(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||||
|
|
||||||
|
out = platform_compat.translate_path(r"C:\Users\alice\models\qwen.gguf")
|
||||||
|
|
||||||
|
assert out == "/mnt/c/Users/alice/models/qwen.gguf"
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_path_resolves_paths_when_not_wsl(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
|
||||||
|
|
||||||
|
assert platform_compat.translate_path(".") == str(Path(".").resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_path_returns_input_when_resolve_fails(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
|
||||||
|
|
||||||
|
class _BrokenPath:
|
||||||
|
def __init__(self, _value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def resolve(self):
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
monkeypatch.setattr(platform_compat, "Path", _BrokenPath)
|
||||||
|
|
||||||
|
assert platform_compat.translate_path("weird::path") == "weird::path"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_wsl_windows_user_profile_prefers_powershell(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||||
|
|
||||||
|
class _Result:
|
||||||
|
returncode = 0
|
||||||
|
stdout = "C:\\Users\\alice\\n"
|
||||||
|
|
||||||
|
monkeypatch.setattr(platform_compat.subprocess, "run", lambda *_a, **_k: _Result())
|
||||||
|
monkeypatch.setattr(platform_compat, "translate_path", lambda _v: "/mnt/c/Users/alice")
|
||||||
|
|
||||||
|
assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_wsl_windows_user_profile_falls_back_to_users_dir(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||||
|
|
||||||
|
def raise_run(*_a, **_k):
|
||||||
|
raise OSError("powershell unavailable")
|
||||||
|
|
||||||
|
monkeypatch.setattr(platform_compat.subprocess, "run", raise_run)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
platform_compat.os,
|
||||||
|
"listdir",
|
||||||
|
lambda _path: ["All Users", "Default", "Public", "alice"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_isdir(path):
|
||||||
|
return path in {"/mnt/c/Users", "/mnt/c/Users/alice"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(platform_compat.os.path, "isdir", fake_isdir)
|
||||||
|
|
||||||
|
assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_wsl_windows_user_profile_returns_none_when_nothing_found(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
platform_compat.subprocess,
|
||||||
|
"run",
|
||||||
|
lambda *_a, **_k: (_ for _ in ()).throw(OSError("powershell unavailable")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(platform_compat.os.path, "isdir", lambda _path: False)
|
||||||
|
|
||||||
|
assert platform_compat.get_wsl_windows_user_profile() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_nvidia_path_override_is_correct_string(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "_SSH_PATH_MEMBERS", ["path1", "path2"])
|
||||||
|
assert platform_compat._ssh_path_override() == "export PATH=\"$PATH:path1:path2\"; "
|
||||||
|
|
||||||
|
|
||||||
|
def test_windows_powershell_argv_defaults_include_no_profile_and_noninteractive():
|
||||||
|
argv = platform_compat._windows_powershell_argv("Write-Output Hello")
|
||||||
|
assert argv == [
|
||||||
|
"powershell.exe",
|
||||||
|
"-NoProfile",
|
||||||
|
"-NonInteractive",
|
||||||
|
"-Command",
|
||||||
|
"Write-Output Hello",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_windows_powershell_argv_respects_disabled_flags():
|
||||||
|
argv = platform_compat._windows_powershell_argv(
|
||||||
|
"Write-Output Hello",
|
||||||
|
no_profile=False,
|
||||||
|
non_interactive=False,
|
||||||
|
)
|
||||||
|
assert argv == ["powershell.exe", "-Command", "Write-Output Hello"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_wsl_windows_powershell_raises_outside_wsl(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
|
||||||
|
try:
|
||||||
|
platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=2)
|
||||||
|
raise AssertionError("Expected RuntimeError")
|
||||||
|
except RuntimeError as exc:
|
||||||
|
assert "only supported in WSL" in str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_wsl_windows_powershell_calls_subprocess_with_expected_argv(monkeypatch):
|
||||||
|
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class _Result:
|
||||||
|
returncode = 0
|
||||||
|
stdout = "ok\n"
|
||||||
|
stderr = ""
|
||||||
|
|
||||||
|
def _fake_run(args, **kwargs):
|
||||||
|
captured["args"] = list(args)
|
||||||
|
captured["kwargs"] = kwargs
|
||||||
|
return _Result()
|
||||||
|
|
||||||
|
monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run)
|
||||||
|
|
||||||
|
result = platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=9)
|
||||||
|
|
||||||
|
assert result.returncode == 0
|
||||||
|
assert captured["args"] == [
|
||||||
|
"powershell.exe",
|
||||||
|
"-NoProfile",
|
||||||
|
"-NonInteractive",
|
||||||
|
"-Command",
|
||||||
|
"Write-Output Hello",
|
||||||
|
]
|
||||||
|
assert captured["kwargs"]["capture_output"] is True
|
||||||
|
assert captured["kwargs"]["text"] is True
|
||||||
|
assert captured["kwargs"]["timeout"] == 9
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssh_exec_argv_builds_default_command():
|
||||||
|
argv = platform_compat._ssh_exec_argv("alice@gpu-box", None, remote_cmd="echo ok")
|
||||||
|
assert argv == ["ssh", "alice@gpu-box", "echo ok"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssh_exec_argv_includes_port_and_options():
|
||||||
|
argv = platform_compat._ssh_exec_argv(
|
||||||
|
"alice@gpu-box",
|
||||||
|
"2222",
|
||||||
|
remote_cmd="tmux ls",
|
||||||
|
connect_timeout=6,
|
||||||
|
strict_host_key_checking=False,
|
||||||
|
)
|
||||||
|
assert argv == [
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=6",
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=no",
|
||||||
|
"-p",
|
||||||
|
"2222",
|
||||||
|
"alice@gpu-box",
|
||||||
|
"tmux ls",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_ssh_command_uses_built_argv(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class _Result:
|
||||||
|
returncode = 0
|
||||||
|
stdout = "ok"
|
||||||
|
stderr = ""
|
||||||
|
|
||||||
|
def _fake_run(args, **kwargs):
|
||||||
|
captured["args"] = list(args)
|
||||||
|
captured["kwargs"] = kwargs
|
||||||
|
return _Result()
|
||||||
|
|
||||||
|
monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run)
|
||||||
|
|
||||||
|
result = platform_compat.run_ssh_command(
|
||||||
|
"alice@gpu-box",
|
||||||
|
"2200",
|
||||||
|
"tmux ls",
|
||||||
|
timeout=7,
|
||||||
|
connect_timeout=3,
|
||||||
|
strict_host_key_checking=True,
|
||||||
|
text=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 0
|
||||||
|
assert captured["args"] == [
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=3",
|
||||||
|
"-o",
|
||||||
|
"StrictHostKeyChecking=yes",
|
||||||
|
"-p",
|
||||||
|
"2200",
|
||||||
|
"alice@gpu-box",
|
||||||
|
"tmux ls",
|
||||||
|
]
|
||||||
|
assert captured["kwargs"]["timeout"] == 7
|
||||||
|
assert captured["kwargs"]["capture_output"] is True
|
||||||
|
assert captured["kwargs"]["text"] is False
|
||||||
|
|||||||
Reference in New Issue
Block a user