test(diffusion-server): exercise security middleware wiring (#3214)

This commit is contained in:
Alexandre Teixeira
2026-06-07 22:42:11 +01:00
committed by GitHub
parent 92300b5d67
commit 9ad6a2809e
2 changed files with 199 additions and 101 deletions
+26 -13
View File
@@ -62,10 +62,6 @@ app = FastAPI(title="Diffusion Server", lifespan=lifespan)
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"] _DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
_DEFAULT_CORS_ORIGINS: list = [] # default-deny _DEFAULT_CORS_ORIGINS: list = [] # default-deny
# Install defaults at module load so importing the app for tests / direct
# uvicorn invocation still benefits from the Host-header allowlist.
app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(_DEFAULT_ALLOWED_HOSTS))
def _compute_allowed_hosts(bind_host: str, extras=None) -> list: def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
"""Allowed Host header values: the bind address + loopback variants + """Allowed Host header values: the bind address + loopback variants +
@@ -91,6 +87,31 @@ def _compute_cors_origins(extras=None) -> list:
return seen return seen
def _configure_security_middleware(application, allowed_hosts, allowed_origins):
"""Replace `application`'s user middleware stack with the diffusion server
security middleware: the TrustedHost allowlist and, when origins are
supplied, CORS. Used at module load and by the __main__ CLI path before
serving starts. Raises before mutating if the middleware stack has already
been built. Order is preserved: TrustedHost first, then CORS (added last ->
outermost)."""
if application.middleware_stack is not None:
raise RuntimeError("security middleware must be configured before the app starts serving")
application.user_middleware.clear()
application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts))
if allowed_origins:
application.add_middleware(
CORSMiddleware,
allow_origins=list(allowed_origins),
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
# Install defaults at module load so importing the app for tests / direct
# uvicorn invocation still benefits from the Host-header allowlist.
_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS)
class ImageRequest(BaseModel): class ImageRequest(BaseModel):
model: str = "" model: str = ""
prompt: str prompt: str
@@ -1141,15 +1162,7 @@ if __name__ == "__main__":
# here is safe. # here is safe.
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host) final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
final_origins = _compute_cors_origins(_args.allowed_origin) final_origins = _compute_cors_origins(_args.allowed_origin)
app.user_middleware.clear() _configure_security_middleware(app, final_hosts, final_origins)
app.add_middleware(TrustedHostMiddleware, allowed_hosts=final_hosts)
if final_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=final_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s", logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
final_hosts, final_origins or "(none — default-deny)") final_hosts, final_origins or "(none — default-deny)")
+173 -88
View File
@@ -12,14 +12,15 @@ Host-header allowlist as a positive defense. These tests pin the allowlist
helpers + Starlette's middleware behavior so a future change can't silently helpers + Starlette's middleware behavior so a future change can't silently
re-open the hole. re-open the hole.
The tests run against a tiny synthetic FastAPI app that uses the same The tests AST-extract the security helpers — including the real
``TrustedHostMiddleware`` + ``CORSMiddleware`` wiring as diffusion_server. ``_configure_security_middleware`` wiring — from diffusion_server.py and run
That keeps the test out of the torch / diffusers import path while still them against a fresh FastAPI app. That keeps the tests out of the torch /
covering the live middleware code paths. diffusers import path while still exercising the production middleware wiring
instead of a hand-rebuilt copy.
""" """
import ast
import importlib.util import importlib.util
import os
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -28,29 +29,49 @@ import pytest
_SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py" _SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py"
_EXPECTED_NAMES = (
"_DEFAULT_ALLOWED_HOSTS",
"_DEFAULT_CORS_ORIGINS",
"_compute_allowed_hosts",
"_compute_cors_origins",
"_configure_security_middleware",
)
def _load_helpers(): def _load_helpers():
"""Import the pure allowlist helpers from diffusion_server.py without """Extract the security helpers from diffusion_server.py via AST so the
triggering its torch / diffusers imports. We compile just the helper tests exercise the production wiring without importing the module (which
block (everything between the `app =` line and the `class ImageRequest` would pull in torch / diffusers). Only the named top-level definitions are
line) so heavy deps stay quarantined behind the if-False import guard. compiled into a fresh module; everything else — including the heavy
""" imports — is left out. A renamed or removed helper fails loudly here."""
src = _SCRIPT.read_text(encoding="utf-8") from fastapi.middleware.cors import CORSMiddleware
# The helpers live between the two markers, both inserted by the security from starlette.middleware.trustedhost import TrustedHostMiddleware
# fix. They depend only on the `_DEFAULT_ALLOWED_HOSTS` / `_DEFAULT_CORS_ORIGINS`
# module-level lists, which we materialise here. tree = ast.parse(_SCRIPT.read_text(encoding="utf-8"))
start_marker = "_DEFAULT_ALLOWED_HOSTS = " wanted: dict = {}
end_marker = "class ImageRequest(" for node in tree.body:
i = src.index(start_marker) if isinstance(node, ast.FunctionDef) and node.name in _EXPECTED_NAMES:
j = src.index(end_marker) wanted[node.name] = node
helper_block = src[i:j] elif isinstance(node, ast.Assign):
ns: dict = {"list": list} for target in node.targets:
# Strip the `app.add_middleware(...)` line — the helpers don't need it if isinstance(target, ast.Name) and target.id in _EXPECTED_NAMES:
# and it would force a torch import via fastapi.responses. wanted[target.id] = node
helper_block = "\n".join( elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
line for line in helper_block.splitlines() if node.target.id in _EXPECTED_NAMES:
if not line.startswith("app.add_middleware") wanted[node.target.id] = node
)
exec(compile(helper_block, str(_SCRIPT), "exec"), ns) missing = [name for name in _EXPECTED_NAMES if name not in wanted]
assert not missing, f"diffusion_server.py is missing expected helpers: {missing}"
module = ast.Module(body=[wanted[name] for name in _EXPECTED_NAMES], type_ignores=[])
ast.fix_missing_locations(module)
ns: dict = {
"TrustedHostMiddleware": TrustedHostMiddleware,
"CORSMiddleware": CORSMiddleware,
"RuntimeError": RuntimeError,
"list": list,
}
exec(compile(module, str(_SCRIPT), "exec"), ns)
return ns return ns
@@ -77,6 +98,15 @@ def test_compute_allowed_hosts_does_not_add_wildcard():
assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole" assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole"
def test_compute_allowed_hosts_preserves_explicit_wildcard():
# Behavior preservation: a wildcard is not added by default, but an
# operator who explicitly passes one is taken at their word (deduped,
# stripped, stable order). This pins current behavior, not policy.
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["*", " lan.example ", "*"])
assert out == ["127.0.0.1", "localhost", "::1", "*", "lan.example"]
def test_compute_cors_origins_default_deny(): def test_compute_cors_origins_default_deny():
ns = _load_helpers() ns = _load_helpers()
out = ns["_compute_cors_origins"]() out = ns["_compute_cors_origins"]()
@@ -99,6 +129,15 @@ def test_compute_cors_origins_honours_explicit_extras():
assert out == ["http://localhost:7000"] assert out == ["http://localhost:7000"]
def test_compute_cors_origins_preserves_explicit_wildcard():
# Behavior preservation: a wildcard is not the default, but an operator
# who explicitly passes one is taken at their word (deduped, stripped,
# stable order). This pins current behavior, not policy.
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=["*", " http://localhost:7000 ", "*"])
assert out == ["*", "http://localhost:7000"]
# ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ───── # ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ─────
@@ -134,61 +173,65 @@ def _asgi_get(app, url, headers=None):
return asyncio.run(_run()) return asyncio.run(_run())
def _configured_app(ns, allowed_origins, route_called=None):
"""Fresh FastAPI app wired by the production `_configure_security_middleware`
with a loopback Host allowlist, plus a minimal route so accepted requests
can assert 200. If `route_called` is given, the route sets
``route_called["hit"] = True`` so callers can prove whether the inner app
was reached."""
from fastapi import FastAPI
app = FastAPI()
ns["_configure_security_middleware"](
app, ns["_compute_allowed_hosts"]("127.0.0.1"), allowed_origins
)
@app.get("/")
def root():
if route_called is not None:
route_called["hit"] = True
return {"ok": True}
return app
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") @pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_trusted_host_middleware_rejects_attacker_host(): def test_trusted_host_middleware_rejects_attacker_host():
"""A request with an attacker-controlled Host header (the DNS-rebinding """A request with an attacker-controlled Host header (the DNS-rebinding
surface) must be rejected by the middleware before reaching any route.""" surface) must be rejected by the production wiring before any route runs."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware # noqa: F401 (parity import)
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers() ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1") route_called = {"hit": False}
app = _configured_app(ns, [], route_called=route_called)
app = FastAPI() # Legitimate request (Host: 127.0.0.1) reaches the route.
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed) ok = _asgi_get(app, "http://127.0.0.1/")
@app.get("/health")
def health():
return {"status": "ok"}
# Legitimate request (Host: 127.0.0.1) goes through.
ok = _asgi_get(app, "http://127.0.0.1/health")
assert ok.status_code == 200 assert ok.status_code == 200
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected. assert route_called["hit"] is True
bad = _asgi_get(app, "http://evil.example.com/health") # Attacker-controlled hostname (DNS-rebinding scenario) is rejected before
# the route runs.
route_called["hit"] = False
bad = _asgi_get(app, "http://evil.example.com/")
assert bad.status_code == 400 assert bad.status_code == 400
assert route_called["hit"] is False
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") @pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_cors_default_deny_does_not_emit_wildcard_acao(): def test_cors_default_deny_does_not_emit_wildcard_acao():
"""Without CORSMiddleware installed, the server must not advertise """Default-deny CORS (no --allowed-origin) must not advertise any
Access-Control-Allow-Origin at all (definitely not the wildcard).""" Access-Control-Allow-Origin, so a browser blocks cross-origin readers."""
from fastapi import FastAPI
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers() ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
# Default-deny CORS: no CORSMiddleware. Mirrors diffusion_server's behavior
# when no --allowed-origin flags are passed.
cors_origins = ns["_compute_cors_origins"]() cors_origins = ns["_compute_cors_origins"]()
assert cors_origins == [] assert cors_origins == []
app = FastAPI() app = _configured_app(ns, cors_origins)
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
@app.get("/v1/models")
def list_models():
return {"data": []}
# Host is allowed, so the request itself succeeds — but the response must # Host is allowed, so the request itself succeeds — but the response must
# carry no ACAO, so a real browser would block the attacker page from # carry no ACAO, so a real browser would block the attacker page from
# reading the body. # reading the body.
resp = _asgi_get( resp = _asgi_get(
app, app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
"http://127.0.0.1/v1/models",
headers={"Origin": "https://evil.example.com"},
) )
assert resp.status_code == 200
acao = resp.headers.get("access-control-allow-origin") acao = resp.headers.get("access-control-allow-origin")
assert acao is None or acao == "", ( assert acao is None or acao == "", (
f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, " f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, "
@@ -200,41 +243,83 @@ def test_cors_default_deny_does_not_emit_wildcard_acao():
def test_explicit_cors_origin_does_not_widen_to_wildcard(): def test_explicit_cors_origin_does_not_widen_to_wildcard():
"""Even when the operator opts in to one cross-origin, that single origin """Even when the operator opts in to one cross-origin, that single origin
must not unlock a wildcard reflection for other origins.""" must not unlock a wildcard reflection for other origins."""
ns = _load_helpers()
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
app = _configured_app(ns, cors_origins)
# Allowed origin: ACAO echoes that origin (NOT '*').
ok = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "http://localhost:7000"}
)
assert ok.status_code == 200
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
bad = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
)
bad_acao = bad.headers.get("access-control-allow-origin")
assert bad_acao != "*"
assert bad_acao != "https://evil.example.com"
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_preserves_order():
"""CORS is added last so it wraps TrustedHost (outermost). The production
order must be user_middleware == [CORSMiddleware, TrustedHostMiddleware];
default-deny installs the Host allowlist alone."""
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
with_cors = _configured_app(ns, ns["_compute_cors_origins"](extras=["http://localhost:7000"]))
assert [m.cls for m in with_cors.user_middleware] == [CORSMiddleware, TrustedHostMiddleware]
default_deny = _configured_app(ns, [])
assert [m.cls for m in default_deny.user_middleware] == [TrustedHostMiddleware]
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_is_idempotent_before_serving():
"""Re-running configuration (module-load defaults, then CLI override)
replaces the stack rather than accumulating duplicate middleware."""
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers() ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1") allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
app = FastAPI() app = FastAPI()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed) ns["_configure_security_middleware"](app, allowed, [])
app.add_middleware( ns["_configure_security_middleware"](
CORSMiddleware, app, allowed, ns["_compute_cors_origins"](extras=["http://localhost:7000"])
allow_origins=cors_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
) )
@app.get("/v1/models") classes = [m.cls for m in app.user_middleware]
def list_models(): assert classes == [CORSMiddleware, TrustedHostMiddleware]
return {"data": []} assert classes.count(TrustedHostMiddleware) == 1
# Allowed origin: ACAO echoes that origin (NOT '*').
ok = _asgi_get( @pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
app, def test_configure_security_middleware_rejects_late_call():
"http://127.0.0.1/v1/models", """Once the middleware stack is built, the helper must raise before
headers={"Origin": "http://localhost:7000"}, mutating user_middleware so a late reconfigure can't silently no-op."""
) from fastapi import FastAPI
assert ok.status_code == 200
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000" ns = _load_helpers()
# Foreign origin: ACAO must NOT echo it, must NOT be '*'. allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
bad = _asgi_get(
app, app = FastAPI()
"http://127.0.0.1/v1/models", ns["_configure_security_middleware"](app, allowed, [])
headers={"Origin": "https://evil.example.com"}, before = list(app.user_middleware)
)
bad_acao = bad.headers.get("access-control-allow-origin") # Simulate the app having started serving (stack built lazily on first req).
assert bad_acao != "*" app.middleware_stack = app.build_middleware_stack()
assert bad_acao != "https://evil.example.com" assert app.middleware_stack is not None
with pytest.raises(RuntimeError):
ns["_configure_security_middleware"](app, ["lan.example"], [])
# Guard fired before mutating: user_middleware is untouched.
assert list(app.user_middleware) == before