mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
fix(security): close DNS-rebinding hole on diffusion_server (wildcard CORS + missing Host check) (#347)
* fix(security): close DNS-rebinding hole on diffusion_server
scripts/diffusion_server.py used to ship `allow_origins=["*"]` with the
default `--host=127.0.0.1` bind. Combined, that left the OpenAI-compatible
image API reachable from any browser tab via DNS-rebinding: an attacker page
resolves its own domain to 127.0.0.1 mid-fetch, the browser forwards the
request to the loopback server, the server processes it (no Host check), and
the wildcard CORS reply lets the attacker page read the result + drive the
GPU. CWE-346 + CWE-942 + CWE-352 (DNS-rebinding bridge).
Fix:
- Drop the wildcard CORS at module load (default-deny).
- Install `TrustedHostMiddleware` with a loopback allowlist so DNS-rebound
requests are rejected by the middleware before any route runs.
- Add additive `--allowed-host` / `--allowed-origin` CLI flags so operators
who need browser access on a specific origin can opt in explicitly without
re-introducing the wildcard.
Tests: tests/test_diffusion_server_security.py (9 cases) pin the allowlist
helpers, the default-deny CORS behavior, and the live middleware paths via
Starlette's TestClient.
Detected by Aeon + semgrep + manual review.
Severity: medium.
CWE-346 / CWE-942 / CWE-352.
* test(diffusion-server): drive ASGI app via httpx, not TestClient portal
The TrustedHost/CORS integration tests used `with TestClient(app) as
client:`, whose context-manager form spins up an anyio blocking portal to
run the app lifespan. Under the repo's pytest setup (anyio plugin active, a
stray asyncio_mode option, no pytest-asyncio) that portal deadlocks —
`test_trusted_host_middleware_rejects_attacker_host` hung indefinitely in
review before emitting any assertion output.
Replace the TestClient usage with a tiny _asgi_get() helper that drives the
ASGI app over httpx.ASGITransport on a fresh event loop (asyncio.run). No
portal, no lifespan, no dependency on the host project's async test plugins.
Host is taken from the request URL so TrustedHostMiddleware sees the exact
hostname under test; Origin goes through headers. Assertions are unchanged.
Focused test now passes in 0.12s; full file 9 passed.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
---------
Co-authored-by: aeonframework <aeonframework@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,6 +34,7 @@ import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -52,7 +53,42 @@ async def lifespan(application):
|
||||
|
||||
|
||||
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
# Conservative defaults — server is designed for server-to-server use from
|
||||
# the Odysseus backend. Wildcard CORS + the 127.0.0.1 default bind used to
|
||||
# leave the server reachable via DNS-rebinding from any browser tab on the
|
||||
# same host. The CLI flags below extend these allowlists for operators who
|
||||
# need browser access; the safe defaults handle the common case.
|
||||
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
|
||||
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
|
||||
|
||||
# Install defaults at module load so importing the app for tests / direct
|
||||
# uvicorn invocation still benefits from the Host-header allowlist.
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(_DEFAULT_ALLOWED_HOSTS))
|
||||
|
||||
|
||||
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
|
||||
"""Allowed Host header values: the bind address + loopback variants +
|
||||
any operator-supplied --allowed-host values. Duplicates and empty
|
||||
strings are dropped; order is stable for predictable middleware setup."""
|
||||
seen = []
|
||||
for h in (bind_host, *_DEFAULT_ALLOWED_HOSTS, *(extras or [])):
|
||||
h = (h or "").strip()
|
||||
if h and h not in seen:
|
||||
seen.append(h)
|
||||
return seen
|
||||
|
||||
|
||||
def _compute_cors_origins(extras=None) -> list:
|
||||
"""CORS allowlist: default-deny (empty), extended only by explicit
|
||||
--allowed-origin values. Server-to-server callers don't set an Origin
|
||||
header so they're unaffected; this only narrows browser access."""
|
||||
seen = []
|
||||
for o in (*_DEFAULT_CORS_ORIGINS, *(extras or [])):
|
||||
o = (o or "").strip()
|
||||
if o and o not in seen:
|
||||
seen.append(o)
|
||||
return seen
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
@@ -1089,7 +1125,33 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
|
||||
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
|
||||
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
|
||||
parser.add_argument("--allowed-host", action="append", default=[],
|
||||
help="Additional Host header value to accept (DNS-rebinding allowlist). "
|
||||
"Can be repeated. Loopback values are always included.")
|
||||
parser.add_argument("--allowed-origin", action="append", default=[],
|
||||
help="Additional CORS origin to allow. Can be repeated. Defaults to "
|
||||
"no cross-origin access — only pass this if you need a browser "
|
||||
"on a specific origin to call the server.")
|
||||
_args = parser.parse_args()
|
||||
|
||||
# Replace the module-load middleware stack with the CLI-configured one so
|
||||
# operator-supplied --allowed-host / --allowed-origin values take effect
|
||||
# before the first request is served. user_middleware is consulted lazily
|
||||
# when the middleware stack is built on the first request, so mutating it
|
||||
# here is safe.
|
||||
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
|
||||
final_origins = _compute_cors_origins(_args.allowed_origin)
|
||||
app.user_middleware.clear()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=final_hosts)
|
||||
if final_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=final_origins,
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
|
||||
final_hosts, final_origins or "(none — default-deny)")
|
||||
|
||||
app.state.model_path = _args.model
|
||||
uvicorn.run(app, host=_args.host, port=_args.port)
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
"""Pin the diffusion_server DNS-rebinding + wildcard-CORS regression.
|
||||
|
||||
Background: scripts/diffusion_server.py used to ship `allow_origins=["*"]`
|
||||
with the default `--host=127.0.0.1` bind. Combined, that left the OpenAI-
|
||||
compatible image API reachable from any browser tab via DNS-rebinding: an
|
||||
attacker page resolves its own domain to 127.0.0.1 mid-fetch, the browser
|
||||
forwards the request to the loopback server, and the wildcard CORS reply
|
||||
lets the attacker page read the result + drive the GPU.
|
||||
|
||||
The fix narrows CORS to default-deny and adds a TrustedHostMiddleware
|
||||
Host-header allowlist as a positive defense. These tests pin the allowlist
|
||||
helpers + Starlette's middleware behavior so a future change can't silently
|
||||
re-open the hole.
|
||||
|
||||
The tests run against a tiny synthetic FastAPI app that uses the same
|
||||
``TrustedHostMiddleware`` + ``CORSMiddleware`` wiring as diffusion_server.
|
||||
That keeps the test out of the torch / diffusers import path while still
|
||||
covering the live middleware code paths.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
_SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py"
|
||||
|
||||
|
||||
def _load_helpers():
|
||||
"""Import the pure allowlist helpers from diffusion_server.py without
|
||||
triggering its torch / diffusers imports. We compile just the helper
|
||||
block (everything between the `app =` line and the `class ImageRequest`
|
||||
line) so heavy deps stay quarantined behind the if-False import guard.
|
||||
"""
|
||||
src = _SCRIPT.read_text(encoding="utf-8")
|
||||
# The helpers live between the two markers, both inserted by the security
|
||||
# fix. They depend only on the `_DEFAULT_ALLOWED_HOSTS` / `_DEFAULT_CORS_ORIGINS`
|
||||
# module-level lists, which we materialise here.
|
||||
start_marker = "_DEFAULT_ALLOWED_HOSTS = "
|
||||
end_marker = "class ImageRequest("
|
||||
i = src.index(start_marker)
|
||||
j = src.index(end_marker)
|
||||
helper_block = src[i:j]
|
||||
ns: dict = {"list": list}
|
||||
# Strip the `app.add_middleware(...)` line — the helpers don't need it
|
||||
# and it would force a torch import via fastapi.responses.
|
||||
helper_block = "\n".join(
|
||||
line for line in helper_block.splitlines()
|
||||
if not line.startswith("app.add_middleware")
|
||||
)
|
||||
exec(compile(helper_block, str(_SCRIPT), "exec"), ns)
|
||||
return ns
|
||||
|
||||
|
||||
def test_compute_allowed_hosts_includes_loopback_and_bind_host():
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_allowed_hosts"]("0.0.0.0")
|
||||
assert "0.0.0.0" in out
|
||||
assert "127.0.0.1" in out
|
||||
assert "localhost" in out
|
||||
assert "::1" in out
|
||||
|
||||
|
||||
def test_compute_allowed_hosts_dedupes_and_strips():
|
||||
ns = _load_helpers()
|
||||
# Bind host duplicates a default + an extra duplicates a default + blanks
|
||||
# all collapse into one entry per unique value, preserving stable order.
|
||||
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["localhost", "", " ", "lan.example"])
|
||||
assert out == ["127.0.0.1", "localhost", "::1", "lan.example"]
|
||||
|
||||
|
||||
def test_compute_allowed_hosts_does_not_add_wildcard():
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole"
|
||||
|
||||
|
||||
def test_compute_cors_origins_default_deny():
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_cors_origins"]()
|
||||
assert out == [], "default CORS allowlist must be empty (no cross-origin)"
|
||||
|
||||
|
||||
def test_compute_cors_origins_does_not_default_to_wildcard():
|
||||
"""Regression: the original code shipped allow_origins=['*']. The fix
|
||||
must NOT bring that back even when the operator passes nothing."""
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_cors_origins"](extras=None)
|
||||
assert "*" not in out
|
||||
out2 = ns["_compute_cors_origins"](extras=[])
|
||||
assert "*" not in out2
|
||||
|
||||
|
||||
def test_compute_cors_origins_honours_explicit_extras():
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_cors_origins"](extras=["http://localhost:7000", "", "http://localhost:7000"])
|
||||
assert out == ["http://localhost:7000"]
|
||||
|
||||
|
||||
# ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ─────
|
||||
|
||||
|
||||
def _starlette_available() -> bool:
|
||||
return importlib.util.find_spec("starlette") is not None
|
||||
|
||||
|
||||
def _asgi_get(app, url, headers=None):
|
||||
"""Drive a single GET against an ASGI ``app`` over httpx's in-process
|
||||
``ASGITransport`` on a fresh event loop.
|
||||
|
||||
This deliberately avoids ``starlette.testclient.TestClient``: its
|
||||
context-manager form spins up an ``anyio`` blocking portal (to run the
|
||||
lifespan), which deadlocks under some pytest / anyio / asyncio test
|
||||
configurations — the focused Host-header test hung indefinitely during
|
||||
review (see PR #347). A direct ASGI call needs neither a portal nor a
|
||||
lifespan, so it stays reliable regardless of the host project's async
|
||||
test plugins.
|
||||
|
||||
The request ``Host`` is derived from ``url`` so the TrustedHost allowlist
|
||||
sees exactly the hostname under test; ``Origin`` and friends go through
|
||||
``headers``.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
|
||||
async def _run():
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport) as client:
|
||||
return await client.get(url, headers=headers or {})
|
||||
|
||||
return asyncio.run(_run())
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_trusted_host_middleware_rejects_attacker_host():
|
||||
"""A request with an attacker-controlled Host header (the DNS-rebinding
|
||||
surface) must be rejected by the middleware before reaching any route."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: F401 (parity import)
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# Legitimate request (Host: 127.0.0.1) goes through.
|
||||
ok = _asgi_get(app, "http://127.0.0.1/health")
|
||||
assert ok.status_code == 200
|
||||
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected.
|
||||
bad = _asgi_get(app, "http://evil.example.com/health")
|
||||
assert bad.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_cors_default_deny_does_not_emit_wildcard_acao():
|
||||
"""Without CORSMiddleware installed, the server must not advertise
|
||||
Access-Control-Allow-Origin at all (definitely not the wildcard)."""
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
# Default-deny CORS: no CORSMiddleware. Mirrors diffusion_server's behavior
|
||||
# when no --allowed-origin flags are passed.
|
||||
cors_origins = ns["_compute_cors_origins"]()
|
||||
assert cors_origins == []
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models():
|
||||
return {"data": []}
|
||||
|
||||
# Host is allowed, so the request itself succeeds — but the response must
|
||||
# carry no ACAO, so a real browser would block the attacker page from
|
||||
# reading the body.
|
||||
resp = _asgi_get(
|
||||
app,
|
||||
"http://127.0.0.1/v1/models",
|
||||
headers={"Origin": "https://evil.example.com"},
|
||||
)
|
||||
acao = resp.headers.get("access-control-allow-origin")
|
||||
assert acao is None or acao == "", (
|
||||
f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, "
|
||||
f"so any non-empty default fails this gate"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_explicit_cors_origin_does_not_widen_to_wildcard():
|
||||
"""Even when the operator opts in to one cross-origin, that single origin
|
||||
must not unlock a wildcard reflection for other origins."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models():
|
||||
return {"data": []}
|
||||
|
||||
# Allowed origin: ACAO echoes that origin (NOT '*').
|
||||
ok = _asgi_get(
|
||||
app,
|
||||
"http://127.0.0.1/v1/models",
|
||||
headers={"Origin": "http://localhost:7000"},
|
||||
)
|
||||
assert ok.status_code == 200
|
||||
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
|
||||
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
|
||||
bad = _asgi_get(
|
||||
app,
|
||||
"http://127.0.0.1/v1/models",
|
||||
headers={"Origin": "https://evil.example.com"},
|
||||
)
|
||||
bad_acao = bad.headers.get("access-control-allow-origin")
|
||||
assert bad_acao != "*"
|
||||
assert bad_acao != "https://evil.example.com"
|
||||
Reference in New Issue
Block a user