mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
Compare commits
3 Commits
218b9ecbc8
...
d5603ee575
| Author | SHA1 | Date | |
|---|---|---|---|
| d5603ee575 | |||
| 9c00da6d1c | |||
| d1a5a7d680 |
@@ -503,6 +503,7 @@ api_key_manager = components["api_key_manager"]
|
||||
preset_manager = components["preset_manager"]
|
||||
chat_processor = components["chat_processor"]
|
||||
research_handler = components["research_handler"]
|
||||
app.state.research_handler = research_handler
|
||||
chat_handler = components["chat_handler"]
|
||||
model_discovery = components["model_discovery"]
|
||||
skills_manager = components["skills_manager"]
|
||||
|
||||
@@ -366,6 +366,10 @@ def _ssh_exec_argv(
|
||||
strict_host_key_checking: bool | None = None,
|
||||
) -> list[str]:
|
||||
"""Build a consistent ssh argv for remote command execution."""
|
||||
remote_value = str(remote or "").strip()
|
||||
remote_host = remote_value.rsplit("@", 1)[-1]
|
||||
if not remote_value or remote_value.startswith("-") or not remote_host or remote_host.startswith("-"):
|
||||
raise ValueError("Invalid SSH remote host")
|
||||
argv = ["ssh"]
|
||||
if connect_timeout is not None:
|
||||
argv.extend(["-o", f"ConnectTimeout={int(connect_timeout)}"])
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import re
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
_REMOTE_HOST_RE = re.compile(
|
||||
r"^(?:[A-Za-z0-9][A-Za-z0-9._-]*@)?[A-Za-z0-9][A-Za-z0-9._-]*$"
|
||||
)
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
|
||||
|
||||
def validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Invalid remote_host — must be host or user@host, no SSH option syntax",
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
@@ -367,6 +367,20 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename user prefs %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# In-flight deep-research tasks live in the process-local
|
||||
# ResearchHandler registry. They are not covered by the persisted JSON
|
||||
# migration above, but the research routes filter and cancel by this
|
||||
# owner field while the job is running. Do this before sweeping
|
||||
# completed JSON files so a job that finishes during the rename saves
|
||||
# with the new owner or is caught by the disk sweep below.
|
||||
try:
|
||||
rh = getattr(request.app.state, "research_handler", None)
|
||||
rename_owner = getattr(rh, "rename_owner", None)
|
||||
if callable(rename_owner):
|
||||
rename_owner(old_username, new_username)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rename active research tasks %s -> %s: %s", old_username, new_username, e)
|
||||
|
||||
# deep_research: each completed report is a standalone JSON file with
|
||||
# an `owner` field. research_routes filters by d.get("owner") == user,
|
||||
# so a stale owner makes every report invisible to the renamed user.
|
||||
|
||||
@@ -11,6 +11,7 @@ import shlex
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,16 +31,12 @@ _LOCAL_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
_OLLAMA_MODEL_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:/-]{0,200}$")
|
||||
# Include pattern is a glob: allow typical safe glyphs only.
|
||||
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
|
||||
# Remote host: either `user@host` or plain `host` (alias is allowed), where host
|
||||
# is a safe DNS-like token or a short SSH config alias.
|
||||
_REMOTE_HOST_RE = re.compile(r"^(?:[A-Za-z0-9._-]+@)?[A-Za-z0-9._-]+$")
|
||||
# HF tokens and API tokens are url-safe base64-like.
|
||||
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
|
||||
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
|
||||
# Anything beyond plain alphanumerics + dash + underscore could break out
|
||||
# of the shell/PowerShell contexts the value lands in.
|
||||
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
|
||||
# A download target directory. Absolute or ~-relative path; safe path glyphs
|
||||
# only (no quotes or shell metacharacters). Spaces are allowed because command
|
||||
@@ -85,14 +82,6 @@ def _validate_include(v: str | None) -> str | None:
|
||||
return v
|
||||
|
||||
|
||||
def _validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(400, "Invalid remote_host — must be host or user@host, no SSH option syntax")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_token(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
@@ -120,17 +109,6 @@ def _validate_local_dir(v: str | None) -> str | None:
|
||||
return v
|
||||
|
||||
|
||||
def _validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
|
||||
|
||||
def _validate_gpus(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
|
||||
+43
-31
@@ -19,6 +19,7 @@ from src.constants import COOKBOOK_STATE_FILE
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
from core.platform_compat import (
|
||||
IS_WINDOWS,
|
||||
detached_popen_kwargs,
|
||||
@@ -33,9 +34,8 @@ from routes.shell_routes import TMUX_LOG_DIR
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from routes.cookbook_helpers import (
|
||||
_SSH_PORT_RE, _REMOTE_HOST_RE, _SESSION_ID_RE,
|
||||
_validate_repo_id, _validate_serve_model_id, _validate_include, _validate_remote_host, _validate_token,
|
||||
_validate_local_dir, _validate_ssh_port, _validate_gpus, _shell_path,
|
||||
_SESSION_ID_RE, _validate_repo_id, _validate_serve_model_id, _validate_include, _validate_token,
|
||||
_validate_local_dir, _validate_gpus, _shell_path,
|
||||
_ps_squote, _bash_squote, _validate_serve_cmd, _parse_serve_phase,
|
||||
_safe_env_prefix, _local_tooling_path_export, _append_serve_preflight_exit_lines,
|
||||
_append_serve_exit_code_lines, _append_llama_cpp_linux_accel_build_lines, _cached_model_scan_script,
|
||||
@@ -407,8 +407,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
else:
|
||||
_validate_repo_id(req.repo_id)
|
||||
_validate_include(req.include)
|
||||
_validate_remote_host(req.remote_host)
|
||||
req.ssh_port = _validate_ssh_port(req.ssh_port)
|
||||
validate_remote_host(req.remote_host)
|
||||
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||
req.local_dir = _validate_local_dir(req.local_dir)
|
||||
req.hf_token = "" if is_ollama_download else (req.hf_token or _load_stored_hf_token())
|
||||
_validate_token(req.hf_token)
|
||||
@@ -739,9 +739,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# Validate shell-bound inputs, matching the sibling list_gpus endpoint —
|
||||
# `host`/`ssh_port` are interpolated into an ssh command below, so an
|
||||
# unvalidated value (e.g. "x'; rm -rf ~ #") would be command injection.
|
||||
host = _validate_remote_host(host)
|
||||
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
host = validate_remote_host(host)
|
||||
ssh_port = validate_ssh_port(ssh_port)
|
||||
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_dirs = []
|
||||
@@ -890,11 +889,16 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
# listening" check without requiring ss/netstat/nmap.
|
||||
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||
if ssh_port and str(ssh_port) != "22":
|
||||
if not _SSH_PORT_RE.match(str(ssh_port)):
|
||||
try:
|
||||
ssh_port = validate_ssh_port(ssh_port)
|
||||
except HTTPException:
|
||||
return None
|
||||
ssh_base.extend(["-p", str(ssh_port)])
|
||||
host_arg = remote
|
||||
if not _REMOTE_HOST_RE.match(host_arg):
|
||||
try:
|
||||
host_arg = validate_remote_host(remote)
|
||||
except HTTPException:
|
||||
return None
|
||||
if not host_arg:
|
||||
return None
|
||||
probe_ports = " ".join(str(start_port + i) for i in range(max_offset + 1))
|
||||
script = (
|
||||
@@ -1197,8 +1201,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
"""
|
||||
require_admin(request)
|
||||
# Defence-in-depth: reject values that could break out of shell contexts.
|
||||
_validate_remote_host(req.remote_host)
|
||||
req.ssh_port = _validate_ssh_port(req.ssh_port)
|
||||
validate_remote_host(req.remote_host)
|
||||
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||
req.gpus = _validate_gpus(req.gpus)
|
||||
req.hf_token = req.hf_token or _load_stored_hf_token()
|
||||
_validate_token(req.hf_token)
|
||||
@@ -1638,12 +1642,11 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
async def server_setup(request: Request, req: SetupRequest):
|
||||
"""Install required dependencies on a remote server via SSH."""
|
||||
require_admin(request)
|
||||
host = _validate_remote_host(req.host)
|
||||
host = validate_remote_host(req.host)
|
||||
if not host:
|
||||
raise HTTPException(400, "host is required")
|
||||
port = req.ssh_port
|
||||
if port is not None and port != "" and not re.fullmatch(r"\d{1,5}", port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = validate_ssh_port(port)
|
||||
pf = f"-p {port} " if port and port != "22" else ""
|
||||
|
||||
# Detect platform: Windows first (echo %OS% → Windows_NT), then Termux, then Linux
|
||||
@@ -1887,9 +1890,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
`busy` is True when free_mb/total_mb < 0.5.
|
||||
"""
|
||||
require_admin(request)
|
||||
host = _validate_remote_host(host)
|
||||
if ssh_port is not None and ssh_port != "" and not _SSH_PORT_RE.fullmatch(ssh_port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
host = validate_remote_host(host)
|
||||
ssh_port = validate_ssh_port(ssh_port)
|
||||
gpu_query = "nvidia-smi --query-gpu=index,name,memory.free,memory.total,memory.used,utilization.gpu,uuid --format=csv,noheader,nounits"
|
||||
nvidia_error = None
|
||||
try:
|
||||
@@ -2046,9 +2048,8 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
sig = (req.signal or "TERM").upper()
|
||||
if sig not in ("TERM", "KILL", "INT"):
|
||||
raise HTTPException(400, "signal must be TERM, KILL, or INT")
|
||||
host = _validate_remote_host(req.host)
|
||||
if req.ssh_port and not _SSH_PORT_RE.fullmatch(req.ssh_port):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
host = validate_remote_host(req.host)
|
||||
req.ssh_port = validate_ssh_port(req.ssh_port)
|
||||
kill_cmd = f"kill -{sig} {req.pid}"
|
||||
try:
|
||||
if host:
|
||||
@@ -2382,14 +2383,19 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
host = (srv.get("host") or "").strip()
|
||||
if not host:
|
||||
continue # local-only entry; the /proc scan handles it
|
||||
if not _REMOTE_HOST_RE.match(host):
|
||||
try:
|
||||
host = validate_remote_host(host)
|
||||
except HTTPException:
|
||||
continue
|
||||
sport = str(srv.get("port") or "").strip()
|
||||
ssh_base = ["ssh", "-o", "ConnectTimeout=4", "-o", "StrictHostKeyChecking=no"]
|
||||
if sport and sport != "22":
|
||||
if not _SSH_PORT_RE.match(sport):
|
||||
try:
|
||||
sport = validate_ssh_port(sport)
|
||||
except HTTPException:
|
||||
continue
|
||||
ssh_base.extend(["-p", sport])
|
||||
if sport != "22":
|
||||
ssh_base.extend(["-p", sport])
|
||||
|
||||
try:
|
||||
ls = subprocess.run(
|
||||
@@ -2743,12 +2749,18 @@ def setup_cookbook_routes() -> APIRouter:
|
||||
if not _SESSION_ID_RE.match(session_id):
|
||||
logger.warning(f"Skipping task with unsafe session_id: {session_id!r}")
|
||||
continue
|
||||
if remote and not _REMOTE_HOST_RE.match(remote):
|
||||
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
|
||||
continue
|
||||
if _tport and not _SSH_PORT_RE.match(str(_tport)):
|
||||
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
|
||||
continue
|
||||
if remote:
|
||||
try:
|
||||
remote = validate_remote_host(remote)
|
||||
except HTTPException:
|
||||
logger.warning(f"Skipping task with unsafe remoteHost: {remote!r}")
|
||||
continue
|
||||
if _tport:
|
||||
try:
|
||||
_tport = validate_ssh_port(str(_tport))
|
||||
except HTTPException:
|
||||
logger.warning(f"Skipping task with unsafe sshPort: {_tport!r}")
|
||||
continue
|
||||
if task_platform == "windows" and remote:
|
||||
# Windows: check PID file + Get-Process, read log tail
|
||||
sd = "$env:TEMP\\odysseus-sessions"
|
||||
|
||||
+23
-3
@@ -1,7 +1,9 @@
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
|
||||
|
||||
# Backends the manual hardware simulator accepts. Must stay a subset of what
|
||||
@@ -11,6 +13,14 @@ from fastapi import APIRouter
|
||||
_MANUAL_BACKENDS = {"cuda", "rocm", "metal", "cpu_x86", "cpu_arm"}
|
||||
|
||||
|
||||
def _validate_detection_target(host: str = "", ssh_port: str = "") -> tuple[str, str]:
|
||||
host_value = validate_remote_host(host) or ""
|
||||
port_value = validate_ssh_port(ssh_port) or ""
|
||||
if port_value and not host_value:
|
||||
raise HTTPException(400, "ssh_port requires host")
|
||||
return host_value, port_value
|
||||
|
||||
|
||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||
"""Manual hardware is a "what if I had this setup" simulator —
|
||||
REPLACES the detected hardware entirely instead of adding to it.
|
||||
@@ -105,6 +115,7 @@ def setup_hwfit_routes():
|
||||
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
||||
fresh=true bypasses the per-host cache (the Rescan button)."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
|
||||
@router.get("/models")
|
||||
@@ -118,6 +129,7 @@ def setup_hwfit_routes():
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.fit import rank_models
|
||||
from services.hwfit.models import get_models, model_catalog_path
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
@@ -165,8 +177,14 @@ def setup_hwfit_routes():
|
||||
system["gpu_name"] = g["name"]
|
||||
system["active_group"] = {**g, "use_count": n}
|
||||
|
||||
if gpu_count != "":
|
||||
n = int(gpu_count)
|
||||
# Parse the optional count defensively (matches the gpu_group guard
|
||||
# above): a non-numeric query param previously raised ValueError ->
|
||||
# HTTP 500. A malformed value is ignored, same as omitting it.
|
||||
try:
|
||||
n = int(gpu_count) if gpu_count != "" else None
|
||||
except ValueError:
|
||||
n = None
|
||||
if n is not None:
|
||||
if n == 0:
|
||||
# RAM-only mode: rank against system memory, offload allowed.
|
||||
system["has_gpu"] = False
|
||||
@@ -229,6 +247,7 @@ def setup_hwfit_routes():
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.models import get_models
|
||||
from services.hwfit.profiles import compute_serve_profiles
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
if system.get("error"):
|
||||
return {"system": system, "profiles": [], "error": system["error"]}
|
||||
@@ -279,6 +298,7 @@ def setup_hwfit_routes():
|
||||
"""Rank image generation models against detected hardware."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.image_models import rank_image_models
|
||||
host, ssh_port = _validate_detection_target(host, ssh_port)
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
|
||||
@@ -221,6 +221,22 @@ class ResearchHandler:
|
||||
# Task registry — background research with persistence
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def rename_owner(self, old_owner: str, new_owner: str) -> int:
|
||||
"""Move in-flight research tasks from one owner key to another."""
|
||||
old_key = str(old_owner or "").strip().lower()
|
||||
new_key = str(new_owner or "").strip().lower()
|
||||
if not old_key or not new_key:
|
||||
return 0
|
||||
|
||||
changed = 0
|
||||
for entry in list(self._active_tasks.values()):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if str(entry.get("owner", "")).strip().lower() == old_key:
|
||||
entry["owner"] = new_key
|
||||
changed += 1
|
||||
return changed
|
||||
|
||||
def start_research(
|
||||
self,
|
||||
session_id: str,
|
||||
|
||||
@@ -26,7 +26,6 @@ from routes.cookbook_helpers import (
|
||||
_validate_repo_id,
|
||||
_validate_serve_cmd,
|
||||
_validate_serve_model_id,
|
||||
_validate_ssh_port,
|
||||
_shell_path,
|
||||
run_ssh_command_async,
|
||||
)
|
||||
@@ -106,12 +105,6 @@ def test_safe_env_prefix_accepts_powershell_activation_path():
|
||||
)
|
||||
|
||||
|
||||
def test_validate_ssh_port_rejects_shell_payload():
|
||||
with pytest.raises(HTTPException):
|
||||
_validate_ssh_port("22; touch /tmp/pwned")
|
||||
assert _validate_ssh_port("2222") == "2222"
|
||||
|
||||
|
||||
def test_validate_local_dir_accepts_external_drive_paths_with_spaces():
|
||||
path = "/Volumes/T7 2TB/AI Models/llamacpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
"""GET /api/hwfit/models must not 500 on a non-numeric gpu_count.
|
||||
|
||||
The handler did `n = int(gpu_count)` with no guard, so `?gpu_count=abc` (or any
|
||||
non-integer) raised ValueError -> HTTP 500. A malformed count is now ignored,
|
||||
matching how the neighbouring gpu_group param is already parsed.
|
||||
"""
|
||||
from routes.hwfit_routes import setup_hwfit_routes
|
||||
|
||||
|
||||
def _get_models():
|
||||
router = setup_hwfit_routes()
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "").endswith("/models") and "GET" in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError("hwfit /models route not found")
|
||||
|
||||
|
||||
def test_non_numeric_gpu_count_does_not_raise():
|
||||
handler = _get_models()
|
||||
# Previously raised ValueError (HTTP 500); now degrades to a normal ranking.
|
||||
result = handler(gpu_count="abc")
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_numeric_gpu_count_still_accepted():
|
||||
handler = _get_models()
|
||||
result = handler(gpu_count="0")
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_non_numeric_manual_gpu_count_does_not_raise():
|
||||
# manual_gpu_count is the other count param on this endpoint (the hardware
|
||||
# simulator in _apply_manual_hardware). A non-numeric value must also degrade
|
||||
# (default to 1) rather than 500, so the endpoint's count parsing is fully
|
||||
# covered.
|
||||
handler = _get_models()
|
||||
result = handler(manual_mode="gpu", manual_gpu_count="abc")
|
||||
assert isinstance(result, dict)
|
||||
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.platform_compat import _ssh_exec_argv
|
||||
from routes.hwfit_routes import setup_hwfit_routes
|
||||
|
||||
|
||||
def _endpoint(path: str):
|
||||
router = setup_hwfit_routes()
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path:
|
||||
return route.endpoint
|
||||
raise AssertionError(f"{path} route not found")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,kwargs",
|
||||
[
|
||||
("/api/hwfit/system", {}),
|
||||
("/api/hwfit/models", {"limit": 1}),
|
||||
("/api/hwfit/profiles", {"model": "demo"}),
|
||||
("/api/hwfit/image-models", {}),
|
||||
],
|
||||
)
|
||||
def test_hwfit_routes_reject_ssh_option_host(path, kwargs):
|
||||
endpoint = _endpoint(path)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
endpoint(host="-oProxyCommand=sh", ssh_port="22", **kwargs)
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
|
||||
def test_hwfit_routes_reject_port_without_host():
|
||||
endpoint = _endpoint("/api/hwfit/system")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
endpoint(host="", ssh_port="2222")
|
||||
|
||||
assert exc.value.status_code == 400
|
||||
|
||||
|
||||
def test_ssh_argv_rejects_option_shaped_remote():
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_exec_argv("-oProxyCommand=sh", "22", remote_cmd="true")
|
||||
with pytest.raises(ValueError):
|
||||
_ssh_exec_argv("alice@-oProxyCommand=sh", "22", remote_cmd="true")
|
||||
@@ -11,7 +11,10 @@ owner column, but three file-backed / in-memory stores are left stale:
|
||||
research_routes filters by `d.get("owner") == user`, making every report
|
||||
invisible after rename.
|
||||
|
||||
3. data/memory.json — a flat array where every entry has an `owner` field;
|
||||
3. research_handler._active_tasks — in-flight research jobs carry the same
|
||||
owner key while status/cancel/active routes filter by it.
|
||||
|
||||
4. data/memory.json — a flat array where every entry has an `owner` field;
|
||||
memory_manager.load(owner=user) filters on it, so all memories vanish.
|
||||
|
||||
Regression coverage: these bugs are invisible in unit tests that mock the DB
|
||||
@@ -64,10 +67,11 @@ def rename_endpoint(monkeypatch, tmp_path):
|
||||
return _route(ar.setup_auth_routes(am), "rename_user"), am, tmp_path
|
||||
|
||||
|
||||
def _request(tmp_path, session_manager=None, token="t"):
|
||||
def _request(tmp_path, session_manager=None, token="t", research_handler=None):
|
||||
state = SimpleNamespace(
|
||||
invalidate_token_cache=lambda: None,
|
||||
session_manager=session_manager,
|
||||
research_handler=research_handler,
|
||||
)
|
||||
return SimpleNamespace(
|
||||
cookies={"odysseus_session": token},
|
||||
@@ -234,6 +238,108 @@ def test_rename_no_deep_research_dir_does_not_crash(rename_endpoint):
|
||||
assert res["ok"] is True
|
||||
|
||||
|
||||
def test_rename_updates_active_research_task_owner(rename_endpoint):
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
|
||||
from routes.research_routes import setup_research_routes
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
rh = ResearchHandler.__new__(ResearchHandler)
|
||||
rh._active_tasks = {
|
||||
"alice-task": {
|
||||
"owner": "Alice",
|
||||
"status": "running",
|
||||
"query": "q",
|
||||
"progress": {},
|
||||
"started_at": 1,
|
||||
},
|
||||
"carol-task": {
|
||||
"owner": "carol",
|
||||
"status": "running",
|
||||
"query": "q2",
|
||||
"progress": {},
|
||||
"started_at": 2,
|
||||
},
|
||||
}
|
||||
|
||||
asyncio.run(endpoint(
|
||||
"alice",
|
||||
SimpleNamespace(username="alice2"),
|
||||
_request(tmp_path, research_handler=rh),
|
||||
))
|
||||
|
||||
assert rh._active_tasks["alice-task"]["owner"] == "alice2"
|
||||
assert rh._active_tasks["carol-task"]["owner"] == "carol"
|
||||
|
||||
router = setup_research_routes(rh)
|
||||
active = next(
|
||||
r.endpoint for r in router.routes
|
||||
if getattr(r, "path", "") == "/api/research/active"
|
||||
)
|
||||
|
||||
alice2 = asyncio.run(active(
|
||||
SimpleNamespace(state=SimpleNamespace(current_user="alice2")),
|
||||
))
|
||||
alice = asyncio.run(active(
|
||||
SimpleNamespace(state=SimpleNamespace(current_user="alice")),
|
||||
))
|
||||
|
||||
assert [item["session_id"] for item in alice2["active"]] == ["alice-task"]
|
||||
assert alice["active"] == []
|
||||
|
||||
|
||||
def test_research_handler_rename_owner_canonicalizes_new_owner():
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
rh = ResearchHandler.__new__(ResearchHandler)
|
||||
rh._active_tasks = {
|
||||
"task": {"owner": "Alice", "status": "running"},
|
||||
}
|
||||
|
||||
changed = rh.rename_owner("alice", "Alice2")
|
||||
assert changed == 1
|
||||
assert rh._active_tasks["task"]["owner"] == "alice2"
|
||||
|
||||
|
||||
def test_research_handler_rename_owner_uses_auth_lower_contract_not_casefold():
|
||||
from src.research_handler import ResearchHandler
|
||||
|
||||
rh = ResearchHandler.__new__(ResearchHandler)
|
||||
rh._active_tasks = {
|
||||
"task-strasse": {"owner": "strasse", "status": "running"},
|
||||
"task-sharp-s": {"owner": "straße", "status": "running"},
|
||||
}
|
||||
|
||||
changed = rh.rename_owner("straße", "renamed")
|
||||
|
||||
assert changed == 1
|
||||
assert rh._active_tasks["task-strasse"]["owner"] == "strasse"
|
||||
assert rh._active_tasks["task-sharp-s"]["owner"] == "renamed"
|
||||
|
||||
|
||||
def test_rename_updates_active_research_before_completed_json_sweep(rename_endpoint):
|
||||
endpoint, _am, tmp_path = rename_endpoint
|
||||
|
||||
dr_dir = tmp_path / "deep_research"
|
||||
dr_dir.mkdir()
|
||||
report = dr_dir / "race-window.json"
|
||||
report.write_text(json.dumps({"owner": "alice", "status": "done"}), encoding="utf-8")
|
||||
owner_seen_by_active_hook = []
|
||||
|
||||
class FakeResearchHandler:
|
||||
def rename_owner(self, _old, _new):
|
||||
owner_seen_by_active_hook.append(json.loads(report.read_text(encoding="utf-8"))["owner"])
|
||||
|
||||
asyncio.run(endpoint(
|
||||
"alice",
|
||||
SimpleNamespace(username="alice2"),
|
||||
_request(tmp_path, research_handler=FakeResearchHandler()),
|
||||
))
|
||||
|
||||
assert owner_seen_by_active_hook == ["alice"]
|
||||
assert json.loads(report.read_text(encoding="utf-8"))["owner"] == "alice2"
|
||||
|
||||
|
||||
def test_rename_research_respects_custom_data_dir(monkeypatch, tmp_path):
|
||||
"""DEEP_RESEARCH_DIR (which honours ODYSSEUS_DATA_DIR) is used, not a
|
||||
hardcoded relative path. Before the fix, setting ODYSSEUS_DATA_DIR made
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from routes._validators import validate_remote_host, validate_ssh_port
|
||||
|
||||
|
||||
def test_validate_ssh_port_rejects_shell_payload():
|
||||
for port in ["22;id", "$(id)", "-p 22", "0", "65536"]:
|
||||
with pytest.raises(HTTPException):
|
||||
validate_ssh_port(port)
|
||||
assert validate_ssh_port("2222") == "2222"
|
||||
|
||||
|
||||
def test_validate_remote_host_rejects_ssh_option_shape():
|
||||
for host in [
|
||||
"-oProxyCommand=sh",
|
||||
"alice@-oProxyCommand=sh",
|
||||
"--",
|
||||
"-p2222",
|
||||
]:
|
||||
with pytest.raises(HTTPException):
|
||||
validate_remote_host(host)
|
||||
assert validate_remote_host("alice@gpu-box_1") == "alice@gpu-box_1"
|
||||
Reference in New Issue
Block a user