mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
fix(platform): Improve WSL SSH remote compatibility (#3316)
* fix(platform): add WSL compatibility functions and path translation fix(cookbook): enhance model scan script to support additional HuggingFace cache paths fix(hardware): improve cache key generation for remote SSH context test(tests): add tests for WSL detection and path translation functionality * fix(cookbook): prefer prebuilt wheels for llama-cpp-python and normalize package aliases * fix: enable StrictHostKeyChecking in nvidia probe refactor: consolidate ssh & powershell command execution to utility functions in core module refactor: consolidate nvidia path candidates in to single variables in core module tests: add tests for new utility functions * fix: correct wrong variable name
This commit is contained in:
@@ -25,6 +25,7 @@ from routes.cookbook_helpers import (
|
||||
_validate_serve_cmd,
|
||||
_validate_serve_model_id,
|
||||
_validate_ssh_port,
|
||||
run_ssh_command_async,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,6 +36,56 @@ def test_safe_env_prefix_accepts_quoted_venv_path():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ssh_command_executes_with_stdin_and_returns_output(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class _Proc:
|
||||
returncode = 0
|
||||
|
||||
async def communicate(self, input=None):
|
||||
captured["input"] = input
|
||||
return b"stdout", b"stderr"
|
||||
|
||||
async def _fake_exec(*args, **kwargs):
|
||||
captured["args"] = list(args)
|
||||
captured["stdin"] = kwargs.get("stdin")
|
||||
captured["stdout"] = kwargs.get("stdout")
|
||||
captured["stderr"] = kwargs.get("stderr")
|
||||
return _Proc()
|
||||
|
||||
monkeypatch.setattr("asyncio.create_subprocess_exec", _fake_exec)
|
||||
|
||||
rc, out, err = await run_ssh_command_async(
|
||||
"alice@gpu-box",
|
||||
"2222",
|
||||
"python -",
|
||||
timeout=5,
|
||||
connect_timeout=4,
|
||||
strict_host_key_checking=False,
|
||||
stdin_data=b"python -m pip install vllm",
|
||||
)
|
||||
|
||||
assert rc == 0
|
||||
assert out == b"stdout"
|
||||
assert err == b"stderr"
|
||||
assert captured["args"] == [
|
||||
"ssh",
|
||||
"-o",
|
||||
"ConnectTimeout=4",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-p",
|
||||
"2222",
|
||||
"alice@gpu-box",
|
||||
"python -",
|
||||
]
|
||||
assert captured["stdin"] is not None
|
||||
assert captured["stdout"] is not None
|
||||
assert captured["stderr"] is not None
|
||||
assert captured["input"] == b"python -m pip install vllm"
|
||||
|
||||
|
||||
def test_safe_env_prefix_leaves_compound_conda_prefix_unchanged():
|
||||
prefix = 'eval "$(conda shell.bash hook)" && conda activate qwen35'
|
||||
assert _safe_env_prefix(prefix) == prefix
|
||||
@@ -170,6 +221,8 @@ def test_pip_install_fallback_chain_quotes_extras_spec():
|
||||
chain = _pip_install_fallback_chain("llama-cpp-python[server]", python_cmd="pip")
|
||||
# Quoted in both the plain and the --user attempt.
|
||||
assert chain.count("'llama-cpp-python[server]'") == 2
|
||||
# llama-cpp installs must prefer prebuilt wheels to avoid fragile source builds.
|
||||
assert "--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu" in chain
|
||||
# Never the unquoted form (bracket-glob risk).
|
||||
assert "install -q llama-cpp-python[server]" not in chain
|
||||
# A plain package name is still passed through unquoted (no regression).
|
||||
@@ -194,6 +247,17 @@ def test_serve_runner_installs_llama_cpp_server_extra():
|
||||
assert "_pip_install_fallback_chain('llama-cpp-python[server]'" in src
|
||||
|
||||
|
||||
def test_serve_pip_install_normalizes_llama_cpp_alias_and_adds_wheel_index():
|
||||
import pathlib
|
||||
|
||||
src = (pathlib.Path(__file__).resolve().parent.parent
|
||||
/ "routes" / "cookbook_routes.py").read_text(encoding="utf-8")
|
||||
|
||||
assert "re.sub(r\"(?<![A-Za-z0-9_.-])llama_cpp(?![A-Za-z0-9_.-])\", \"llama-cpp-python[server]\", req.cmd)" in src
|
||||
assert "if \"llama-cpp-python\" in req.cmd and \"--extra-index-url\" not in req.cmd:" in src
|
||||
assert "https://abetlen.github.io/llama-cpp-python/whl/cpu" in src
|
||||
|
||||
|
||||
def test_vllm_preflight_reports_cli_and_version():
|
||||
lines = []
|
||||
|
||||
@@ -289,6 +353,7 @@ def test_local_tooling_path_export_converts_windows_paths_for_bash():
|
||||
def test_user_shell_path_bootstrap_falls_back_to_python_on_windows_bash():
|
||||
script = "\n".join(_user_shell_path_bootstrap())
|
||||
assert 'command -v python3 >/dev/null 2>&1 || python3() { python "$@"; }' in script
|
||||
assert 'command -v python >/dev/null 2>&1 || python() { python3 "$@"; }' in script
|
||||
|
||||
|
||||
def test_serve_preflight_failure_keeps_tmux_pane_visible():
|
||||
@@ -606,3 +671,35 @@ def test_pip_install_no_cache_is_idempotent_and_scoped():
|
||||
# not a pip install -> unchanged
|
||||
assert _pip_install_no_cache("vllm serve --model x") == "vllm serve --model x"
|
||||
assert _pip_install_no_cache("") == ""
|
||||
|
||||
|
||||
def test_cached_model_scan_runs_additional_hf_cache(tmp_path):
|
||||
extra_cache = tmp_path / "extra_hf_cache"
|
||||
model_dir = extra_cache / "models--acme--sample-7b"
|
||||
snap = model_dir / "snapshots" / "rev-1"
|
||||
snap.mkdir(parents=True)
|
||||
weights = snap / "model.safetensors"
|
||||
weights.write_bytes(b"abc123")
|
||||
|
||||
scan_py = tmp_path / "scan_cache.py"
|
||||
scan_py.write_text(
|
||||
_cached_model_scan_script(add_hf_cache=str(extra_cache)),
|
||||
encoding="utf-8",
|
||||
)
|
||||
proc = subprocess.run(
|
||||
[sys.executable, str(scan_py)],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
models = json.loads(proc.stdout)
|
||||
by_repo = {m["repo_id"]: m for m in models}
|
||||
|
||||
assert "acme/sample-7b" in by_repo
|
||||
rec = by_repo["acme/sample-7b"]
|
||||
assert rec["path"] == str(extra_cache)
|
||||
assert rec["nb_files"] == 1
|
||||
assert rec["size_bytes"] == len(b"abc123")
|
||||
assert rec["has_incomplete"] is False
|
||||
assert rec["is_diffusion"] is False
|
||||
|
||||
@@ -71,3 +71,81 @@ def test_no_gpu_still_none(monkeypatch):
|
||||
"""No nvidia-smi output → still None, no spurious unified GPU."""
|
||||
monkeypatch.setattr(hardware, "_run", lambda cmd: None)
|
||||
assert hardware._detect_nvidia() is None
|
||||
|
||||
|
||||
def test_detect_system_cache_separates_same_host_different_ports(monkeypatch):
|
||||
"""Keep cache separate by host+port+platform, don't use cached data"""
|
||||
ram_gb = 0
|
||||
|
||||
def _ram():
|
||||
nonlocal ram_gb
|
||||
ram_gb += 1
|
||||
return ram_gb * 64.0
|
||||
|
||||
monkeypatch.setattr(hardware, "_get_ram_gb", _ram)
|
||||
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0)
|
||||
monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16)
|
||||
monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen")
|
||||
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None)
|
||||
monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None)
|
||||
monkeypatch.setattr(hardware, "_detect_amd", lambda: None)
|
||||
monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64")
|
||||
|
||||
def _windows_probe():
|
||||
nonlocal ram_gb
|
||||
ram_gb += 1
|
||||
return {
|
||||
"total_ram_gb": ram_gb * 64.0,
|
||||
"available_ram_gb": 40.0,
|
||||
"cpu_cores": 16,
|
||||
"cpu_name": "AMD Ryzen",
|
||||
"has_gpu": False,
|
||||
"gpu_name": None,
|
||||
"gpu_vram_gb": None,
|
||||
"gpu_count": 0,
|
||||
"backend": "cpu_x86",
|
||||
"homogeneous": True,
|
||||
"gpu_error": None,
|
||||
"platform": "windows",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(hardware, "_detect_windows", _windows_probe)
|
||||
hardware._cache_by_host.clear()
|
||||
|
||||
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
|
||||
hardware.detect_system(host="user@wsl-host", ssh_port="2222", platform="linux", fresh=False)
|
||||
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="windows", fresh=False)
|
||||
|
||||
assert len(hardware._cache_by_host) == 3
|
||||
assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0
|
||||
assert hardware._cache_by_host[("user@wsl-host", "2222", "linux")][1]["total_ram_gb"] == 128.0
|
||||
assert hardware._cache_by_host[("user@wsl-host", "22", "windows")][1]["total_ram_gb"] == 192.0
|
||||
|
||||
|
||||
def test_detect_system_cache_hits_when_remote_context_matches(monkeypatch):
|
||||
"""Cache hits when host+port+platform match"""
|
||||
ram_gb = 0
|
||||
|
||||
def _ram():
|
||||
nonlocal ram_gb
|
||||
ram_gb += 1
|
||||
return ram_gb * 64.0
|
||||
|
||||
monkeypatch.setattr(hardware, "_get_ram_gb", _ram)
|
||||
monkeypatch.setattr(hardware, "_get_available_ram_gb", lambda: 40.0)
|
||||
monkeypatch.setattr(hardware, "_get_cpu_count", lambda: 16)
|
||||
monkeypatch.setattr(hardware, "_get_cpu_name", lambda: "AMD Ryzen")
|
||||
monkeypatch.setattr(hardware, "_detect_apple_silicon", lambda: None)
|
||||
monkeypatch.setattr(hardware, "_detect_nvidia", lambda: None)
|
||||
monkeypatch.setattr(hardware, "_detect_amd", lambda: None)
|
||||
monkeypatch.setattr(hardware, "_run", lambda _cmd: "x86_64")
|
||||
hardware._cache_by_host.clear()
|
||||
|
||||
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
|
||||
hardware.detect_system(host="user@wsl-host", ssh_port="22", platform="linux", fresh=False)
|
||||
hardware.detect_system(fresh=False)
|
||||
hardware.detect_system(fresh=False)
|
||||
|
||||
assert len(hardware._cache_by_host) == 2
|
||||
assert hardware._cache_by_host[("user@wsl-host", "22", "linux")][1]["total_ram_gb"] == 64.0
|
||||
assert hardware._cache_by_host[("_local", "", "")][1]["total_ram_gb"] == 128.0
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Regression tests for cross-platform helper behavior."""
|
||||
|
||||
import importlib.util
|
||||
import io
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -59,3 +61,243 @@ def test_find_bash_skips_windows_wsl_stub(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat.os.path, "exists", lambda path: path == expected)
|
||||
|
||||
assert platform_compat.find_bash() == expected
|
||||
|
||||
|
||||
def test_is_wsl_true_when_proc_version_mentions_microsoft(monkeypatch):
|
||||
monkeypatch.setattr(sys, "platform", "linux", raising=False)
|
||||
|
||||
def fake_open(path, mode="r", *args, **kwargs):
|
||||
assert path == "/proc/version"
|
||||
assert mode == "r"
|
||||
return io.StringIO("Linux version 6.6.0 microsoft standard")
|
||||
|
||||
monkeypatch.setattr("builtins.open", fake_open)
|
||||
|
||||
assert platform_compat.is_wsl() is True
|
||||
|
||||
|
||||
def test_is_wsl_false_when_proc_version_is_not_microsoft(monkeypatch):
|
||||
monkeypatch.setattr(sys, "platform", "linux", raising=False)
|
||||
monkeypatch.setattr("builtins.open", lambda *_a, **_k: io.StringIO("Linux version 6.6.0 generic"))
|
||||
|
||||
assert platform_compat.is_wsl() is False
|
||||
|
||||
|
||||
def test_is_wsl_false_on_non_posix_without_proc_probe(monkeypatch):
|
||||
monkeypatch.setattr(sys, "platform", "win32", raising=False)
|
||||
monkeypatch.setattr(platform_compat.os, "name", "nt", raising=False)
|
||||
|
||||
def fail_open(*_args, **_kwargs):
|
||||
raise AssertionError("open should not be called when platform is not Linux/POSIX")
|
||||
|
||||
monkeypatch.setattr("builtins.open", fail_open)
|
||||
|
||||
assert platform_compat.is_wsl() is False
|
||||
|
||||
|
||||
def test_translate_path_converts_windows_drive_path_on_wsl(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||
|
||||
out = platform_compat.translate_path(r"C:\Users\alice\models\qwen.gguf")
|
||||
|
||||
assert out == "/mnt/c/Users/alice/models/qwen.gguf"
|
||||
|
||||
|
||||
def test_translate_path_resolves_paths_when_not_wsl(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
|
||||
|
||||
assert platform_compat.translate_path(".") == str(Path(".").resolve())
|
||||
|
||||
|
||||
def test_translate_path_returns_input_when_resolve_fails(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
|
||||
|
||||
class _BrokenPath:
|
||||
def __init__(self, _value):
|
||||
pass
|
||||
|
||||
def resolve(self):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(platform_compat, "Path", _BrokenPath)
|
||||
|
||||
assert platform_compat.translate_path("weird::path") == "weird::path"
|
||||
|
||||
|
||||
def test_get_wsl_windows_user_profile_prefers_powershell(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||
|
||||
class _Result:
|
||||
returncode = 0
|
||||
stdout = "C:\\Users\\alice\\n"
|
||||
|
||||
monkeypatch.setattr(platform_compat.subprocess, "run", lambda *_a, **_k: _Result())
|
||||
monkeypatch.setattr(platform_compat, "translate_path", lambda _v: "/mnt/c/Users/alice")
|
||||
|
||||
assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice"
|
||||
|
||||
|
||||
def test_get_wsl_windows_user_profile_falls_back_to_users_dir(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||
|
||||
def raise_run(*_a, **_k):
|
||||
raise OSError("powershell unavailable")
|
||||
|
||||
monkeypatch.setattr(platform_compat.subprocess, "run", raise_run)
|
||||
monkeypatch.setattr(
|
||||
platform_compat.os,
|
||||
"listdir",
|
||||
lambda _path: ["All Users", "Default", "Public", "alice"],
|
||||
)
|
||||
|
||||
def fake_isdir(path):
|
||||
return path in {"/mnt/c/Users", "/mnt/c/Users/alice"}
|
||||
|
||||
monkeypatch.setattr(platform_compat.os.path, "isdir", fake_isdir)
|
||||
|
||||
assert platform_compat.get_wsl_windows_user_profile() == "/mnt/c/Users/alice"
|
||||
|
||||
|
||||
def test_get_wsl_windows_user_profile_returns_none_when_nothing_found(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
platform_compat.subprocess,
|
||||
"run",
|
||||
lambda *_a, **_k: (_ for _ in ()).throw(OSError("powershell unavailable")),
|
||||
)
|
||||
monkeypatch.setattr(platform_compat.os.path, "isdir", lambda _path: False)
|
||||
|
||||
assert platform_compat.get_wsl_windows_user_profile() is None
|
||||
|
||||
|
||||
def test_nvidia_path_override_is_correct_string(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "_SSH_PATH_MEMBERS", ["path1", "path2"])
|
||||
assert platform_compat._ssh_path_override() == "export PATH=\"$PATH:path1:path2\"; "
|
||||
|
||||
|
||||
def test_windows_powershell_argv_defaults_include_no_profile_and_noninteractive():
|
||||
argv = platform_compat._windows_powershell_argv("Write-Output Hello")
|
||||
assert argv == [
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
"Write-Output Hello",
|
||||
]
|
||||
|
||||
|
||||
def test_windows_powershell_argv_respects_disabled_flags():
|
||||
argv = platform_compat._windows_powershell_argv(
|
||||
"Write-Output Hello",
|
||||
no_profile=False,
|
||||
non_interactive=False,
|
||||
)
|
||||
assert argv == ["powershell.exe", "-Command", "Write-Output Hello"]
|
||||
|
||||
|
||||
def test_run_wsl_windows_powershell_raises_outside_wsl(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: False)
|
||||
try:
|
||||
platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=2)
|
||||
raise AssertionError("Expected RuntimeError")
|
||||
except RuntimeError as exc:
|
||||
assert "only supported in WSL" in str(exc)
|
||||
|
||||
|
||||
def test_run_wsl_windows_powershell_calls_subprocess_with_expected_argv(monkeypatch):
|
||||
monkeypatch.setattr(platform_compat, "is_wsl", lambda: True)
|
||||
captured = {}
|
||||
|
||||
class _Result:
|
||||
returncode = 0
|
||||
stdout = "ok\n"
|
||||
stderr = ""
|
||||
|
||||
def _fake_run(args, **kwargs):
|
||||
captured["args"] = list(args)
|
||||
captured["kwargs"] = kwargs
|
||||
return _Result()
|
||||
|
||||
monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run)
|
||||
|
||||
result = platform_compat.run_wsl_windows_powershell("Write-Output Hello", timeout=9)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert captured["args"] == [
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-NonInteractive",
|
||||
"-Command",
|
||||
"Write-Output Hello",
|
||||
]
|
||||
assert captured["kwargs"]["capture_output"] is True
|
||||
assert captured["kwargs"]["text"] is True
|
||||
assert captured["kwargs"]["timeout"] == 9
|
||||
|
||||
|
||||
def test_ssh_exec_argv_builds_default_command():
|
||||
argv = platform_compat._ssh_exec_argv("alice@gpu-box", None, remote_cmd="echo ok")
|
||||
assert argv == ["ssh", "alice@gpu-box", "echo ok"]
|
||||
|
||||
|
||||
def test_ssh_exec_argv_includes_port_and_options():
|
||||
argv = platform_compat._ssh_exec_argv(
|
||||
"alice@gpu-box",
|
||||
"2222",
|
||||
remote_cmd="tmux ls",
|
||||
connect_timeout=6,
|
||||
strict_host_key_checking=False,
|
||||
)
|
||||
assert argv == [
|
||||
"ssh",
|
||||
"-o",
|
||||
"ConnectTimeout=6",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-p",
|
||||
"2222",
|
||||
"alice@gpu-box",
|
||||
"tmux ls",
|
||||
]
|
||||
|
||||
|
||||
def test_run_ssh_command_uses_built_argv(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class _Result:
|
||||
returncode = 0
|
||||
stdout = "ok"
|
||||
stderr = ""
|
||||
|
||||
def _fake_run(args, **kwargs):
|
||||
captured["args"] = list(args)
|
||||
captured["kwargs"] = kwargs
|
||||
return _Result()
|
||||
|
||||
monkeypatch.setattr(platform_compat.subprocess, "run", _fake_run)
|
||||
|
||||
result = platform_compat.run_ssh_command(
|
||||
"alice@gpu-box",
|
||||
"2200",
|
||||
"tmux ls",
|
||||
timeout=7,
|
||||
connect_timeout=3,
|
||||
strict_host_key_checking=True,
|
||||
text=False,
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert captured["args"] == [
|
||||
"ssh",
|
||||
"-o",
|
||||
"ConnectTimeout=3",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=yes",
|
||||
"-p",
|
||||
"2200",
|
||||
"alice@gpu-box",
|
||||
"tmux ls",
|
||||
]
|
||||
assert captured["kwargs"]["timeout"] == 7
|
||||
assert captured["kwargs"]["capture_output"] is True
|
||||
assert captured["kwargs"]["text"] is False
|
||||
|
||||
Reference in New Issue
Block a user