mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
fix(platform): Improve WSL SSH remote compatibility (#3316)
* fix(platform): add WSL compatibility functions and path translation fix(cookbook): enhance model scan script to support additional HuggingFace cache paths fix(hardware): improve cache key generation for remote SSH context test(tests): add tests for WSL detection and path translation functionality * fix(cookbook): prefer prebuilt wheels for llama-cpp-python and normalize package aliases * fix: enable StrictHostKeyChecking in nvidia probe refactor: consolidate ssh & powershell command execution to utility functions in core module refactor: consolidate nvidia path candidates in to single variables in core module tests: add tests for new utility functions * fix: correct wrong variable name
This commit is contained in:
@@ -161,6 +161,29 @@ _WINDOWS_BASH_RELATIVE_PATHS = (
|
||||
("usr", "bin", "bash.exe"),
|
||||
)
|
||||
|
||||
# Paths to add to the remote SSH probe command to find tools like nvidia-smi that may not be on PATH.
|
||||
_SSH_PATH_MEMBERS = (
|
||||
"/usr/bin",
|
||||
"/usr/local/bin",
|
||||
"/usr/local/cuda/bin",
|
||||
"/usr/lib/wsl/lib"
|
||||
)
|
||||
# Fallback locations for nvidia-smi on WSL and other Linux distros where it may not be on PATH.
|
||||
NVIDIA_PATH_CANDIDATES = (
|
||||
"/usr/bin/nvidia-smi",
|
||||
"/usr/local/bin/nvidia-smi",
|
||||
"/usr/local/cuda/bin/nvidia-smi",
|
||||
"/usr/lib/wsl/lib/nvidia-smi",
|
||||
)
|
||||
|
||||
|
||||
def _ssh_path_override() -> str:
|
||||
"""Build the PATH export snippet used for remote SSH shell probes."""
|
||||
return f"export PATH=\"$PATH:{':'.join(_SSH_PATH_MEMBERS)}\"; "
|
||||
|
||||
|
||||
SSH_PATH_OVERRIDE = _ssh_path_override()
|
||||
|
||||
|
||||
def _windows_bash_fallbacks() -> List[str]:
|
||||
roots: List[str] = []
|
||||
@@ -268,3 +291,156 @@ def run_script_argv(script_path) -> List[str]:
|
||||
comspec = os.environ.get("ComSpec", "cmd.exe")
|
||||
return [comspec, "/c", str(script_path)]
|
||||
return ["sh", str(script_path)]
|
||||
|
||||
|
||||
def is_wsl() -> bool:
|
||||
"""True if running inside Windows Subsystem for Linux (WSL)."""
|
||||
import sys
|
||||
if sys.platform.startswith("linux") or os.name == "posix":
|
||||
try:
|
||||
with open("/proc/version", "r") as f:
|
||||
if "microsoft" in f.read().lower():
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def translate_path(path_str: str) -> str:
|
||||
"""Translate a path (possibly a Windows path) to the current OS format.
|
||||
|
||||
Particularly handles Windows paths (e.g. C:\\foo or C:/foo) when running
|
||||
under WSL, translating them to /mnt/c/foo.
|
||||
Also handles standard path normalization to avoid string breakages.
|
||||
"""
|
||||
if not path_str:
|
||||
return path_str
|
||||
|
||||
if is_wsl():
|
||||
path_str = path_str.replace("\\", "/")
|
||||
import re
|
||||
m = re.match(r"^([a-zA-Z]):(.*)", path_str)
|
||||
if m:
|
||||
drive = m.group(1).lower()
|
||||
rest = m.group(2)
|
||||
if not rest.startswith("/"):
|
||||
rest = "/" + rest
|
||||
return f"/mnt/{drive}{rest}"
|
||||
|
||||
try:
|
||||
return str(Path(path_str).resolve())
|
||||
except Exception:
|
||||
return path_str
|
||||
|
||||
|
||||
def get_wsl_windows_user_profile() -> Optional[str]:
|
||||
"""Retrieve the Windows host User Profile path from inside WSL."""
|
||||
if not is_wsl():
|
||||
return None
|
||||
try:
|
||||
r = run_wsl_windows_powershell("Write-Output $env:USERPROFILE", timeout=5)
|
||||
if r.returncode == 0 and r.stdout.strip():
|
||||
return translate_path(r.stdout.strip())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
users_dir = "/mnt/c/Users"
|
||||
if os.path.isdir(users_dir):
|
||||
for entry in os.listdir(users_dir):
|
||||
if entry not in ("All Users", "Default", "Default User", "desktop.ini", "Public"):
|
||||
path = os.path.join(users_dir, entry)
|
||||
if os.path.isdir(path):
|
||||
return path
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _ssh_exec_argv(
|
||||
remote: str,
|
||||
ssh_port: str | None,
|
||||
*,
|
||||
remote_cmd: str | None = None,
|
||||
connect_timeout: int | None = None,
|
||||
strict_host_key_checking: bool | None = None,
|
||||
) -> list[str]:
|
||||
"""Build a consistent ssh argv for remote command execution."""
|
||||
argv = ["ssh"]
|
||||
if connect_timeout is not None:
|
||||
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||
if strict_host_key_checking is not None:
|
||||
argv.extend(
|
||||
[
|
||||
"-o",
|
||||
"StrictHostKeyChecking=yes"
|
||||
if strict_host_key_checking
|
||||
else "StrictHostKeyChecking=no",
|
||||
]
|
||||
)
|
||||
if ssh_port and ssh_port != "22":
|
||||
argv.extend(["-p", str(ssh_port)])
|
||||
argv.append(remote)
|
||||
if remote_cmd is not None:
|
||||
argv.append(remote_cmd)
|
||||
return argv
|
||||
|
||||
|
||||
def run_ssh_command(
|
||||
remote: str,
|
||||
ssh_port: str | None,
|
||||
remote_cmd: str,
|
||||
*,
|
||||
timeout: float,
|
||||
connect_timeout: int | None = None,
|
||||
strict_host_key_checking: bool | None = None,
|
||||
text: bool = True,
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""Run an ssh command with centralized timeout and stderr/stdout capture."""
|
||||
return subprocess.run(
|
||||
_ssh_exec_argv(
|
||||
remote,
|
||||
ssh_port,
|
||||
remote_cmd=remote_cmd,
|
||||
connect_timeout=connect_timeout,
|
||||
strict_host_key_checking=strict_host_key_checking,
|
||||
),
|
||||
timeout=timeout,
|
||||
capture_output=True,
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
def _windows_powershell_argv(
|
||||
command: str,
|
||||
*,
|
||||
no_profile: bool = True,
|
||||
non_interactive: bool = True,
|
||||
) -> List[str]:
|
||||
argv: List[str] = ["powershell.exe"]
|
||||
if no_profile:
|
||||
argv.append("-NoProfile")
|
||||
if non_interactive:
|
||||
argv.append("-NonInteractive")
|
||||
argv.extend(["-Command", command])
|
||||
return argv
|
||||
|
||||
|
||||
def run_wsl_windows_powershell(
|
||||
command: str,
|
||||
*,
|
||||
timeout: float = 5,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Run a PowerShell command on the Windows host from WSL.
|
||||
|
||||
Raises ``RuntimeError`` when called outside WSL.
|
||||
"""
|
||||
|
||||
if not is_wsl():
|
||||
raise RuntimeError("run_wsl_windows_powershell is only supported in WSL")
|
||||
return subprocess.run(
|
||||
_windows_powershell_argv(command),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user