Files
odysseus/tests/test_diffusion_server_security.py

326 lines
13 KiB
Python

"""Pin the diffusion_server DNS-rebinding + wildcard-CORS regression.
Background: scripts/diffusion_server.py used to ship `allow_origins=["*"]`
with the default `--host=127.0.0.1` bind. Combined, that left the OpenAI-
compatible image API reachable from any browser tab via DNS-rebinding: an
attacker page resolves its own domain to 127.0.0.1 mid-fetch, the browser
forwards the request to the loopback server, and the wildcard CORS reply
lets the attacker page read the result + drive the GPU.
The fix narrows CORS to default-deny and adds a TrustedHostMiddleware
Host-header allowlist as a positive defense. These tests pin the allowlist
helpers + Starlette's middleware behavior so a future change can't silently
re-open the hole.
The tests AST-extract the security helpers — including the real
``_configure_security_middleware`` wiring — from diffusion_server.py and run
them against a fresh FastAPI app. That keeps the tests out of the torch /
diffusers import path while still exercising the production middleware wiring
instead of a hand-rebuilt copy.
"""
import ast
import importlib.util
from pathlib import Path
import pytest
_SCRIPT = Path(__file__).resolve().parent.parent / "scripts" / "diffusion_server.py"
_EXPECTED_NAMES = (
"_DEFAULT_ALLOWED_HOSTS",
"_DEFAULT_CORS_ORIGINS",
"_compute_allowed_hosts",
"_compute_cors_origins",
"_configure_security_middleware",
)
def _load_helpers():
"""Extract the security helpers from diffusion_server.py via AST so the
tests exercise the production wiring without importing the module (which
would pull in torch / diffusers). Only the named top-level definitions are
compiled into a fresh module; everything else — including the heavy
imports — is left out. A renamed or removed helper fails loudly here."""
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
tree = ast.parse(_SCRIPT.read_text(encoding="utf-8"))
wanted: dict = {}
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name in _EXPECTED_NAMES:
wanted[node.name] = node
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id in _EXPECTED_NAMES:
wanted[target.id] = node
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
if node.target.id in _EXPECTED_NAMES:
wanted[node.target.id] = node
missing = [name for name in _EXPECTED_NAMES if name not in wanted]
assert not missing, f"diffusion_server.py is missing expected helpers: {missing}"
module = ast.Module(body=[wanted[name] for name in _EXPECTED_NAMES], type_ignores=[])
ast.fix_missing_locations(module)
ns: dict = {
"TrustedHostMiddleware": TrustedHostMiddleware,
"CORSMiddleware": CORSMiddleware,
"RuntimeError": RuntimeError,
"list": list,
}
exec(compile(module, str(_SCRIPT), "exec"), ns)
return ns
def test_compute_allowed_hosts_includes_loopback_and_bind_host():
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("0.0.0.0")
assert "0.0.0.0" in out
assert "127.0.0.1" in out
assert "localhost" in out
assert "::1" in out
def test_compute_allowed_hosts_dedupes_and_strips():
ns = _load_helpers()
# Bind host duplicates a default + an extra duplicates a default + blanks
# all collapse into one entry per unique value, preserving stable order.
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["localhost", "", " ", "lan.example"])
assert out == ["127.0.0.1", "localhost", "::1", "lan.example"]
def test_compute_allowed_hosts_does_not_add_wildcard():
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("127.0.0.1")
assert "*" not in out, "wildcard host would re-open the DNS-rebinding hole"
def test_compute_allowed_hosts_preserves_explicit_wildcard():
# Behavior preservation: a wildcard is not added by default, but an
# operator who explicitly passes one is taken at their word (deduped,
# stripped, stable order). This pins current behavior, not policy.
ns = _load_helpers()
out = ns["_compute_allowed_hosts"]("127.0.0.1", extras=["*", " lan.example ", "*"])
assert out == ["127.0.0.1", "localhost", "::1", "*", "lan.example"]
def test_compute_cors_origins_default_deny():
ns = _load_helpers()
out = ns["_compute_cors_origins"]()
assert out == [], "default CORS allowlist must be empty (no cross-origin)"
def test_compute_cors_origins_does_not_default_to_wildcard():
"""Regression: the original code shipped allow_origins=['*']. The fix
must NOT bring that back even when the operator passes nothing."""
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=None)
assert "*" not in out
out2 = ns["_compute_cors_origins"](extras=[])
assert "*" not in out2
def test_compute_cors_origins_honours_explicit_extras():
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=["http://localhost:7000", "", "http://localhost:7000"])
assert out == ["http://localhost:7000"]
def test_compute_cors_origins_preserves_explicit_wildcard():
# Behavior preservation: a wildcard is not the default, but an operator
# who explicitly passes one is taken at their word (deduped, stripped,
# stable order). This pins current behavior, not policy.
ns = _load_helpers()
out = ns["_compute_cors_origins"](extras=["*", " http://localhost:7000 ", "*"])
assert out == ["*", "http://localhost:7000"]
# ── Live middleware integration: TrustedHostMiddleware + CORSMiddleware ─────
def _starlette_available() -> bool:
return importlib.util.find_spec("starlette") is not None
def _asgi_get(app, url, headers=None):
"""Drive a single GET against an ASGI ``app`` over httpx's in-process
``ASGITransport`` on a fresh event loop.
This deliberately avoids ``starlette.testclient.TestClient``: its
context-manager form spins up an ``anyio`` blocking portal (to run the
lifespan), which deadlocks under some pytest / anyio / asyncio test
configurations — the focused Host-header test hung indefinitely during
review (see PR #347). A direct ASGI call needs neither a portal nor a
lifespan, so it stays reliable regardless of the host project's async
test plugins.
The request ``Host`` is derived from ``url`` so the TrustedHost allowlist
sees exactly the hostname under test; ``Origin`` and friends go through
``headers``.
"""
import asyncio
import httpx
async def _run():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport) as client:
return await client.get(url, headers=headers or {})
return asyncio.run(_run())
def _configured_app(ns, allowed_origins, route_called=None):
"""Fresh FastAPI app wired by the production `_configure_security_middleware`
with a loopback Host allowlist, plus a minimal route so accepted requests
can assert 200. If `route_called` is given, the route sets
``route_called["hit"] = True`` so callers can prove whether the inner app
was reached."""
from fastapi import FastAPI
app = FastAPI()
ns["_configure_security_middleware"](
app, ns["_compute_allowed_hosts"]("127.0.0.1"), allowed_origins
)
@app.get("/")
def root():
if route_called is not None:
route_called["hit"] = True
return {"ok": True}
return app
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_trusted_host_middleware_rejects_attacker_host():
"""A request with an attacker-controlled Host header (the DNS-rebinding
surface) must be rejected by the production wiring before any route runs."""
ns = _load_helpers()
route_called = {"hit": False}
app = _configured_app(ns, [], route_called=route_called)
# Legitimate request (Host: 127.0.0.1) reaches the route.
ok = _asgi_get(app, "http://127.0.0.1/")
assert ok.status_code == 200
assert route_called["hit"] is True
# Attacker-controlled hostname (DNS-rebinding scenario) is rejected before
# the route runs.
route_called["hit"] = False
bad = _asgi_get(app, "http://evil.example.com/")
assert bad.status_code == 400
assert route_called["hit"] is False
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_cors_default_deny_does_not_emit_wildcard_acao():
"""Default-deny CORS (no --allowed-origin) must not advertise any
Access-Control-Allow-Origin, so a browser blocks cross-origin readers."""
ns = _load_helpers()
cors_origins = ns["_compute_cors_origins"]()
assert cors_origins == []
app = _configured_app(ns, cors_origins)
# Host is allowed, so the request itself succeeds — but the response must
# carry no ACAO, so a real browser would block the attacker page from
# reading the body.
resp = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
)
assert resp.status_code == 200
acao = resp.headers.get("access-control-allow-origin")
assert acao is None or acao == "", (
f"unexpected ACAO header: {acao!r} — the regression was wildcard CORS, "
f"so any non-empty default fails this gate"
)
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_explicit_cors_origin_does_not_widen_to_wildcard():
"""Even when the operator opts in to one cross-origin, that single origin
must not unlock a wildcard reflection for other origins."""
ns = _load_helpers()
cors_origins = ns["_compute_cors_origins"](extras=["http://localhost:7000"])
app = _configured_app(ns, cors_origins)
# Allowed origin: ACAO echoes that origin (NOT '*').
ok = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "http://localhost:7000"}
)
assert ok.status_code == 200
assert ok.headers.get("access-control-allow-origin") == "http://localhost:7000"
# Foreign origin: ACAO must NOT echo it, must NOT be '*'.
bad = _asgi_get(
app, "http://127.0.0.1/", headers={"Origin": "https://evil.example.com"}
)
bad_acao = bad.headers.get("access-control-allow-origin")
assert bad_acao != "*"
assert bad_acao != "https://evil.example.com"
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_preserves_order():
"""CORS is added last so it wraps TrustedHost (outermost). The production
order must be user_middleware == [CORSMiddleware, TrustedHostMiddleware];
default-deny installs the Host allowlist alone."""
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
with_cors = _configured_app(ns, ns["_compute_cors_origins"](extras=["http://localhost:7000"]))
assert [m.cls for m in with_cors.user_middleware] == [CORSMiddleware, TrustedHostMiddleware]
default_deny = _configured_app(ns, [])
assert [m.cls for m in default_deny.user_middleware] == [TrustedHostMiddleware]
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_is_idempotent_before_serving():
"""Re-running configuration (module-load defaults, then CLI override)
replaces the stack rather than accumulating duplicate middleware."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
app = FastAPI()
ns["_configure_security_middleware"](app, allowed, [])
ns["_configure_security_middleware"](
app, allowed, ns["_compute_cors_origins"](extras=["http://localhost:7000"])
)
classes = [m.cls for m in app.user_middleware]
assert classes == [CORSMiddleware, TrustedHostMiddleware]
assert classes.count(TrustedHostMiddleware) == 1
@pytest.mark.skipif(not _starlette_available(), reason="starlette not installed")
def test_configure_security_middleware_rejects_late_call():
"""Once the middleware stack is built, the helper must raise before
mutating user_middleware so a late reconfigure can't silently no-op."""
from fastapi import FastAPI
ns = _load_helpers()
allowed = ns["_compute_allowed_hosts"]("127.0.0.1")
app = FastAPI()
ns["_configure_security_middleware"](app, allowed, [])
before = list(app.user_middleware)
# Simulate the app having started serving (stack built lazily on first req).
app.middleware_stack = app.build_middleware_stack()
assert app.middleware_stack is not None
with pytest.raises(RuntimeError):
ns["_configure_security_middleware"](app, ["lan.example"], [])
# Guard fired before mutating: user_middleware is untouched.
assert list(app.user_middleware) == before