From 9ad6a2809e29e3e6f28034d79c6851bfd64b369e Mon Sep 17 00:00:00 2001 From: Alexandre Teixeira <111787685+alteixeira20@users.noreply.github.com> Date: Sun, 7 Jun 2026 22:42:11 +0100 Subject: [PATCH] test(diffusion-server): exercise security middleware wiring (#3214) --- scripts/diffusion_server.py | 39 ++-- tests/test_diffusion_server_security.py | 261 ++++++++++++++++-------- 2 files changed, 199 insertions(+), 101 deletions(-) diff --git a/scripts/diffusion_server.py b/scripts/diffusion_server.py index 281ce2c6d..71da9ed0c 100644 --- a/scripts/diffusion_server.py +++ b/scripts/diffusion_server.py @@ -62,10 +62,6 @@ app = FastAPI(title="Diffusion Server", lifespan=lifespan) _DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"] _DEFAULT_CORS_ORIGINS: list = [] # default-deny -# Install defaults at module load so importing the app for tests / direct -# uvicorn invocation still benefits from the Host-header allowlist. -app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(_DEFAULT_ALLOWED_HOSTS)) - def _compute_allowed_hosts(bind_host: str, extras=None) -> list: """Allowed Host header values: the bind address + loopback variants + @@ -91,6 +87,31 @@ def _compute_cors_origins(extras=None) -> list: return seen +def _configure_security_middleware(application, allowed_hosts, allowed_origins): + """Replace `application`'s user middleware stack with the diffusion server + security middleware: the TrustedHost allowlist and, when origins are + supplied, CORS. Used at module load and by the __main__ CLI path before + serving starts. Raises before mutating if the middleware stack has already + been built. Order is preserved: TrustedHost first, then CORS (added last -> + outermost).""" + if application.middleware_stack is not None: + raise RuntimeError("security middleware must be configured before the app starts serving") + application.user_middleware.clear() + application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts)) + if allowed_origins: + application.add_middleware( + CORSMiddleware, + allow_origins=list(allowed_origins), + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["Authorization", "Content-Type"], + ) + + +# Install defaults at module load so importing the app for tests / direct +# uvicorn invocation still benefits from the Host-header allowlist. +_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS) + + class ImageRequest(BaseModel): model: str = "" prompt: str @@ -1141,15 +1162,7 @@ if __name__ == "__main__": # here is safe. final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host) final_origins = _compute_cors_origins(_args.allowed_origin) - app.user_middleware.clear() - app.add_middleware(TrustedHostMiddleware, allowed_hosts=final_hosts) - if final_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=final_origins, - allow_methods=["GET", "POST", "OPTIONS"], - allow_headers=["Authorization", "Content-Type"], - ) + _configure_security_middleware(app, final_hosts, final_origins) logger.info("security middleware: allowed_hosts=%s allowed_origins=%s", final_hosts, final_origins or "(none — default-deny)") diff --git a/tests/test_diffusion_server_security.py b/tests/test_diffusion_server_security.py index f18972ff0..ba1253d6e 100644 --- a/tests/test_diffusion_server_security.py +++ b/tests/test_diffusion_server_security.py @@ -12,14 +12,15 @@ Host-header allowlist as a positive defense. These tests pin the allowlist helpers + Starlette's middleware behavior so a future change can't silently re-open the hole. -The tests run against a tiny synthetic FastAPI app that uses the same -``TrustedHostMiddleware`` + ``CORSMiddleware`` wiring as diffusion_server. -That keeps the test out of the torch / diffusers import path while still -covering the live middleware code paths. +The tests AST-extract the security helpers — including the real +``_configure_security_middleware`` wiring — from diffusion_server.py and run +them against a fresh FastAPI app. That keeps the tests out of the torch / +diffusers import path while still exercising the production middleware wiring +instead of a hand-rebuilt copy. """ +import ast import importlib.util -import os from pathlib import Path import pytest @@ -28,29 +29,49 @@ import pytest _SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py" +_EXPECTED_NAMES = ( + "_DEFAULT_ALLOWED_HOSTS", + "_DEFAULT_CORS_ORIGINS", + "_compute_allowed_hosts", + "_compute_cors_origins", + "_configure_security_middleware", +) + + def _load_helpers(): - """Import the pure allowlist helpers from diffusion_server.py without - triggering its torch / diffusers imports. We compile just the helper - block (everything between the `app =` line and the `class ImageRequest` - line) so heavy deps stay quarantined behind the if-False import guard. - """ - src = _SCRIPT.read_text(encoding="utf-8") - # The helpers live between the two markers, both inserted by the security - # fix. They depend only on the `_DEFAULT_ALLOWED_HOSTS` / `_DEFAULT_CORS_ORIGINS` - # module-level lists, which we materialise here. - start_marker = "_DEFAULT_ALLOWED_HOSTS = " - end_marker = "class ImageRequest(" - i = src.index(start_marker) - j = src.index(end_marker) - helper_block = src[i:j] - ns: dict = {"list": list} - # Strip the `app.add_middleware(...)` line — the helpers don't need it - # and it would force a torch import via fastapi.responses. - helper_block = "\n".join( - line for line in helper_block.splitlines() - if not line.startswith("app.add_middleware") - ) - exec(compile(helper_block, str(_SCRIPT), "exec"), ns) + """Extract the security helpers from diffusion_server.py via AST so the + tests exercise the production wiring without importing the module (which + would pull in torch / diffusers). Only the named top-level definitions are + compiled into a fresh module; everything else — including the heavy + imports — is left out. A renamed or removed helper fails loudly here.""" + from fastapi.middleware.cors import CORSMiddleware + from starlette.middleware.trustedhost import TrustedHostMiddleware + + tree = ast.parse(_SCRIPT.read_text(encoding="utf-8")) + wanted: dict = {} + for node in tree.body: + if isinstance(node, ast.FunctionDef) and node.name in _EXPECTED_NAMES: + wanted[node.name] = node + elif isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id in _EXPECTED_NAMES: + wanted[target.id] = node + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + if node.target.id in _EXPECTED_NAMES: + wanted[node.target.id] = node + + missing = [name for name in _EXPECTED_NAMES if name not in wanted] + assert not missing, f"diffusion_server.py is missing expected helpers: {missing}" + + module = ast.Module(body=[wanted[name] for name in _EXPECTED_NAMES], type_ignores=[]) + ast.fix_missing_locations(module) + ns: dict = { + "TrustedHostMiddleware": TrustedHostMiddleware, + "CORSMiddleware": CORSMiddleware, + "RuntimeError": RuntimeError, + "list": list, + } + exec(compile(module, str(_SCRIPT), "exec"), ns) return ns @@ -77,6 +98,15 @@ def test_compute_allowed_hosts_does_not_add_wildcard(): assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole" +def test_compute_allowed_hosts_preserves_explicit_wildcard(): + # Behavior preservation: a wildcard is not added by default, but an + # operator who explicitly passes one is taken at their word (deduped, + # stripped, stable order). This pins current behavior, not policy. + ns = _load_helpers() + out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["*", " lan.example ", "*"]) + assert out == ["127.0.0.1", "localhost", "::1", "*", "lan.example"] + + def test_compute_cors_origins_default_deny(): ns = _load_helpers() out = ns["_compute_cors_origins"]() @@ -99,6 +129,15 @@ def test_compute_cors_origins_honours_explicit_extras(): assert out == ["http://localhost:7000"] +def test_compute_cors_origins_preserves_explicit_wildcard(): + # Behavior preservation: a wildcard is not the default, but an operator + # who explicitly passes one is taken at their word (deduped, stripped, + # stable order). This pins current behavior, not policy. + ns = _load_helpers() + out = ns["_compute_cors_origins"](extras=["*", " http://localhost:7000 ", "*"]) + assert out == ["*", "http://localhost:7000"] + + # ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ───── @@ -134,61 +173,65 @@ def _asgi_get(app, url, headers=None): return asyncio.run(_run()) +def _configured_app(ns, allowed_origins, route_called=None): + """Fresh FastAPI app wired by the production `_configure_security_middleware` + with a loopback Host allowlist, plus a minimal route so accepted requests + can assert 200. If `route_called` is given, the route sets + ``route_called["hit"] = True`` so callers can prove whether the inner app + was reached.""" + from fastapi import FastAPI + + app = FastAPI() + ns["_configure_security_middleware"]( + app, ns["_compute_allowed_hosts"]("127.0.0.1"), allowed_origins + ) + + @app.get("/") + def root(): + if route_called is not None: + route_called["hit"] = True + return {"ok": True} + + return app + + @pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") def test_trusted_host_middleware_rejects_attacker_host(): """A request with an attacker-controlled Host header (the DNS-rebinding - surface) must be rejected by the middleware before reaching any route.""" - from fastapi import FastAPI - from fastapi.middleware.cors import CORSMiddleware # noqa: F401 (parity import) - from starlette.middleware.trustedhost import TrustedHostMiddleware - + surface) must be rejected by the production wiring before any route runs.""" ns = _load_helpers() - allowed = ns["_compute_allowed_hosts"]("127.0.0.1") + route_called = {"hit": False} + app = _configured_app(ns, [], route_called=route_called) - app = FastAPI() - app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed) - - @app.get("/health") - def health(): - return {"status": "ok"} - - # Legitimate request (Host: 127.0.0.1) goes through. - ok = _asgi_get(app, "http://127.0.0.1/health") + # Legitimate request (Host: 127.0.0.1) reaches the route. + ok = _asgi_get(app, "http://127.0.0.1/") assert ok.status_code == 200 - # Attacker-controlled hostname (DNS-rebinding scenario) is rejected. - bad = _asgi_get(app, "http://evil.example.com/health") + assert route_called["hit"] is True + # Attacker-controlled hostname (DNS-rebinding scenario) is rejected before + # the route runs. + route_called["hit"] = False + bad = _asgi_get(app, "http://evil.example.com/") assert bad.status_code == 400 + assert route_called["hit"] is False @pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") def test_cors_default_deny_does_not_emit_wildcard_acao(): - """Without CORSMiddleware installed, the server must not advertise - Access-Control-Allow-Origin at all (definitely not the wildcard).""" - from fastapi import FastAPI - from starlette.middleware.trustedhost import TrustedHostMiddleware - + """Default-deny CORS (no --allowed-origin) must not advertise any + Access-Control-Allow-Origin, so a browser blocks cross-origin readers.""" ns = _load_helpers() - allowed = ns["_compute_allowed_hosts"]("127.0.0.1") - # Default-deny CORS: no CORSMiddleware. Mirrors diffusion_server's behavior - # when no --allowed-origin flags are passed. cors_origins = ns["_compute_cors_origins"]() assert cors_origins == [] - app = FastAPI() - app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed) - - @app.get("/v1/models") - def list_models(): - return {"data": []} + app = _configured_app(ns, cors_origins) # Host is allowed, so the request itself succeeds — but the response must # carry no ACAO, so a real browser would block the attacker page from # reading the body. resp = _asgi_get( - app, - "http://127.0.0.1/v1/models", - headers={"Origin": "https://evil.example.com"}, + app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"} ) + assert resp.status_code == 200 acao = resp.headers.get("access-control-allow-origin") assert acao is None or acao == "", ( f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, " @@ -200,41 +243,83 @@ def test_cors_default_deny_does_not_emit_wildcard_acao(): def test_explicit_cors_origin_does_not_widen_to_wildcard(): """Even when the operator opts in to one cross-origin, that single origin must not unlock a wildcard reflection for other origins.""" + ns = _load_helpers() + cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"]) + + app = _configured_app(ns, cors_origins) + + # Allowed origin: ACAO echoes that origin (NOT '*'). + ok = _asgi_get( + app, "http://127.0.0.1/", headers={"Origin": "http://localhost:7000"} + ) + assert ok.status_code == 200 + assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000" + # Foreign origin: ACAO must NOT echo it, must NOT be '*'. + bad = _asgi_get( + app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"} + ) + bad_acao = bad.headers.get("access-control-allow-origin") + assert bad_acao != "*" + assert bad_acao != "https://evil.example.com" + + +@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") +def test_configure_security_middleware_preserves_order(): + """CORS is added last so it wraps TrustedHost (outermost). The production + order must be user_middleware == [CORSMiddleware, TrustedHostMiddleware]; + default-deny installs the Host allowlist alone.""" + from fastapi.middleware.cors import CORSMiddleware + from starlette.middleware.trustedhost import TrustedHostMiddleware + + ns = _load_helpers() + + with_cors = _configured_app(ns, ns["_compute_cors_origins"](extras=["http://localhost:7000"])) + assert [m.cls for m in with_cors.user_middleware] == [CORSMiddleware, TrustedHostMiddleware] + + default_deny = _configured_app(ns, []) + assert [m.cls for m in default_deny.user_middleware] == [TrustedHostMiddleware] + + +@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") +def test_configure_security_middleware_is_idempotent_before_serving(): + """Re-running configuration (module-load defaults, then CLI override) + replaces the stack rather than accumulating duplicate middleware.""" from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware ns = _load_helpers() allowed = ns["_compute_allowed_hosts"]("127.0.0.1") - cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"]) app = FastAPI() - app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed) - app.add_middleware( - CORSMiddleware, - allow_origins=cors_origins, - allow_methods=["GET", "POST", "OPTIONS"], - allow_headers=["Authorization", "Content-Type"], + ns["_configure_security_middleware"](app, allowed, []) + ns["_configure_security_middleware"]( + app, allowed, ns["_compute_cors_origins"](extras=["http://localhost:7000"]) ) - @app.get("/v1/models") - def list_models(): - return {"data": []} + classes = [m.cls for m in app.user_middleware] + assert classes == [CORSMiddleware, TrustedHostMiddleware] + assert classes.count(TrustedHostMiddleware) == 1 - # Allowed origin: ACAO echoes that origin (NOT '*'). - ok = _asgi_get( - app, - "http://127.0.0.1/v1/models", - headers={"Origin": "http://localhost:7000"}, - ) - assert ok.status_code == 200 - assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000" - # Foreign origin: ACAO must NOT echo it, must NOT be '*'. - bad = _asgi_get( - app, - "http://127.0.0.1/v1/models", - headers={"Origin": "https://evil.example.com"}, - ) - bad_acao = bad.headers.get("access-control-allow-origin") - assert bad_acao != "*" - assert bad_acao != "https://evil.example.com" + +@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed") +def test_configure_security_middleware_rejects_late_call(): + """Once the middleware stack is built, the helper must raise before + mutating user_middleware so a late reconfigure can't silently no-op.""" + from fastapi import FastAPI + + ns = _load_helpers() + allowed = ns["_compute_allowed_hosts"]("127.0.0.1") + + app = FastAPI() + ns["_configure_security_middleware"](app, allowed, []) + before = list(app.user_middleware) + + # Simulate the app having started serving (stack built lazily on first req). + app.middleware_stack = app.build_middleware_stack() + assert app.middleware_stack is not None + + with pytest.raises(RuntimeError): + ns["_configure_security_middleware"](app, ["lan.example"], []) + # Guard fired before mutating: user_middleware is untouched. + assert list(app.user_middleware) == before