mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Harden session endpoint owner scope (#1308)
This commit is contained in:
+114
-31
@@ -58,23 +58,71 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["sessions"])
|
||||
|
||||
def _pick_endpoint_for_sort():
|
||||
|
||||
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||
if not callable(is_admin):
|
||||
return False
|
||||
try:
|
||||
return bool(is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _reject_raw_endpoint_url_for_non_admin(
|
||||
request: Request,
|
||||
user: str | None,
|
||||
endpoint_id: str | None,
|
||||
endpoint_url: str | None,
|
||||
) -> None:
|
||||
"""Require registered endpoints for signed-in non-admin session changes."""
|
||||
if endpoint_id and endpoint_id.strip():
|
||||
return
|
||||
if not endpoint_url:
|
||||
return
|
||||
# Raw URLs make the server dial whatever host the request supplies. For
|
||||
# non-admin users, require a saved endpoint row so normal owner scoping and
|
||||
# endpoint validation have already happened.
|
||||
if user and not _current_user_is_admin(request, user):
|
||||
raise HTTPException(403, "Choose a registered model endpoint")
|
||||
|
||||
|
||||
def _persist_session_headers(session_id: str, headers: dict | None) -> None:
|
||||
"""Persist endpoint auth headers for DB-backed session metadata."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.headers = headers or {}
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _pick_endpoint_for_sort(owner=None):
|
||||
"""Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default."""
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
# Try utility endpoint first (what the user configured for background tasks)
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||
if url and model:
|
||||
return url, model, headers
|
||||
# Fall back to task endpoint
|
||||
try:
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=owner)
|
||||
if url and model:
|
||||
return url, model, headers
|
||||
except Exception:
|
||||
pass
|
||||
# Fall back to default
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||
if url and model:
|
||||
return url, model, headers
|
||||
return None, None, None
|
||||
@@ -197,11 +245,41 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
endpoint_id: str = Form(""),
|
||||
):
|
||||
skip_val = str(skip_validation).lower() == "true"
|
||||
user = get_current_user(request)
|
||||
endpoint_api_key = ""
|
||||
endpoint_base_url = ""
|
||||
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||
if endpoint_id and endpoint_id.strip():
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
q = _db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.id == endpoint_id.strip(),
|
||||
ModelEndpoint.is_enabled == True,
|
||||
)
|
||||
if user:
|
||||
q = owner_filter(q, ModelEndpoint, user)
|
||||
endpoint_row = q.first()
|
||||
if not endpoint_row:
|
||||
raise HTTPException(400, "Model endpoint no longer exists")
|
||||
endpoint_base_url = endpoint_row.base_url or ""
|
||||
endpoint_api_key = endpoint_row.api_key or ""
|
||||
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
|
||||
finally:
|
||||
_db.close()
|
||||
|
||||
if not endpoint_url and not skip_val:
|
||||
raise HTTPException(400, "endpoint_url is required (choose from /api/models)")
|
||||
|
||||
model_to_use = model
|
||||
request_api_key = api_key.strip() if api_key else ""
|
||||
effective_api_key = request_api_key or endpoint_api_key
|
||||
validation_headers = None
|
||||
if effective_api_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
validation_headers = build_headers(effective_api_key, endpoint_base_url or endpoint_url)
|
||||
|
||||
if skip_val:
|
||||
# skip_validation = trust the caller and do NOT probe /v1/models.
|
||||
@@ -212,7 +290,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
elif not model_to_use:
|
||||
from src.llm_core import list_model_ids
|
||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
|
||||
headers=validation_headers)
|
||||
if not ids:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
# Default to the first CHAT model — endpoints often list embedding/
|
||||
@@ -227,7 +305,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
import os as _os
|
||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
|
||||
headers=validation_headers)
|
||||
if not avail:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
if model_to_use not in avail:
|
||||
@@ -252,22 +330,15 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
owner=user,
|
||||
)
|
||||
# Set auth headers for custom API-key endpoints
|
||||
resolved_key = api_key.strip() if api_key else ""
|
||||
resolved_key = request_api_key
|
||||
resolved_base = endpoint_url
|
||||
if not resolved_key and endpoint_id and endpoint_id.strip():
|
||||
from core.database import ModelEndpoint
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id.strip()).first()
|
||||
if ep and ep.api_key:
|
||||
resolved_key = ep.api_key
|
||||
resolved_base = ep.base_url
|
||||
finally:
|
||||
_db.close()
|
||||
if not resolved_key and endpoint_api_key:
|
||||
resolved_key = endpoint_api_key
|
||||
resolved_base = endpoint_base_url
|
||||
if resolved_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
session.headers = build_headers(resolved_key, resolved_base)
|
||||
session_manager.save_sessions()
|
||||
_persist_session_headers(sid, session.headers)
|
||||
# Fire webhook (sync-safe)
|
||||
if webhook_manager:
|
||||
webhook_manager.fire_and_forget("session.created", {
|
||||
@@ -313,27 +384,38 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
db.close()
|
||||
# Switch model/endpoint mid-session
|
||||
if model is not None and endpoint_url is not None:
|
||||
user = get_current_user(request)
|
||||
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||
endpoint_api_key = ""
|
||||
endpoint_base_url = ""
|
||||
if endpoint_id:
|
||||
from core.database import ModelEndpoint
|
||||
from src.auth_helpers import owner_filter
|
||||
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
q = _db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.id == endpoint_id,
|
||||
ModelEndpoint.is_enabled == True,
|
||||
)
|
||||
if user:
|
||||
q = owner_filter(q, ModelEndpoint, user)
|
||||
ep = q.first()
|
||||
if not ep:
|
||||
raise HTTPException(400, "Model endpoint no longer exists")
|
||||
endpoint_base_url = ep.base_url or ""
|
||||
endpoint_api_key = ep.api_key or ""
|
||||
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
|
||||
finally:
|
||||
_db.close()
|
||||
session.model = model
|
||||
session.endpoint_url = endpoint_url
|
||||
# Update auth headers from the endpoint's stored API key
|
||||
if endpoint_id:
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
||||
if ep and ep.api_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
session.headers = build_headers(ep.api_key, ep.base_url)
|
||||
finally:
|
||||
_db.close()
|
||||
if endpoint_api_key:
|
||||
from src.endpoint_resolver import build_headers
|
||||
session.headers = build_headers(endpoint_api_key, endpoint_base_url)
|
||||
else:
|
||||
session.headers = {}
|
||||
# Persist to DB
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -341,6 +423,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
if db_session:
|
||||
db_session.model = model
|
||||
db_session.endpoint_url = endpoint_url
|
||||
db_session.headers = session.headers or {}
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
finally:
|
||||
@@ -754,7 +837,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
||||
if not url or not model:
|
||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||
if not url or not model:
|
||||
@@ -954,9 +1037,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
|
||||
# Pick an endpoint — prefer admin-configured task endpoint
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
url, model, headers = resolve_task_endpoint()
|
||||
url, model, headers = resolve_task_endpoint(owner=user)
|
||||
if not url:
|
||||
url, model, headers = _pick_endpoint_for_sort()
|
||||
url, model, headers = _pick_endpoint_for_sort(owner=user)
|
||||
if not url:
|
||||
raise HTTPException(503, "No available model endpoint for auto-sort")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user