mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Harden session endpoint owner scope (#1308)
This commit is contained in:
+28
-14
@@ -72,13 +72,17 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
return sess in variants or sess.startswith(base + "/")
|
||||
|
||||
|
||||
def _clear_orphaned_session_endpoint(sess) -> bool:
|
||||
def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
|
||||
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
|
||||
if not getattr(sess, "endpoint_url", ""):
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for ep in endpoints:
|
||||
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
|
||||
return False
|
||||
@@ -118,7 +122,7 @@ def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
|
||||
return wanted in {str(item).strip() for item in models}
|
||||
|
||||
|
||||
def _is_image_generation_session(sess) -> bool:
|
||||
def _is_image_generation_session(sess, owner: str | None = None) -> bool:
|
||||
"""Whether this chat session should bypass text chat and generate images.
|
||||
|
||||
Model-name prefixes are explicit image models. Endpoint type is only used
|
||||
@@ -137,7 +141,11 @@ def _is_image_generation_session(sess) -> bool:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for endpoint in endpoints:
|
||||
if (getattr(endpoint, "model_type", None) or "llm") != "image":
|
||||
continue
|
||||
@@ -152,7 +160,7 @@ def _is_image_generation_session(sess) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _recover_empty_session_model(sess, session_id: str) -> bool:
|
||||
def _recover_empty_session_model(sess, session_id: str, owner: str | None = None) -> bool:
|
||||
"""Re-populate sess.model from the matching endpoint's cached models.
|
||||
|
||||
Covers the window between endpoint setup and the first chat send: the
|
||||
@@ -172,7 +180,11 @@ def _recover_empty_session_model(sess, session_id: str) -> bool:
|
||||
# cached model is the most defensible default.
|
||||
ep = None
|
||||
if getattr(sess, "endpoint_url", ""):
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for cand in endpoints:
|
||||
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
|
||||
ep = cand
|
||||
@@ -251,13 +263,14 @@ def setup_chat_routes(
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session '{session}' not found")
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
owner = get_current_user(request)
|
||||
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
|
||||
# Empty model + live endpoint = setup race (Issue #587). Repair from
|
||||
# the endpoint's cached model list before privilege checks, which
|
||||
# otherwise see "" and behave inconsistently with the allowlist.
|
||||
_recover_empty_session_model(sess, session)
|
||||
_recover_empty_session_model(sess, session, owner=owner)
|
||||
if not getattr(sess, "model", "").strip():
|
||||
raise HTTPException(
|
||||
400,
|
||||
@@ -401,7 +414,8 @@ def setup_chat_routes(
|
||||
# but BEFORE loading. Prevents cross-user session hijack.
|
||||
_verify_session_owner(request, session)
|
||||
sess = session_manager.get_session(session)
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
owner = get_current_user(request)
|
||||
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
# Issue #587: picker shows a model from the endpoint cache but
|
||||
# s.model never made it onto the DB row (first-send race after
|
||||
@@ -409,7 +423,7 @@ def setup_chat_routes(
|
||||
# the first cached model off the matching endpoint so the
|
||||
# upstream isn't called with model="" (which surfaces as a
|
||||
# generic 401/503).
|
||||
_recover_empty_session_model(sess, session)
|
||||
_recover_empty_session_model(sess, session, owner=owner)
|
||||
if not getattr(sess, "model", "").strip():
|
||||
raise HTTPException(
|
||||
400,
|
||||
@@ -431,7 +445,7 @@ def setup_chat_routes(
|
||||
_enforce_chat_privileges(request, sess)
|
||||
|
||||
# Ensure session has auth headers
|
||||
resolve_session_auth(sess, session)
|
||||
resolve_session_auth(sess, session, owner=get_current_user(request))
|
||||
|
||||
# Check for research_pending BEFORE mode persist overwrites it
|
||||
do_research = str(use_research).lower() == "true"
|
||||
@@ -768,7 +782,7 @@ def setup_chat_routes(
|
||||
# output. Resolved once per request.
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_chat_fallback_candidates
|
||||
_fallback_candidates = resolve_chat_fallback_candidates()
|
||||
_fallback_candidates = resolve_chat_fallback_candidates(owner=_user)
|
||||
except Exception:
|
||||
_fallback_candidates = []
|
||||
|
||||
@@ -781,7 +795,7 @@ def setup_chat_routes(
|
||||
_model_info["character_name"] = ctx.preset.character_name
|
||||
yield f'data: {json.dumps(_model_info)}\n\n'
|
||||
|
||||
if _is_image_generation_session(sess):
|
||||
if _is_image_generation_session(sess, owner=_user):
|
||||
from src.settings import get_setting
|
||||
if not get_setting("image_gen_enabled", True):
|
||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||
@@ -792,7 +806,7 @@ def setup_chat_routes(
|
||||
_user_msg = message or ""
|
||||
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
|
||||
yield ": heartbeat\n\n"
|
||||
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session)
|
||||
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session, owner=_user)
|
||||
_img_output = _img_result.get("results", _img_result.get("error", ""))
|
||||
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
|
||||
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):
|
||||
|
||||
Reference in New Issue
Block a user