mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-30 00:22:10 -04:00
Harden session endpoint owner scope (#1308)
This commit is contained in:
+47
-15
@@ -148,8 +148,9 @@ async def auto_name_session(session_manager, sess):
|
||||
if not first_msg:
|
||||
return
|
||||
|
||||
owner = getattr(sess, "owner", None)
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
if not t_model:
|
||||
logger.debug("[auto-name] No model provided, skipping")
|
||||
@@ -311,7 +312,24 @@ def fire_message_event(request, webhook_manager, session_id: str, sess, message:
|
||||
fire_event("message_sent", user)
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str):
|
||||
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
if not session_url or not endpoint_base:
|
||||
return False
|
||||
try:
|
||||
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||
|
||||
sess_url = session_url.rstrip("/")
|
||||
base = normalize_base(endpoint_base).rstrip("/")
|
||||
return sess_url in {
|
||||
base,
|
||||
base + "/chat/completions",
|
||||
build_chat_url(base).rstrip("/"),
|
||||
}
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||
@@ -320,19 +338,33 @@ def resolve_session_auth(sess, session_id: str):
|
||||
return
|
||||
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
db = SessionLocal()
|
||||
try:
|
||||
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else ""
|
||||
if domain:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first()
|
||||
if ep and ep.api_key:
|
||||
sess.headers = build_headers(ep.api_key, ep.base_url)
|
||||
db.query(DBSession).filter(DBSession.id == session_id).update(
|
||||
{"headers": json.dumps(sess.headers)}
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||
if not target_url:
|
||||
return
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if owner:
|
||||
# Missing headers usually means "recover from the saved endpoint".
|
||||
# Scope that lookup to the session owner, otherwise two users
|
||||
# with similar endpoint URLs can borrow each other's API key.
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
for ep in q.all():
|
||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||
continue
|
||||
if not ep.api_key:
|
||||
return
|
||||
base = normalize_base(ep.base_url or "")
|
||||
sess.headers = build_headers(ep.api_key, base)
|
||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
update_q = update_q.filter(DBSession.owner == owner)
|
||||
update_q.update({"headers": sess.headers})
|
||||
db.commit()
|
||||
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
||||
return
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
@@ -806,7 +838,7 @@ def run_post_response_tasks(
|
||||
from services.memory.memory_extractor import extract_and_store
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
asyncio.create_task(extract_and_store(
|
||||
sess, memory_manager, memory_vector,
|
||||
@@ -843,7 +875,7 @@ def run_post_response_tasks(
|
||||
from services.memory.skill_extractor import maybe_extract_skill
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
s_url, s_model, s_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||
)
|
||||
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
|
||||
asyncio.create_task(maybe_extract_skill(
|
||||
|
||||
Reference in New Issue
Block a user