Harden session endpoint owner scope (#1308)

This commit is contained in:
Vykos
2026-06-02 19:40:22 +02:00
committed by GitHub
parent 80de69ebb0
commit 4771d80eb2
6 changed files with 261 additions and 71 deletions
+28 -14
View File
@@ -72,13 +72,17 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
return sess in variants or sess.startswith(base + "/")
def _clear_orphaned_session_endpoint(sess) -> bool:
def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
if not getattr(sess, "endpoint_url", ""):
return False
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for ep in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
return False
@@ -118,7 +122,7 @@ def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
return wanted in {str(item).strip() for item in models}
def _is_image_generation_session(sess) -> bool:
def _is_image_generation_session(sess, owner: str | None = None) -> bool:
"""Whether this chat session should bypass text chat and generate images.
Model-name prefixes are explicit image models. Endpoint type is only used
@@ -137,7 +141,11 @@ def _is_image_generation_session(sess) -> bool:
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for endpoint in endpoints:
if (getattr(endpoint, "model_type", None) or "llm") != "image":
continue
@@ -152,7 +160,7 @@ def _is_image_generation_session(sess) -> bool:
return False
def _recover_empty_session_model(sess, session_id: str) -> bool:
def _recover_empty_session_model(sess, session_id: str, owner: str | None = None) -> bool:
"""Re-populate sess.model from the matching endpoint's cached models.
Covers the window between endpoint setup and the first chat send: the
@@ -172,7 +180,11 @@ def _recover_empty_session_model(sess, session_id: str) -> bool:
# cached model is the most defensible default.
ep = None
if getattr(sess, "endpoint_url", ""):
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for cand in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
ep = cand
@@ -251,13 +263,14 @@ def setup_chat_routes(
sess = session_manager.get_session(session)
except KeyError:
raise HTTPException(404, f"Session '{session}' not found")
if _clear_orphaned_session_endpoint(sess):
owner = get_current_user(request)
if _clear_orphaned_session_endpoint(sess, owner=owner):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Empty model + live endpoint = setup race (Issue #587). Repair from
# the endpoint's cached model list before privilege checks, which
# otherwise see "" and behave inconsistently with the allowlist.
_recover_empty_session_model(sess, session)
_recover_empty_session_model(sess, session, owner=owner)
if not getattr(sess, "model", "").strip():
raise HTTPException(
400,
@@ -401,7 +414,8 @@ def setup_chat_routes(
# but BEFORE loading. Prevents cross-user session hijack.
_verify_session_owner(request, session)
sess = session_manager.get_session(session)
if _clear_orphaned_session_endpoint(sess):
owner = get_current_user(request)
if _clear_orphaned_session_endpoint(sess, owner=owner):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Issue #587: picker shows a model from the endpoint cache but
# s.model never made it onto the DB row (first-send race after
@@ -409,7 +423,7 @@ def setup_chat_routes(
# the first cached model off the matching endpoint so the
# upstream isn't called with model="" (which surfaces as a
# generic 401/503).
_recover_empty_session_model(sess, session)
_recover_empty_session_model(sess, session, owner=owner)
if not getattr(sess, "model", "").strip():
raise HTTPException(
400,
@@ -431,7 +445,7 @@ def setup_chat_routes(
_enforce_chat_privileges(request, sess)
# Ensure session has auth headers
resolve_session_auth(sess, session)
resolve_session_auth(sess, session, owner=get_current_user(request))
# Check for research_pending BEFORE mode persist overwrites it
do_research = str(use_research).lower() == "true"
@@ -768,7 +782,7 @@ def setup_chat_routes(
# output. Resolved once per request.
try:
from src.endpoint_resolver import resolve_chat_fallback_candidates
_fallback_candidates = resolve_chat_fallback_candidates()
_fallback_candidates = resolve_chat_fallback_candidates(owner=_user)
except Exception:
_fallback_candidates = []
@@ -781,7 +795,7 @@ def setup_chat_routes(
_model_info["character_name"] = ctx.preset.character_name
yield f'data: {json.dumps(_model_info)}\n\n'
if _is_image_generation_session(sess):
if _is_image_generation_session(sess, owner=_user):
from src.settings import get_setting
if not get_setting("image_gen_enabled", True):
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
@@ -792,7 +806,7 @@ def setup_chat_routes(
_user_msg = message or ""
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
yield ": heartbeat\n\n"
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session)
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session, owner=_user)
_img_output = _img_result.get("results", _img_result.get("error", ""))
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):