mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
fix(platform): Improve WSL SSH remote compatibility (#3316)
* fix(platform): add WSL compatibility functions and path translation fix(cookbook): enhance model scan script to support additional HuggingFace cache paths fix(hardware): improve cache key generation for remote SSH context test(tests): add tests for WSL detection and path translation functionality * fix(cookbook): prefer prebuilt wheels for llama-cpp-python and normalize package aliases * fix: enable StrictHostKeyChecking in nvidia probe refactor: consolidate ssh & powershell command execution to utility functions in core module refactor: consolidate nvidia path candidates in to single variables in core module tests: add tests for new utility functions * fix: correct wrong variable name
This commit is contained in:
+32
-11
@@ -4,6 +4,13 @@ import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import shlex
|
||||
|
||||
from core.platform_compat import (
|
||||
NVIDIA_PATH_CANDIDATES,
|
||||
SSH_PATH_OVERRIDE,
|
||||
run_ssh_command,
|
||||
)
|
||||
|
||||
CACHE_TTL = 24 * 3600 # 24 h — hardware probes are user-initiated via the Rescan button; bumped
|
||||
# from 30 min so changing filters doesn't keep re-probing the rig every
|
||||
@@ -21,16 +28,17 @@ def _run(cmd):
|
||||
if _remote_host:
|
||||
# Run command on remote host via SSH
|
||||
if isinstance(cmd, list):
|
||||
cmd_str = " ".join(cmd)
|
||||
cmd_str = shlex.join(str(c) for c in cmd)
|
||||
else:
|
||||
cmd_str = cmd
|
||||
ssh_cmd = ["ssh", "-o", "ConnectTimeout=5", "-o", "StrictHostKeyChecking=no"]
|
||||
if _remote_port and _remote_port != "22":
|
||||
ssh_cmd += ["-p", _remote_port]
|
||||
ssh_cmd += [_remote_host, cmd_str]
|
||||
r = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True, text=True, timeout=15,
|
||||
r = run_ssh_command(
|
||||
_remote_host,
|
||||
_remote_port,
|
||||
cmd_str,
|
||||
timeout=15,
|
||||
connect_timeout=5,
|
||||
strict_host_key_checking=False,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
r = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
@@ -83,7 +91,7 @@ def _detect_nvidia():
|
||||
# Retry through a login shell with the common CUDA bin dirs on PATH.
|
||||
if not out and _remote_host:
|
||||
out = _run(
|
||||
"bash -lc 'export PATH=\"$PATH:/usr/bin:/usr/local/bin:/usr/local/cuda/bin:/usr/lib/wsl/lib\"; "
|
||||
f"bash -lc '{SSH_PATH_OVERRIDE}"
|
||||
"nvidia-smi --query-gpu=memory.total,name --format=csv,noheader,nounits'"
|
||||
)
|
||||
# Last resort: call nvidia-smi by absolute path. Some hosts have a login
|
||||
@@ -92,7 +100,7 @@ def _detect_nvidia():
|
||||
# Also handles WSL where nvidia-smi lives at /usr/lib/wsl/lib/ — a path
|
||||
# that may not be in the server process's PATH.
|
||||
if not out:
|
||||
for _p in ("/usr/bin/nvidia-smi", "/usr/local/bin/nvidia-smi", "/usr/local/cuda/bin/nvidia-smi", "/usr/lib/wsl/lib/nvidia-smi"):
|
||||
for _p in NVIDIA_PATH_CANDIDATES:
|
||||
# Use list form so subprocess.run (local) resolves the absolute path
|
||||
# correctly instead of treating the whole string as an executable name.
|
||||
if _remote_host:
|
||||
@@ -590,6 +598,19 @@ def _detect_windows():
|
||||
_cache_by_host = {} # host -> (timestamp, result)
|
||||
|
||||
|
||||
def _cache_key(host: str, ssh_port: str, platform_name: str):
|
||||
"""Build a stable cache key that isolates remote SSH context.
|
||||
|
||||
Same host aliases can have different hardware due to visibility, forwarding etc.
|
||||
To avoid using the wrong cached hardware info, include the SSH port and platform in the cache key.
|
||||
"""
|
||||
return (
|
||||
host or "_local",
|
||||
str(ssh_port or ""),
|
||||
str(platform_name or "").lower(),
|
||||
)
|
||||
|
||||
|
||||
def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||
"""Detect system hardware: RAM, CPU, GPU. Cached per host (hardware rarely
|
||||
changes, and probing a remote host over SSH is slow). Pass fresh=True to
|
||||
@@ -599,7 +620,7 @@ def detect_system(host="", ssh_port="", platform="", fresh=False):
|
||||
"""
|
||||
global _remote_host, _remote_port, _remote_platform
|
||||
|
||||
cache_key = host or "_local"
|
||||
cache_key = _cache_key(host, ssh_port, platform)
|
||||
now = time.time()
|
||||
if not fresh and cache_key in _cache_by_host:
|
||||
ts, cached = _cache_by_host[cache_key]
|
||||
|
||||
Reference in New Issue
Block a user