mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-15 17:25:26 -04:00
test(diffusion-server): exercise security middleware wiring (#3214)
This commit is contained in:
committed by
GitHub
parent
92300b5d67
commit
9ad6a2809e
+26
-13
@@ -62,10 +62,6 @@ app = FastAPI(title="Diffusion Server", lifespan=lifespan)
|
||||
_DEFAULT_ALLOWED_HOSTS = ["127.0.0.1", "localhost", "::1"]
|
||||
_DEFAULT_CORS_ORIGINS: list = [] # default-deny
|
||||
|
||||
# Install defaults at module load so importing the app for tests / direct
|
||||
# uvicorn invocation still benefits from the Host-header allowlist.
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=list(_DEFAULT_ALLOWED_HOSTS))
|
||||
|
||||
|
||||
def _compute_allowed_hosts(bind_host: str, extras=None) -> list:
|
||||
"""Allowed Host header values: the bind address + loopback variants +
|
||||
@@ -91,6 +87,31 @@ def _compute_cors_origins(extras=None) -> list:
|
||||
return seen
|
||||
|
||||
|
||||
def _configure_security_middleware(application, allowed_hosts, allowed_origins):
|
||||
"""Replace `application`'s user middleware stack with the diffusion server
|
||||
security middleware: the TrustedHost allowlist and, when origins are
|
||||
supplied, CORS. Used at module load and by the __main__ CLI path before
|
||||
serving starts. Raises before mutating if the middleware stack has already
|
||||
been built. Order is preserved: TrustedHost first, then CORS (added last ->
|
||||
outermost)."""
|
||||
if application.middleware_stack is not None:
|
||||
raise RuntimeError("security middleware must be configured before the app starts serving")
|
||||
application.user_middleware.clear()
|
||||
application.add_middleware(TrustedHostMiddleware, allowed_hosts=list(allowed_hosts))
|
||||
if allowed_origins:
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=list(allowed_origins),
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
|
||||
# Install defaults at module load so importing the app for tests / direct
|
||||
# uvicorn invocation still benefits from the Host-header allowlist.
|
||||
_configure_security_middleware(app, _DEFAULT_ALLOWED_HOSTS, _DEFAULT_CORS_ORIGINS)
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
model: str = ""
|
||||
prompt: str
|
||||
@@ -1141,15 +1162,7 @@ if __name__ == "__main__":
|
||||
# here is safe.
|
||||
final_hosts = _compute_allowed_hosts(_args.host, _args.allowed_host)
|
||||
final_origins = _compute_cors_origins(_args.allowed_origin)
|
||||
app.user_middleware.clear()
|
||||
app.add_middleware(TrustedHostMiddleware, allowed_hosts=final_hosts)
|
||||
if final_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=final_origins,
|
||||
allow_methods=["GET", "POST", "OPTIONS"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
_configure_security_middleware(app, final_hosts, final_origins)
|
||||
logger.info("security middleware: allowed_hosts=%s allowed_origins=%s",
|
||||
final_hosts, final_origins or "(none — default-deny)")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user