test(diffusion-server): exercise security middleware wiring (#3214)

This commit is contained in:
Alexandre Teixeira
2026-06-07 22:42:11 +01:00
committed by GitHub
parent 92300b5d67
commit 9ad6a2809e
2 changed files with 199 additions and 101 deletions
+26 -13
View File
@@ -62,10 +62,6 @@ app = FastAPI(title="Diffusion Server", lifespan=lifespan)
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
# Install defaults at module load so importing the app for tests / direct
# uvicorn invocation still benefits from the Host-header allowlist.
app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(_DEFAULT_ALLOWED_HOSTS))
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
"""Allowed Host header values: the bind address + loopback variants +
@@ -91,6 +87,31 @@ def _compute_cors_origins(extras=None) -> list:
return seen
def _configure_security_middleware(application, allowed_hosts, allowed_origins):
"""Replace `application`'s user middleware stack with the diffusion server
security middleware: the TrustedHost allowlist and, when origins are
supplied, CORS. Used at module load and by the __main__ CLI path before
serving starts. Raises before mutating if the middleware stack has already
been built. Order is preserved: TrustedHost first, then CORS (added last ->
outermost)."""
if application.middleware_stack is not None:
raise RuntimeError("security middleware must be configured before the app starts serving")
application.user_middleware.clear()
application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts))
if allowed_origins:
application.add_middleware(
CORSMiddleware,
allow_origins=list(allowed_origins),
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
# Install defaults at module load so importing the app for tests / direct
# uvicorn invocation still benefits from the Host-header allowlist.
_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS)
class ImageRequest(BaseModel):
model: str = ""
prompt: str
@@ -1141,15 +1162,7 @@ if __name__ == "__main__":
# here is safe.
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
final_origins = _compute_cors_origins(_args.allowed_origin)
app.user_middleware.clear()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=final_hosts)
if final_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=final_origins,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
)
_configure_security_middleware(app, final_hosts, final_origins)
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
final_hosts, final_origins or "(none — default-deny)")