fix(security): close DNS-rebinding hole on diffusion_server (wildcard CORS + missing Host check) (#347)

* fix(security): close DNS-rebinding hole on diffusion_server

scripts/diffusion_server.py used to ship `allow_origins=["*"]` with the
default `--host=127.0.0.1` bind. Combined, that left the OpenAI-compatible
image API reachable from any browser tab via DNS-rebinding: an attacker page
resolves its own domain to 127.0.0.1 mid-fetch, the browser forwards the
request to the loopback server, the server processes it (no Host check), and
the wildcard CORS reply lets the attacker page read the result + drive the
GPU. CWE-346 + CWE-942 + CWE-352 (DNS-rebinding bridge).

Fix:
  - Drop the wildcard CORS at module load (default-deny).
  - Install `TrustedHostMiddleware` with a loopback allowlist so DNS-rebound
    requests are rejected by the middleware before any route runs.
  - Add additive `--allowed-host` / `--allowed-origin` CLI flags so operators
    who need browser access on a specific origin can opt in explicitly without
    re-introducing the wildcard.

Tests: tests/test_diffusion_server_security.py (9 cases) pin the allowlist
helpers, the default-deny CORS behavior, and the live middleware paths via
Starlette's TestClient.

Detected by Aeon + semgrep + manual review.
Severity: medium.
CWE-346 / CWE-942 / CWE-352.

* test(diffusion-server): drive ASGI app via httpx, not TestClient portal

The TrustedHost/CORS integration tests used `with TestClient(app) as
client:`, whose context-manager form spins up an anyio blocking portal to
run the app lifespan. Under the repo's pytest setup (anyio plugin active, a
stray asyncio_mode option, no pytest-asyncio) that portal deadlocks —
`test_trusted_host_middleware_rejects_attacker_host` hung indefinitely in
review before emitting any assertion output.

Replace the TestClient usage with a tiny _asgi_get() helper that drives the
ASGI app over httpx.ASGITransport on a fresh event loop (asyncio.run). No
portal, no lifespan, no dependency on the host project's async test plugins.
Host is taken from the request URL so TrustedHostMiddleware sees the exact
hostname under test; Origin goes through headers. Assertions are unchanged.

Focused test now passes in 0.12s; full file 9 passed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: aeonframework <aeonframework@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@aaronjmars
2026-06-06 18:34:39 -04:00
committed by GitHub
parent b03d934ec6
commit 108ee1e32b
2 changed files with 303 additions and 1 deletions
+63 -1
View File
@@ -34,6 +34,7 @@ import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel
logging.basicConfig(level=logging.INFO)
@@ -52,7 +53,42 @@ async def lifespan(application):
app = FastAPI(title="Diffusion Server", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# Conservative defaults — server is designed for server-to-server use from
# the Odysseus backend. Wildcard CORS + the 127.0.0.1 default bind used to
# leave the server reachable via DNS-rebinding from any browser tab on the
# same host. The CLI flags below extend these allowlists for operators who
# need browser access; the safe defaults handle the common case.
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
# Install defaults at module load so importing the app for tests / direct
# uvicorn invocation still benefits from the Host-header allowlist.
app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(_DEFAULT_ALLOWED_HOSTS))
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
"""Allowed Host header values: the bind address + loopback variants +
any operator-supplied --allowed-host values. Duplicates and empty
strings are dropped; order is stable for predictable middleware setup."""
seen = []
for h in (bind_host, *_DEFAULT_ALLOWED_HOSTS, *(extras or [])):
h = (h or "").strip()
if h and h not in seen:
seen.append(h)
return seen
def _compute_cors_origins(extras=None) -> list:
"""CORS allowlist: default-deny (empty), extended only by explicit
--allowed-origin values. Server-to-server callers don't set an Origin
header so they're unaffected; this only narrows browser access."""
seen = []
for o in (*_DEFAULT_CORS_ORIGINS, *(extras or [])):
o = (o or "").strip()
if o and o not in seen:
seen.append(o)
return seen
class ImageRequest(BaseModel):
@@ -1089,7 +1125,33 @@ if __name__ == "__main__":
parser.add_argument("--attention-slicing", action="store_true", help="Enable attention slicing")
parser.add_argument("--vae-slicing", action="store_true", help="Enable VAE slicing")
parser.add_argument("--harmonize-gpu", type=int, default=None, help="GPU index for harmonize/img2img (default: same as main)")
parser.add_argument("--allowed-host", action="append", default=[],
help="Additional Host header value to accept (DNS-rebinding allowlist). "
"Can be repeated. Loopback values are always included.")
parser.add_argument("--allowed-origin", action="append", default=[],
help="Additional CORS origin to allow. Can be repeated. Defaults to "
"no cross-origin access — only pass this if you need a browser "
"on a specific origin to call the server.")
_args = parser.parse_args()
# Replace the module-load middleware stack with the CLI-configured one so
# operator-supplied --allowed-host / --allowed-origin values take effect
# before the first request is served. user_middleware is consulted lazily
# when the middleware stack is built on the first request, so mutating it
# here is safe.
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
final_origins = _compute_cors_origins(_args.allowed_origin)
app.user_middleware.clear()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=final_hosts)
if final_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=final_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
final_hosts, final_origins or "(none — default-deny)")
app.state.model_path = _args.model
uvicorn.run(app, host=_args.host, port=_args.port)