Files
odysseus/tests/test_diffusion_server_security.py
T
@aaronjmars 108ee1e32b fix(security): close DNS-rebinding hole on diffusion_server (wildcard CORS + missing Host check) (#347)
* fix(security): close DNS-rebinding hole on diffusion_server

scripts/diffusion_server.py used to ship `allow_origins=["*"]` with the
default `--host=127.0.0.1` bind. Combined, that left the OpenAI-compatible
image API reachable from any browser tab via DNS-rebinding: an attacker page
resolves its own domain to 127.0.0.1 mid-fetch, the browser forwards the
request to the loopback server, the server processes it (no Host check), and
the wildcard CORS reply lets the attacker page read the result + drive the
GPU. CWE-346 + CWE-942 + CWE-352 (DNS-rebinding bridge).

Fix:
  - Drop the wildcard CORS at module load (default-deny).
  - Install `TrustedHostMiddleware` with a loopback allowlist so DNS-rebound
    requests are rejected by the middleware before any route runs.
  - Add additive `--allowed-host` / `--allowed-origin` CLI flags so operators
    who need browser access on a specific origin can opt in explicitly without
    re-introducing the wildcard.

Tests: tests/test_diffusion_server_security.py (9 cases) pin the allowlist
helpers, the default-deny CORS behavior, and the live middleware paths via
Starlette's TestClient.

Detected by Aeon + semgrep + manual review.
Severity: medium.
CWE-346 / CWE-942 / CWE-352.

* test(diffusion-server): drive ASGI app via httpx, not TestClient portal

The TrustedHost/CORS integration tests used `with TestClient(app) as
client:`, whose context-manager form spins up an anyio blocking portal to
run the app lifespan. Under the repo's pytest setup (anyio plugin active, a
stray asyncio_mode option, no pytest-asyncio) that portal deadlocks —
`test_trusted_host_middleware_rejects_attacker_host` hung indefinitely in
review before emitting any assertion output.

Replace the TestClient usage with a tiny _asgi_get() helper that drives the
ASGI app over httpx.ASGITransport on a fresh event loop (asyncio.run). No
portal, no lifespan, no dependency on the host project's async test plugins.
Host is taken from the request URL so TrustedHostMiddleware sees the exact
hostname under test; Origin goes through headers. Assertions are unchanged.

Focused test now passes in 0.12s; full file 9 passed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: aeonframework <aeonframework@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-06 23:34:39 +01:00

241 lines
9.1 KiB
Python

"""Pin the diffusion_server DNS-rebinding + wildcard-CORS regression.
Background: scripts/diffusion_server.py used to ship `allow_origins=["*"]`
with the default `--host=127.0.0.1` bind. Combined, that left the OpenAI-
compatible image API reachable from any browser tab via DNS-rebinding: an
attacker page resolves its own domain to 127.0.0.1 mid-fetch, the browser
forwards the request to the loopback server, and the wildcard CORS reply
lets the attacker page read the result + drive the GPU.
The fix narrows CORS to default-deny and adds a TrustedHostMiddleware
Host-header allowlist as a positive defense. These tests pin the allowlist
helpers + Starlette's middleware behavior so a future change can't silently
re-open the hole.
The tests run against a tiny synthetic FastAPI app that uses the same
``TrustedHostMiddleware`` + ``CORSMiddleware`` wiring as diffusion_server.
That keeps the test out of the torch / diffusers import path while still
covering the live middleware code paths.
"""
import importlib.util
import os
from pathlib import Path
import pytest
_SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py"
def _load_helpers():
"""Import the pure allowlist helpers from diffusion_server.py without
triggering its torch / diffusers imports. We compile just the helper
block (everything between the `app =` line and the `class ImageRequest`
line) so heavy deps stay quarantined behind the if-False import guard.
"""
src = _SCRIPT.read_text(encoding="utf-8")
# The helpers live between the two markers, both inserted by the security
# fix. They depend only on the `_DEFAULT_ALLOWED_HOSTS` / `_DEFAULT_CORS_ORIGINS`
# module-level lists, which we materialise here.
start_marker = "_DEFAULT_ALLOWED_HOSTS = "
end_marker = "class ImageRequest("
i = src.index(start_marker)
j = src.index(end_marker)
helper_block = src[i:j]
ns: dict = {"list": list}
# Strip the `app.add_middleware(...)` line — the helpers don't need it
# and it would force a torch import via fastapi.responses.
helper_block = "\n".join(
line for line in helper_block.splitlines()
if not line.startswith("app.add_middleware")
)
exec(compile(helper_block, str(_SCRIPT), "exec"), ns)
return ns
def test_compute_allowed_hosts_includes_loopback_and_bind_host():
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("0.0.0.0")
assert "0.0.0.0" in out
assert "127.0.0.1" in out
assert "localhost" in out
assert "::1" in out
def test_compute_allowed_hosts_dedupes_and_strips():
ns = _load_helpers()
# Bind host duplicates a default + an extra duplicates a default + blanks
# all collapse into one entry per unique value, preserving stable order.
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["localhost", "", " ", "lan.example"])
assert out == ["127.0.0.1", "localhost", "::1", "lan.example"]
def test_compute_allowed_hosts_does_not_add_wildcard():
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("127.0.0.1")
assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole"
def test_compute_cors_origins_default_deny():
ns = _load_helpers()
out = ns["_compute_cors_origins"]()
assert out == [], "default CORS allowlist must be empty (no cross-origin)"
def test_compute_cors_origins_does_not_default_to_wildcard():
"""Regression: the original code shipped allow_origins=['*']. The fix
must NOT bring that back even when the operator passes nothing."""
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=None)
assert "*" not in out
out2 = ns["_compute_cors_origins"](extras=[])
assert "*" not in out2
def test_compute_cors_origins_honours_explicit_extras():
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=["http://localhost:7000", "", "http://localhost:7000"])
assert out == ["http://localhost:7000"]
# ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ─────
def _starlette_available() -> bool:
return importlib.util.find_spec("starlette") is not None
def _asgi_get(app, url, headers=None):
"""Drive a single GET against an ASGI ``app`` over httpx's in-process
``ASGITransport`` on a fresh event loop.
This deliberately avoids ``starlette.testclient.TestClient``: its
context-manager form spins up an ``anyio`` blocking portal (to run the
lifespan), which deadlocks under some pytest / anyio / asyncio test
configurations — the focused Host-header test hung indefinitely during
review (see PR #347). A direct ASGI call needs neither a portal nor a
lifespan, so it stays reliable regardless of the host project's async
test plugins.
The request ``Host`` is derived from ``url`` so the TrustedHost allowlist
sees exactly the hostname under test; ``Origin`` and friends go through
``headers``.
"""
import asyncio
import httpx
async def _run():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, headers=headers or {})
return asyncio.run(_run())
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_trusted_host_middleware_rejects_attacker_host():
"""A request with an attacker-controlled Host header (the DNS-rebinding
surface) must be rejected by the middleware before reaching any route."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware # noqa: F401 (parity import)
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
@app.get("/health")
def health():
return {"status": "ok"}
# Legitimate request (Host: 127.0.0.1) goes through.
ok = _asgi_get(app, "http://127.0.0.1/health")
assert ok.status_code == 200
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected.
bad = _asgi_get(app, "http://evil.example.com/health")
assert bad.status_code == 400
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_cors_default_deny_does_not_emit_wildcard_acao():
"""Without CORSMiddleware installed, the server must not advertise
Access-Control-Allow-Origin at all (definitely not the wildcard)."""
from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
# Default-deny CORS: no CORSMiddleware. Mirrors diffusion_server's behavior
# when no --allowed-origin flags are passed.
cors_origins = ns["_compute_cors_origins"]()
assert cors_origins == []
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
@app.get("/v1/models")
def list_models():
return {"data": []}
# Host is allowed, so the request itself succeeds — but the response must
# carry no ACAO, so a real browser would block the attacker page from
# reading the body.
resp = _asgi_get(
app,
"http://127.0.0.1/v1/models",
headers={"Origin": "https://evil.example.com"},
)
acao = resp.headers.get("access-control-allow-origin")
assert acao is None or acao == "", (
f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, "
f"so any non-empty default fails this gate"
)
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_explicit_cors_origin_does_not_widen_to_wildcard():
"""Even when the operator opts in to one cross-origin, that single origin
must not unlock a wildcard reflection for other origins."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
@app.get("/v1/models")
def list_models():
return {"data": []}
# Allowed origin: ACAO echoes that origin (NOT '*').
ok = _asgi_get(
app,
"http://127.0.0.1/v1/models",
headers={"Origin": "http://localhost:7000"},
)
assert ok.status_code == 200
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
bad = _asgi_get(
app,
"http://127.0.0.1/v1/models",
headers={"Origin": "https://evil.example.com"},
)
bad_acao = bad.headers.get("access-control-allow-origin")
assert bad_acao != "*"
assert bad_acao != "https://evil.example.com"