test(diffusion-server): exercise security middleware wiring (#3214)

This commit is contained in:
Alexandre Teixeira
2026-06-07 22:42:11 +01:00
committed by GitHub
parent 92300b5d67
commit 9ad6a2809e
2 changed files with 199 additions and 101 deletions
+173 -88
View File
@@ -12,14 +12,15 @@ Host-header allowlist as a positive defense. These tests pin the allowlist
helpers + Starlette's middleware behavior so a future change can't silently
re-open the hole.
The tests run against a tiny synthetic FastAPI app that uses the same
``TrustedHostMiddleware`` + ``CORSMiddleware`` wiring as diffusion_server.
That keeps the test out of the torch / diffusers import path while still
covering the live middleware code paths.
The tests AST-extract the security helpers — including the real
``_configure_security_middleware`` wiring — from diffusion_server.py and run
them against a fresh FastAPI app. That keeps the tests out of the torch /
diffusers import path while still exercising the production middleware wiring
instead of a hand-rebuilt copy.
"""
import ast
import importlib.util
import os
from pathlib import Path
import pytest
@@ -28,29 +29,49 @@ import pytest
_SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py"
_EXPECTED_NAMES = (
"_DEFAULT_ALLOWED_HOSTS",
"_DEFAULT_CORS_ORIGINS",
"_compute_allowed_hosts",
"_compute_cors_origins",
"_configure_security_middleware",
)
def _load_helpers():
"""Import the pure allowlist helpers from diffusion_server.py without
triggering its torch / diffusers imports. We compile just the helper
block (everything between the `app =` line and the `class ImageRequest`
line) so heavy deps stay quarantined behind the if-False import guard.
"""
src = _SCRIPT.read_text(encoding="utf-8")
# The helpers live between the two markers, both inserted by the security
# fix. They depend only on the `_DEFAULT_ALLOWED_HOSTS` / `_DEFAULT_CORS_ORIGINS`
# module-level lists, which we materialise here.
start_marker = "_DEFAULT_ALLOWED_HOSTS = "
end_marker = "class ImageRequest("
i = src.index(start_marker)
j = src.index(end_marker)
helper_block = src[i:j]
ns: dict = {"list": list}
# Strip the `app.add_middleware(...)` line — the helpers don't need it
# and it would force a torch import via fastapi.responses.
helper_block = "\n".join(
line for line in helper_block.splitlines()
if not line.startswith("app.add_middleware")
)
exec(compile(helper_block, str(_SCRIPT), "exec"), ns)
"""Extract the security helpers from diffusion_server.py via AST so the
tests exercise the production wiring without importing the module (which
would pull in torch / diffusers). Only the named top-level definitions are
compiled into a fresh module; everything else — including the heavy
imports — is left out. A renamed or removed helper fails loudly here."""
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
tree = ast.parse(_SCRIPT.read_text(encoding="utf-8"))
wanted: dict = {}
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name in _EXPECTED_NAMES:
wanted[node.name] = node
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id in _EXPECTED_NAMES:
wanted[target.id] = node
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
if node.target.id in _EXPECTED_NAMES:
wanted[node.target.id] = node
missing = [name for name in _EXPECTED_NAMES if name not in wanted]
assert not missing, f"diffusion_server.py is missing expected helpers: {missing}"
module = ast.Module(body=[wanted[name] for name in _EXPECTED_NAMES], type_ignores=[])
ast.fix_missing_locations(module)
ns: dict = {
"TrustedHostMiddleware": TrustedHostMiddleware,
"CORSMiddleware": CORSMiddleware,
"RuntimeError": RuntimeError,
"list": list,
}
exec(compile(module, str(_SCRIPT), "exec"), ns)
return ns
@@ -77,6 +98,15 @@ def test_compute_allowed_hosts_does_not_add_wildcard():
assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole"
def test_compute_allowed_hosts_preserves_explicit_wildcard():
# Behavior preservation: a wildcard is not added by default, but an
# operator who explicitly passes one is taken at their word (deduped,
# stripped, stable order). This pins current behavior, not policy.
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["*", " lan.example ", "*"])
assert out == ["127.0.0.1", "localhost", "::1", "*", "lan.example"]
def test_compute_cors_origins_default_deny():
ns = _load_helpers()
out = ns["_compute_cors_origins"]()
@@ -99,6 +129,15 @@ def test_compute_cors_origins_honours_explicit_extras():
assert out == ["http://localhost:7000"]
def test_compute_cors_origins_preserves_explicit_wildcard():
# Behavior preservation: a wildcard is not the default, but an operator
# who explicitly passes one is taken at their word (deduped, stripped,
# stable order). This pins current behavior, not policy.
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=["*", " http://localhost:7000 ", "*"])
assert out == ["*", "http://localhost:7000"]
# ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ─────
@@ -134,61 +173,65 @@ def _asgi_get(app, url, headers=None):
return asyncio.run(_run())
def _configured_app(ns, allowed_origins, route_called=None):
"""Fresh FastAPI app wired by the production `_configure_security_middleware`
with a loopback Host allowlist, plus a minimal route so accepted requests
can assert 200. If `route_called` is given, the route sets
``route_called["hit"] = True`` so callers can prove whether the inner app
was reached."""
from fastapi import FastAPI
app = FastAPI()
ns["_configure_security_middleware"](
app, ns["_compute_allowed_hosts"]("127.0.0.1"), allowed_origins
)
@app.get("/")
def root():
if route_called is not None:
route_called["hit"] = True
return {"ok": True}
return app
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_trusted_host_middleware_rejects_attacker_host():
"""A request with an attacker-controlled Host header (the DNS-rebinding
surface) must be rejected by the middleware before reaching any route."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware # noqa: F401 (parity import)
from starlette.middleware.trustedhost import TrustedHostMiddleware
surface) must be rejected by the production wiring before any route runs."""
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
route_called = {"hit": False}
app = _configured_app(ns, [], route_called=route_called)
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
@app.get("/health")
def health():
return {"status": "ok"}
# Legitimate request (Host: 127.0.0.1) goes through.
ok = _asgi_get(app, "http://127.0.0.1/health")
# Legitimate request (Host: 127.0.0.1) reaches the route.
ok = _asgi_get(app, "http://127.0.0.1/")
assert ok.status_code == 200
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected.
bad = _asgi_get(app, "http://evil.example.com/health")
assert route_called["hit"] is True
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected before
# the route runs.
route_called["hit"] = False
bad = _asgi_get(app, "http://evil.example.com/")
assert bad.status_code == 400
assert route_called["hit"] is False
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_cors_default_deny_does_not_emit_wildcard_acao():
"""Without CORSMiddleware installed, the server must not advertise
Access-Control-Allow-Origin at all (definitely not the wildcard)."""
from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware
"""Default-deny CORS (no --allowed-origin) must not advertise any
Access-Control-Allow-Origin, so a browser blocks cross-origin readers."""
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
# Default-deny CORS: no CORSMiddleware. Mirrors diffusion_server's behavior
# when no --allowed-origin flags are passed.
cors_origins = ns["_compute_cors_origins"]()
assert cors_origins == []
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
@app.get("/v1/models")
def list_models():
return {"data": []}
app = _configured_app(ns, cors_origins)
# Host is allowed, so the request itself succeeds — but the response must
# carry no ACAO, so a real browser would block the attacker page from
# reading the body.
resp = _asgi_get(
app,
"http://127.0.0.1/v1/models",
headers={"Origin": "https://evil.example.com"},
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
)
assert resp.status_code == 200
acao = resp.headers.get("access-control-allow-origin")
assert acao is None or acao == "", (
f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, "
@@ -200,41 +243,83 @@ def test_cors_default_deny_does_not_emit_wildcard_acao():
def test_explicit_cors_origin_does_not_widen_to_wildcard():
"""Even when the operator opts in to one cross-origin, that single origin
must not unlock a wildcard reflection for other origins."""
ns = _load_helpers()
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
app = _configured_app(ns, cors_origins)
# Allowed origin: ACAO echoes that origin (NOT '*').
ok = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "http://localhost:7000"}
)
assert ok.status_code == 200
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
bad = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
)
bad_acao = bad.headers.get("access-control-allow-origin")
assert bad_acao != "*"
assert bad_acao != "https://evil.example.com"
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_preserves_order():
"""CORS is added last so it wraps TrustedHost (outermost). The production
order must be user_middleware == [CORSMiddleware, TrustedHostMiddleware];
default-deny installs the Host allowlist alone."""
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
with_cors = _configured_app(ns, ns["_compute_cors_origins"](extras=["http://localhost:7000"]))
assert [m.cls for m in with_cors.user_middleware] == [CORSMiddleware, TrustedHostMiddleware]
default_deny = _configured_app(ns, [])
assert [m.cls for m in default_deny.user_middleware] == [TrustedHostMiddleware]
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_is_idempotent_before_serving():
"""Re-running configuration (module-load defaults, then CLI override)
replaces the stack rather than accumulating duplicate middleware."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
ns["_configure_security_middleware"](app, allowed, [])
ns["_configure_security_middleware"](
app, allowed, ns["_compute_cors_origins"](extras=["http://localhost:7000"])
)
@app.get("/v1/models")
def list_models():
return {"data": []}
classes = [m.cls for m in app.user_middleware]
assert classes == [CORSMiddleware, TrustedHostMiddleware]
assert classes.count(TrustedHostMiddleware) == 1
# Allowed origin: ACAO echoes that origin (NOT '*').
ok = _asgi_get(
app,
"http://127.0.0.1/v1/models",
headers={"Origin": "http://localhost:7000"},
)
assert ok.status_code == 200
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
bad = _asgi_get(
app,
"http://127.0.0.1/v1/models",
headers={"Origin": "https://evil.example.com"},
)
bad_acao = bad.headers.get("access-control-allow-origin")
assert bad_acao != "*"
assert bad_acao != "https://evil.example.com"
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_rejects_late_call():
"""Once the middleware stack is built, the helper must raise before
mutating user_middleware so a late reconfigure can't silently no-op."""
from fastapi import FastAPI
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
app = FastAPI()
ns["_configure_security_middleware"](app, allowed, [])
before = list(app.user_middleware)
# Simulate the app having started serving (stack built lazily on first req).
app.middleware_stack = app.build_middleware_stack()
assert app.middleware_stack is not None
with pytest.raises(RuntimeError):
ns["_configure_security_middleware"](app, ["lan.example"], [])
# Guard fired before mutating: user_middleware is untouched.
assert list(app.user_middleware) == before