mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
fix(image): patch realesrgan torchvision compatibility (#4110)
This commit is contained in:
@@ -19,6 +19,7 @@ from src.upload_limits import (
|
|||||||
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
|
GALLERY_TRANSFORM_UPLOAD_MAX_BYTES,
|
||||||
)
|
)
|
||||||
from src.constants import GENERATED_IMAGES_DIR
|
from src.constants import GENERATED_IMAGES_DIR
|
||||||
|
from src.optional_deps import patch_realesrgan_torchvision_compat
|
||||||
|
|
||||||
from routes.gallery_helpers import (
|
from routes.gallery_helpers import (
|
||||||
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size,
|
||||||
@@ -1467,6 +1468,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
img_bytes = base64.b64decode(image_b64)
|
img_bytes = base64.b64decode(image_b64)
|
||||||
src = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
src = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||||
try:
|
try:
|
||||||
|
patch_realesrgan_torchvision_compat()
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return {"error": "realesrgan not installed. Install it from Cookbook → Dependencies (search 'realesrgan')."}
|
return {"error": "realesrgan not installed. Install it from Cookbook → Dependencies (search 'realesrgan')."}
|
||||||
@@ -1516,6 +1518,7 @@ def setup_gallery_routes() -> APIRouter:
|
|||||||
img_bytes = base64.b64decode(image_b64)
|
img_bytes = base64.b64decode(image_b64)
|
||||||
src = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
src = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||||
try:
|
try:
|
||||||
|
patch_realesrgan_torchvision_compat()
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shell routes — user-facing command execution endpoint."""
|
"""Shell routes — user-facing command execution endpoint."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -14,6 +15,7 @@ from collections import namedtuple
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from core.platform_compat import IS_APPLE_SILICON, which_tool
|
from core.platform_compat import IS_APPLE_SILICON, which_tool
|
||||||
|
from src.optional_deps import prepare_optional_dependency_import
|
||||||
|
|
||||||
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
# POSIX-only: `pty`/`fcntl` transitively import `termios`, which does NOT exist
|
||||||
# on Windows, so importing them unconditionally crashed app startup there
|
# on Windows, so importing them unconditionally crashed app startup there
|
||||||
@@ -149,6 +151,11 @@ def _pip_dist_name(pkg: dict) -> str:
|
|||||||
return (pkg.get("name") or "").replace("_", "-")
|
return (pkg.get("name") or "").replace("_", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def _import_optional_dependency_for_status(name: str):
|
||||||
|
prepare_optional_dependency_import(name)
|
||||||
|
return importlib.import_module(name)
|
||||||
|
|
||||||
|
|
||||||
def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
def _package_installed_from_probe(name: str, probe: dict) -> bool:
|
||||||
"""Return whether an optional dependency is usable by Cookbook.
|
"""Return whether an optional dependency is usable by Cookbook.
|
||||||
|
|
||||||
@@ -970,7 +977,6 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
"""
|
"""
|
||||||
_require_admin(request)
|
_require_admin(request)
|
||||||
_reject_cross_site(request)
|
_reject_cross_site(request)
|
||||||
import importlib
|
|
||||||
import importlib.metadata as importlib_metadata
|
import importlib.metadata as importlib_metadata
|
||||||
import shlex
|
import shlex
|
||||||
import json as _json
|
import json as _json
|
||||||
@@ -1202,7 +1208,7 @@ def setup_shell_routes() -> APIRouter:
|
|||||||
pkg["status_note"] = _package_status_note("vllm", probe)
|
pkg["status_note"] = _package_status_note("vllm", probe)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
importlib.import_module(pkg["name"])
|
_import_optional_dependency_for_status(pkg["name"])
|
||||||
importlib_metadata.version(_pip_dist_name(pkg))
|
importlib_metadata.version(_pip_dist_name(pkg))
|
||||||
pkg["installed"] = True
|
pkg["installed"] = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
"""Compatibility helpers for optional third-party dependencies."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
|
||||||
|
def patch_realesrgan_torchvision_compat() -> None:
|
||||||
|
"""Restore the torchvision import path expected by BasicSR/Real-ESRGAN."""
|
||||||
|
module_name = "torchvision.transforms.functional_tensor"
|
||||||
|
if module_name in sys.modules:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from torchvision.transforms import functional
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
rgb_to_grayscale = getattr(functional, "rgb_to_grayscale", None)
|
||||||
|
if rgb_to_grayscale is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
shim = types.ModuleType(module_name)
|
||||||
|
shim.rgb_to_grayscale = rgb_to_grayscale
|
||||||
|
shim.__getattr__ = lambda name: getattr(functional, name)
|
||||||
|
sys.modules[module_name] = shim
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_optional_dependency_import(name: str) -> None:
|
||||||
|
"""Apply known import-time compatibility shims before probing a package."""
|
||||||
|
if name == "realesrgan":
|
||||||
|
patch_realesrgan_torchvision_compat()
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
from src.optional_deps import (
|
||||||
|
patch_realesrgan_torchvision_compat,
|
||||||
|
prepare_optional_dependency_import,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_realesrgan_patch_restores_removed_functional_tensor_module(monkeypatch):
|
||||||
|
for name in list(sys.modules):
|
||||||
|
if name.startswith("torchvision"):
|
||||||
|
monkeypatch.delitem(sys.modules, name, raising=False)
|
||||||
|
|
||||||
|
sentinel = object()
|
||||||
|
torchvision = types.ModuleType("torchvision")
|
||||||
|
transforms = types.ModuleType("torchvision.transforms")
|
||||||
|
functional = types.ModuleType("torchvision.transforms.functional")
|
||||||
|
functional.rgb_to_grayscale = sentinel
|
||||||
|
transforms.functional = functional
|
||||||
|
torchvision.transforms = transforms
|
||||||
|
monkeypatch.setitem(sys.modules, "torchvision", torchvision)
|
||||||
|
monkeypatch.setitem(sys.modules, "torchvision.transforms", transforms)
|
||||||
|
monkeypatch.setitem(sys.modules, "torchvision.transforms.functional", functional)
|
||||||
|
|
||||||
|
patch_realesrgan_torchvision_compat()
|
||||||
|
|
||||||
|
shim = sys.modules["torchvision.transforms.functional_tensor"]
|
||||||
|
assert shim.rgb_to_grayscale is sentinel
|
||||||
|
assert shim.rgb_to_grayscale is functional.rgb_to_grayscale
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_optional_dependency_import_scopes_patch_to_realesrgan(monkeypatch):
|
||||||
|
import src.optional_deps as optional_deps
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
optional_deps,
|
||||||
|
"patch_realesrgan_torchvision_compat",
|
||||||
|
lambda: calls.append("patched"),
|
||||||
|
)
|
||||||
|
|
||||||
|
prepare_optional_dependency_import("diffusers")
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
|
prepare_optional_dependency_import("realesrgan")
|
||||||
|
assert calls == ["patched"]
|
||||||
@@ -13,6 +13,7 @@ import pytest
|
|||||||
|
|
||||||
from routes.shell_routes import (
|
from routes.shell_routes import (
|
||||||
_find_line_break,
|
_find_line_break,
|
||||||
|
_import_optional_dependency_for_status,
|
||||||
_running_in_container,
|
_running_in_container,
|
||||||
_docker_row_status,
|
_docker_row_status,
|
||||||
_package_installed_from_probe,
|
_package_installed_from_probe,
|
||||||
@@ -376,6 +377,26 @@ class TestPackageProbeStatus:
|
|||||||
assert "add_user_install_bins_to_path()" in script
|
assert "add_user_install_bins_to_path()" in script
|
||||||
assert "shutil.which(b)" in script
|
assert "shutil.which(b)" in script
|
||||||
|
|
||||||
|
def test_status_import_prepares_optional_dependency(self, monkeypatch):
|
||||||
|
import routes.shell_routes as shell_routes
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
monkeypatch.setattr(
|
||||||
|
shell_routes,
|
||||||
|
"prepare_optional_dependency_import",
|
||||||
|
lambda name: calls.append(name),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
shell_routes.importlib,
|
||||||
|
"import_module",
|
||||||
|
lambda name: SimpleNamespace(__name__=name),
|
||||||
|
)
|
||||||
|
|
||||||
|
module = _import_optional_dependency_for_status("realesrgan")
|
||||||
|
|
||||||
|
assert module.__name__ == "realesrgan"
|
||||||
|
assert calls == ["realesrgan"]
|
||||||
|
|
||||||
|
|
||||||
class TestSshBaseArgv:
|
class TestSshBaseArgv:
|
||||||
def test_basic_host_no_port(self):
|
def test_basic_host_no_port(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user