mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
test(diffusion-server): exercise security middleware wiring (#3214)
This commit is contained in:
committed by
GitHub
parent
92300b5d67
commit
9ad6a2809e
@@ -12,14 +12,15 @@ Host-header allowlist as a positive defense. These tests pin the allowlist
|
||||
helpers + Starlette's middleware behavior so a future change can't silently
|
||||
re-open the hole.
|
||||
|
||||
The tests run against a tiny synthetic FastAPI app that uses the same
|
||||
``TrustedHostMiddleware`` + ``CORSMiddleware`` wiring as diffusion_server.
|
||||
That keeps the test out of the torch / diffusers import path while still
|
||||
covering the live middleware code paths.
|
||||
The tests AST-extract the security helpers — including the real
|
||||
``_configure_security_middleware`` wiring — from diffusion_server.py and run
|
||||
them against a fresh FastAPI app. That keeps the tests out of the torch /
|
||||
diffusers import path while still exercising the production middleware wiring
|
||||
instead of a hand-rebuilt copy.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -28,29 +29,49 @@ import pytest
|
||||
_SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py"
|
||||
|
||||
|
||||
_EXPECTED_NAMES = (
|
||||
"_DEFAULT_ALLOWED_HOSTS",
|
||||
"_DEFAULT_CORS_ORIGINS",
|
||||
"_compute_allowed_hosts",
|
||||
"_compute_cors_origins",
|
||||
"_configure_security_middleware",
|
||||
)
|
||||
|
||||
|
||||
def _load_helpers():
|
||||
"""Import the pure allowlist helpers from diffusion_server.py without
|
||||
triggering its torch / diffusers imports. We compile just the helper
|
||||
block (everything between the `app =` line and the `class ImageRequest`
|
||||
line) so heavy deps stay quarantined behind the if-False import guard.
|
||||
"""
|
||||
src = _SCRIPT.read_text(encoding="utf-8")
|
||||
# The helpers live between the two markers, both inserted by the security
|
||||
# fix. They depend only on the `_DEFAULT_ALLOWED_HOSTS` / `_DEFAULT_CORS_ORIGINS`
|
||||
# module-level lists, which we materialise here.
|
||||
start_marker = "_DEFAULT_ALLOWED_HOSTS = "
|
||||
end_marker = "class ImageRequest("
|
||||
i = src.index(start_marker)
|
||||
j = src.index(end_marker)
|
||||
helper_block = src[i:j]
|
||||
ns: dict = {"list": list}
|
||||
# Strip the `app.add_middleware(...)` line — the helpers don't need it
|
||||
# and it would force a torch import via fastapi.responses.
|
||||
helper_block = "\n".join(
|
||||
line for line in helper_block.splitlines()
|
||||
if not line.startswith("app.add_middleware")
|
||||
)
|
||||
exec(compile(helper_block, str(_SCRIPT), "exec"), ns)
|
||||
"""Extract the security helpers from diffusion_server.py via AST so the
|
||||
tests exercise the production wiring without importing the module (which
|
||||
would pull in torch / diffusers). Only the named top-level definitions are
|
||||
compiled into a fresh module; everything else — including the heavy
|
||||
imports — is left out. A renamed or removed helper fails loudly here."""
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
tree = ast.parse(_SCRIPT.read_text(encoding="utf-8"))
|
||||
wanted: dict = {}
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.FunctionDef) and node.name in _EXPECTED_NAMES:
|
||||
wanted[node.name] = node
|
||||
elif isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id in _EXPECTED_NAMES:
|
||||
wanted[target.id] = node
|
||||
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
||||
if node.target.id in _EXPECTED_NAMES:
|
||||
wanted[node.target.id] = node
|
||||
|
||||
missing = [name for name in _EXPECTED_NAMES if name not in wanted]
|
||||
assert not missing, f"diffusion_server.py is missing expected helpers: {missing}"
|
||||
|
||||
module = ast.Module(body=[wanted[name] for name in _EXPECTED_NAMES], type_ignores=[])
|
||||
ast.fix_missing_locations(module)
|
||||
ns: dict = {
|
||||
"TrustedHostMiddleware": TrustedHostMiddleware,
|
||||
"CORSMiddleware": CORSMiddleware,
|
||||
"RuntimeError": RuntimeError,
|
||||
"list": list,
|
||||
}
|
||||
exec(compile(module, str(_SCRIPT), "exec"), ns)
|
||||
return ns
|
||||
|
||||
|
||||
@@ -77,6 +98,15 @@ def test_compute_allowed_hosts_does_not_add_wildcard():
|
||||
assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole"
|
||||
|
||||
|
||||
def test_compute_allowed_hosts_preserves_explicit_wildcard():
|
||||
# Behavior preservation: a wildcard is not added by default, but an
|
||||
# operator who explicitly passes one is taken at their word (deduped,
|
||||
# stripped, stable order). This pins current behavior, not policy.
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["*", " lan.example ", "*"])
|
||||
assert out == ["127.0.0.1", "localhost", "::1", "*", "lan.example"]
|
||||
|
||||
|
||||
def test_compute_cors_origins_default_deny():
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_cors_origins"]()
|
||||
@@ -99,6 +129,15 @@ def test_compute_cors_origins_honours_explicit_extras():
|
||||
assert out == ["http://localhost:7000"]
|
||||
|
||||
|
||||
def test_compute_cors_origins_preserves_explicit_wildcard():
|
||||
# Behavior preservation: a wildcard is not the default, but an operator
|
||||
# who explicitly passes one is taken at their word (deduped, stripped,
|
||||
# stable order). This pins current behavior, not policy.
|
||||
ns = _load_helpers()
|
||||
out = ns["_compute_cors_origins"](extras=["*", " http://localhost:7000 ", "*"])
|
||||
assert out == ["*", "http://localhost:7000"]
|
||||
|
||||
|
||||
# ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ─────
|
||||
|
||||
|
||||
@@ -134,61 +173,65 @@ def _asgi_get(app, url, headers=None):
|
||||
return asyncio.run(_run())
|
||||
|
||||
|
||||
def _configured_app(ns, allowed_origins, route_called=None):
|
||||
"""Fresh FastAPI app wired by the production `_configure_security_middleware`
|
||||
with a loopback Host allowlist, plus a minimal route so accepted requests
|
||||
can assert 200. If `route_called` is given, the route sets
|
||||
``route_called["hit"] = True`` so callers can prove whether the inner app
|
||||
was reached."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
ns["_configure_security_middleware"](
|
||||
app, ns["_compute_allowed_hosts"]("127.0.0.1"), allowed_origins
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
if route_called is not None:
|
||||
route_called["hit"] = True
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_trusted_host_middleware_rejects_attacker_host():
|
||||
"""A request with an attacker-controlled Host header (the DNS-rebinding
|
||||
surface) must be rejected by the middleware before reaching any route."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: F401 (parity import)
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
surface) must be rejected by the production wiring before any route runs."""
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
route_called = {"hit": False}
|
||||
app = _configured_app(ns, [], route_called=route_called)
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
# Legitimate request (Host: 127.0.0.1) goes through.
|
||||
ok = _asgi_get(app, "http://127.0.0.1/health")
|
||||
# Legitimate request (Host: 127.0.0.1) reaches the route.
|
||||
ok = _asgi_get(app, "http://127.0.0.1/")
|
||||
assert ok.status_code == 200
|
||||
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected.
|
||||
bad = _asgi_get(app, "http://evil.example.com/health")
|
||||
assert route_called["hit"] is True
|
||||
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected before
|
||||
# the route runs.
|
||||
route_called["hit"] = False
|
||||
bad = _asgi_get(app, "http://evil.example.com/")
|
||||
assert bad.status_code == 400
|
||||
assert route_called["hit"] is False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_cors_default_deny_does_not_emit_wildcard_acao():
|
||||
"""Without CORSMiddleware installed, the server must not advertise
|
||||
Access-Control-Allow-Origin at all (definitely not the wildcard)."""
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
"""Default-deny CORS (no --allowed-origin) must not advertise any
|
||||
Access-Control-Allow-Origin, so a browser blocks cross-origin readers."""
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
# Default-deny CORS: no CORSMiddleware. Mirrors diffusion_server's behavior
|
||||
# when no --allowed-origin flags are passed.
|
||||
cors_origins = ns["_compute_cors_origins"]()
|
||||
assert cors_origins == []
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models():
|
||||
return {"data": []}
|
||||
app = _configured_app(ns, cors_origins)
|
||||
|
||||
# Host is allowed, so the request itself succeeds — but the response must
|
||||
# carry no ACAO, so a real browser would block the attacker page from
|
||||
# reading the body.
|
||||
resp = _asgi_get(
|
||||
app,
|
||||
"http://127.0.0.1/v1/models",
|
||||
headers={"Origin": "https://evil.example.com"},
|
||||
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
acao = resp.headers.get("access-control-allow-origin")
|
||||
assert acao is None or acao == "", (
|
||||
f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, "
|
||||
@@ -200,41 +243,83 @@ def test_cors_default_deny_does_not_emit_wildcard_acao():
|
||||
def test_explicit_cors_origin_does_not_widen_to_wildcard():
|
||||
"""Even when the operator opts in to one cross-origin, that single origin
|
||||
must not unlock a wildcard reflection for other origins."""
|
||||
ns = _load_helpers()
|
||||
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
|
||||
|
||||
app = _configured_app(ns, cors_origins)
|
||||
|
||||
# Allowed origin: ACAO echoes that origin (NOT '*').
|
||||
ok = _asgi_get(
|
||||
app, "http://127.0.0.1/", headers={"Origin": "http://localhost:7000"}
|
||||
)
|
||||
assert ok.status_code == 200
|
||||
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
|
||||
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
|
||||
bad = _asgi_get(
|
||||
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
|
||||
)
|
||||
bad_acao = bad.headers.get("access-control-allow-origin")
|
||||
assert bad_acao != "*"
|
||||
assert bad_acao != "https://evil.example.com"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_configure_security_middleware_preserves_order():
|
||||
"""CORS is added last so it wraps TrustedHost (outermost). The production
|
||||
order must be user_middleware == [CORSMiddleware, TrustedHostMiddleware];
|
||||
default-deny installs the Host allowlist alone."""
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
ns = _load_helpers()
|
||||
|
||||
with_cors = _configured_app(ns, ns["_compute_cors_origins"](extras=["http://localhost:7000"]))
|
||||
assert [m.cls for m in with_cors.user_middleware] == [CORSMiddleware, TrustedHostMiddleware]
|
||||
|
||||
default_deny = _configured_app(ns, [])
|
||||
assert [m.cls for m in default_deny.user_middleware] == [TrustedHostMiddleware]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_configure_security_middleware_is_idempotent_before_serving():
|
||||
"""Re-running configuration (module-load defaults, then CLI override)
|
||||
replaces the stack rather than accumulating duplicate middleware."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
ns["_configure_security_middleware"](app, allowed, [])
|
||||
ns["_configure_security_middleware"](
|
||||
app, allowed, ns["_compute_cors_origins"](extras=["http://localhost:7000"])
|
||||
)
|
||||
|
||||
@app.get("/v1/models")
|
||||
def list_models():
|
||||
return {"data": []}
|
||||
classes = [m.cls for m in app.user_middleware]
|
||||
assert classes == [CORSMiddleware, TrustedHostMiddleware]
|
||||
assert classes.count(TrustedHostMiddleware) == 1
|
||||
|
||||
# Allowed origin: ACAO echoes that origin (NOT '*').
|
||||
ok = _asgi_get(
|
||||
app,
|
||||
"http://127.0.0.1/v1/models",
|
||||
headers={"Origin": "http://localhost:7000"},
|
||||
)
|
||||
assert ok.status_code == 200
|
||||
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
|
||||
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
|
||||
bad = _asgi_get(
|
||||
app,
|
||||
"http://127.0.0.1/v1/models",
|
||||
headers={"Origin": "https://evil.example.com"},
|
||||
)
|
||||
bad_acao = bad.headers.get("access-control-allow-origin")
|
||||
assert bad_acao != "*"
|
||||
assert bad_acao != "https://evil.example.com"
|
||||
|
||||
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
|
||||
def test_configure_security_middleware_rejects_late_call():
|
||||
"""Once the middleware stack is built, the helper must raise before
|
||||
mutating user_middleware so a late reconfigure can't silently no-op."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
ns = _load_helpers()
|
||||
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
|
||||
|
||||
app = FastAPI()
|
||||
ns["_configure_security_middleware"](app, allowed, [])
|
||||
before = list(app.user_middleware)
|
||||
|
||||
# Simulate the app having started serving (stack built lazily on first req).
|
||||
app.middleware_stack = app.build_middleware_stack()
|
||||
assert app.middleware_stack is not None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
ns["_configure_security_middleware"](app, ["lan.example"], [])
|
||||
# Guard fired before mutating: user_middleware is untouched.
|
||||
assert list(app.user_middleware) == before
|
||||
|
||||
Reference in New Issue
Block a user