mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
Improve Ollama setup and model endpoint handling
This commit is contained in:
@@ -4,6 +4,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, AsyncGenerator, List
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Form, Query
|
||||
@@ -17,6 +18,7 @@ from src.agent_loop import stream_agent_loop
|
||||
from src import agent_runs
|
||||
from src.model_context import estimate_tokens
|
||||
from src.chat_helpers import coerce_message_and_session
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from core.exceptions import SessionNotFoundError
|
||||
from src.auth_helpers import get_current_user
|
||||
@@ -87,6 +89,46 @@ def _message_needs_tools(text: str) -> bool:
|
||||
return any(p.search(text) for p in _TOOL_INTENT_PATTERNS)
|
||||
|
||||
|
||||
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
if not session_url or not endpoint_base:
|
||||
return False
|
||||
sess = session_url.rstrip("/")
|
||||
base = _normalize_base(endpoint_base).rstrip("/")
|
||||
variants = {
|
||||
base,
|
||||
base + "/chat/completions",
|
||||
build_chat_url(base).rstrip("/"),
|
||||
}
|
||||
return sess in variants or sess.startswith(base + "/")
|
||||
|
||||
|
||||
def _clear_orphaned_session_endpoint(sess) -> bool:
|
||||
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
|
||||
if not getattr(sess, "endpoint_url", ""):
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in endpoints:
|
||||
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
|
||||
return False
|
||||
db_session = db.query(DBSession).filter(DBSession.id == sess.id).first()
|
||||
if db_session:
|
||||
db_session.endpoint_url = ""
|
||||
db_session.model = ""
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
sess.endpoint_url = ""
|
||||
sess.model = ""
|
||||
sess.headers = {}
|
||||
return True
|
||||
except Exception:
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def setup_chat_routes(
|
||||
session_manager,
|
||||
chat_handler,
|
||||
@@ -121,6 +163,8 @@ def setup_chat_routes(
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session '{session}' not found")
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
|
||||
# Same allowed_models + daily-cap gate as chat_stream (mirror so the
|
||||
# non-streaming path can't be used to bypass).
|
||||
@@ -259,6 +303,8 @@ def setup_chat_routes(
|
||||
# but BEFORE loading. Prevents cross-user session hijack.
|
||||
_verify_session_owner(request, session)
|
||||
sess = session_manager.get_session(session)
|
||||
if _clear_orphaned_session_endpoint(sess):
|
||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||
except SessionNotFoundError as e:
|
||||
raise HTTPException(404, str(e))
|
||||
except (ValueError, ValidationError):
|
||||
|
||||
Reference in New Issue
Block a user