Improve Ollama setup and model endpoint handling

This commit is contained in:
pewdiepie-archdaemon
2026-06-01 10:00:15 +09:00
parent 051751adcd
commit fc7f107b22
22 changed files with 982 additions and 131 deletions
+4 -1
View File
@@ -482,7 +482,10 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
}
return {"ok": False, "message": f"ntfy returned HTTP {r.status_code} from {full_url}: {r.text[:200]}"}
except Exception as e:
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}"[:300]}
hint = ""
if parsed.hostname not in ("127.0.0.1", "localhost"):
hint = " If this is Docker Compose ntfy, set NTFY_BIND to that host/Tailscale IP and NTFY_BASE_URL to the same server URL in .env, then recreate ntfy."
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}.{hint}"[:500]}
# All other presets: GET against a known health endpoint.
# Fall back to detecting from name if preset is missing.
+2 -1
View File
@@ -902,7 +902,8 @@ def setup_calendar_routes() -> APIRouter:
lines.append(f"DTSTART:{ev.dtstart.strftime('%Y%m%dT%H%M%S')}")
lines.append(f"DTEND:{ev.dtend.strftime('%Y%m%dT%H%M%S')}")
if ev.description:
lines.append(f"DESCRIPTION:{ev.description.replace(chr(10), '\\n')}")
escaped_desc = ev.description.replace(chr(10), "\\n")
lines.append(f"DESCRIPTION:{escaped_desc}")
if ev.location:
lines.append(f"LOCATION:{ev.location}")
if ev.rrule:
+46
View File
@@ -4,6 +4,7 @@ import asyncio
import json
import time
import logging
from datetime import datetime
from typing import Dict, Any, AsyncGenerator, List
from fastapi import APIRouter, Request, HTTPException, Form, Query
@@ -17,6 +18,7 @@ from src.agent_loop import stream_agent_loop
from src import agent_runs
from src.model_context import estimate_tokens
from src.chat_helpers import coerce_message_and_session
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url
from src.prompt_security import untrusted_context_message
from core.exceptions import SessionNotFoundError
from src.auth_helpers import get_current_user
@@ -87,6 +89,46 @@ def _message_needs_tools(text: str) -> bool:
return any(p.search(text) for p in _TOOL_INTENT_PATTERNS)
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
if not session_url or not endpoint_base:
return False
sess = session_url.rstrip("/")
base = _normalize_base(endpoint_base).rstrip("/")
variants = {
base,
base + "/chat/completions",
build_chat_url(base).rstrip("/"),
}
return sess in variants or sess.startswith(base + "/")
def _clear_orphaned_session_endpoint(sess) -> bool:
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
if not getattr(sess, "endpoint_url", ""):
return False
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
return False
db_session = db.query(DBSession).filter(DBSession.id == sess.id).first()
if db_session:
db_session.endpoint_url = ""
db_session.model = ""
db_session.updated_at = datetime.utcnow()
db.commit()
sess.endpoint_url = ""
sess.model = ""
sess.headers = {}
return True
except Exception:
db.rollback()
return False
finally:
db.close()
def setup_chat_routes(
session_manager,
chat_handler,
@@ -121,6 +163,8 @@ def setup_chat_routes(
sess = session_manager.get_session(session)
except KeyError:
raise HTTPException(404, f"Session '{session}' not found")
if _clear_orphaned_session_endpoint(sess):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Same allowed_models + daily-cap gate as chat_stream (mirror so the
# non-streaming path can't be used to bypass).
@@ -259,6 +303,8 @@ def setup_chat_routes(
# but BEFORE loading. Prevents cross-user session hijack.
_verify_session_owner(request, session)
sess = session_manager.get_session(session)
if _clear_orphaned_session_endpoint(sess):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
except SessionNotFoundError as e:
raise HTTPException(404, str(e))
except (ValueError, ValidationError):
+167 -11
View File
@@ -6,12 +6,13 @@ import json
import time as _time
import logging
import httpx
from datetime import datetime
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from core.database import SessionLocal, ModelEndpoint
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
from core.middleware import require_admin
from src.llm_core import _detect_provider, ANTHROPIC_MODELS
from src.settings import load_settings as _load_settings, save_settings as _save_settings
@@ -301,6 +302,21 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
logger.warning(f"Failed to probe {url} with API key: {e}")
return []
logger.warning(f"Failed to probe {url}: {e}")
# Older Ollama builds and some proxies expose native /api/tags even when
# the OpenAI-compatible /v1/models path is unavailable.
try:
parsed = urlparse(base)
if parsed.port == 11434 or "ollama" in (parsed.hostname or "").lower():
root = base[:-3].rstrip("/") if base.endswith("/v1") else base
r = httpx.get(root + "/api/tags", timeout=timeout)
r.raise_for_status()
data = r.json()
models = [m.get("name") or m.get("model") for m in (data.get("models") or []) if m.get("name") or m.get("model")]
if models:
return models
except Exception as e:
logger.debug(f"Ollama /api/tags probe failed for {base}: {e}")
# Fall back to curated list if the provider has a URL-based match (e.g. z.ai has no /models endpoint)
curated_key = _match_provider_curated(base, None)
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
@@ -310,6 +326,51 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
return []
def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> Dict[str, Any]:
"""Reachability probe that does not require installed/listed models."""
from src.endpoint_resolver import resolve_url
base = resolve_url(_normalize_base(base_url))
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
url = base + "/models"
try:
r = httpx.get(url, headers=headers, timeout=timeout)
if 300 <= r.status_code < 400:
loc = r.headers.get("location", "")
if loc.startswith("/login") or "/login" in loc:
return {
"reachable": False,
"status_code": r.status_code,
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
}
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
if r.status_code < 500:
return {"reachable": r.status_code < 400, "status_code": r.status_code, "error": None if r.status_code < 400 else f"HTTP {r.status_code}"}
except Exception as e:
last_error = str(e)[:120]
else:
last_error = f"HTTP {r.status_code}"
try:
parsed = urlparse(base)
if parsed.port == 11434 or "ollama" in (parsed.hostname or "").lower():
root = base[:-3].rstrip("/") if base.endswith("/v1") else base
for path in ("/api/version", "/api/tags"):
try:
r = httpx.get(root + path, timeout=timeout)
if r.status_code < 400:
return {"reachable": True, "status_code": r.status_code, "error": None}
last_error = f"HTTP {r.status_code}"
except Exception as e:
last_error = str(e)[:120]
except Exception:
pass
return {"reachable": False, "status_code": None, "error": last_error}
def setup_model_routes(model_discovery):
router = APIRouter(prefix="/api")
@@ -549,15 +610,16 @@ def setup_model_routes(model_discovery):
db.close()
async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]:
url = base.rstrip("/") + "/models"
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
t0 = _time.time()
try:
async with httpx.AsyncClient(timeout=1.5) as client:
r = await client.get(url, headers=headers)
models = _probe_endpoint(base, api_key, timeout=2.5)
lat = round((_time.time() - t0) * 1000)
return {"alive": r.status_code < 400, "latency_ms": lat,
"status_code": r.status_code, "error": None if r.status_code < 400 else f"HTTP {r.status_code}"}
return {
"alive": bool(models),
"latency_ms": lat,
"status_code": 200 if models else None,
"error": None if models else "No models found",
}
except Exception as e:
return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]}
@@ -789,6 +851,12 @@ def setup_model_routes(model_discovery):
except Exception:
pass
visible = [m for m in all_models if m not in hidden]
status = "online" if all_models else "offline"
ping = None
if not all_models and r.is_enabled:
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"):
status = "empty"
results.append({
"id": r.id,
"name": r.name,
@@ -797,7 +865,9 @@ def setup_model_routes(model_discovery):
"is_enabled": r.is_enabled,
"models": visible,
"hidden_count": len(hidden),
"online": len(all_models) > 0,
"online": status != "offline",
"status": status,
"ping_error": (ping or {}).get("error") if ping else None,
"model_type": getattr(r, "model_type", None) or "llm",
"supports_tools": getattr(r, "supports_tools", None),
})
@@ -840,7 +910,11 @@ def setup_model_routes(model_discovery):
should_probe = require_model_list or not _truthy(skip_probe)
# Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh)
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=1) if should_probe else []
_probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else []
ping = {"reachable": False, "error": None}
if should_probe and not model_ids:
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout)
if require_model_list and not model_ids:
raise HTTPException(400, "No models found for that provider/key")
@@ -876,6 +950,7 @@ def setup_model_routes(model_discovery):
settings["default_model"] = model_ids[0] if model_ids else ""
_save_settings(settings)
_invalidate_models_cache()
_local_probe_cache["data"] = None
finally:
db.close()
@@ -883,8 +958,38 @@ def setup_model_routes(model_discovery):
return {
"id": ep_id,
"name": name.strip(),
"base_url": base_url,
"models": model_ids,
"online": len(model_ids) > 0,
"online": bool(model_ids) or bool(ping.get("reachable")),
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None,
}
@router.post("/model-endpoints/test")
def test_model_endpoint(
request: Request,
base_url: str = Form(...),
api_key: str = Form(""),
):
require_admin(request)
base_url = base_url.strip().rstrip("/")
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
if base_url.endswith(suffix):
base_url = base_url[:-len(suffix)].rstrip("/")
if not base_url:
raise HTTPException(400, "Base URL is required")
from src.endpoint_resolver import resolve_url
base_url = resolve_url(base_url)
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
return {
"base_url": base_url,
"online": bool(models) or bool(ping.get("reachable")),
"status": "online" if models else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None,
"models": models,
"count": len(models),
}
@router.get("/model-endpoints/{ep_id}/probe")
@@ -1175,6 +1280,49 @@ def setup_model_routes(model_discovery):
_save_settings(settings)
return cleared
def _session_uses_endpoint_url(session_url: str, base_url: str) -> bool:
if not session_url or not base_url:
return False
sess = session_url.rstrip("/")
base = _normalize_base(base_url).rstrip("/")
variants = {
base,
base + "/chat/completions",
build_chat_url(base).rstrip("/"),
}
return sess in variants or sess.startswith(base + "/")
def _clear_sessions_for_endpoint(db, base_url: str) -> int:
cleared = 0
rows = db.query(DbSession).filter(DbSession.endpoint_url.isnot(None)).all()
for row in rows:
if _session_uses_endpoint_url(row.endpoint_url or "", base_url):
row.endpoint_url = ""
row.model = ""
row.updated_at = datetime.utcnow()
cleared += 1
return cleared
def _clear_loaded_sessions_for_endpoint(base_url: str) -> int:
try:
from src.ai_interaction import get_session_manager
manager = get_session_manager()
except Exception:
manager = None
if not manager:
return 0
cleared = 0
try:
for sess in list(getattr(manager, "sessions", {}).values()):
if _session_uses_endpoint_url(getattr(sess, "endpoint_url", "") or "", base_url):
sess.endpoint_url = ""
sess.model = ""
sess.headers = {}
cleared += 1
except Exception:
return cleared
return cleared
@router.get("/model-endpoints/{ep_id}/dependents")
def get_endpoint_dependents(ep_id: str, request: Request):
"""Check which settings depend on this endpoint."""
@@ -1191,10 +1339,18 @@ def setup_model_routes(model_discovery):
raise HTTPException(404, "Endpoint not found")
# Clean up any settings that reference this endpoint
cleared = _clear_settings_for_endpoint(ep_id)
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
db.delete(ep)
db.commit()
_invalidate_models_cache()
return {"deleted": True, "cleared_settings": cleared}
_local_probe_cache["data"] = None
return {
"deleted": True,
"cleared_settings": cleared,
"cleared_sessions": cleared_sessions,
"cleared_loaded_sessions": cleared_loaded_sessions,
}
finally:
db.close()
+9 -1
View File
@@ -284,11 +284,19 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
db.close()
# Switch model/endpoint mid-session
if model is not None and endpoint_url is not None:
if endpoint_id:
from core.database import ModelEndpoint
_db = SessionLocal()
try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
if not ep:
raise HTTPException(400, "Model endpoint no longer exists")
finally:
_db.close()
session.model = model
session.endpoint_url = endpoint_url
# Update auth headers from the endpoint's stored API key
if endpoint_id:
from core.database import ModelEndpoint
_db = SessionLocal()
try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
+12 -2
View File
@@ -4,8 +4,6 @@ import asyncio
import json
import logging
import os
import pty
import fcntl
import shlex
import shutil
import uuid
@@ -13,6 +11,13 @@ import tempfile
from pathlib import Path
from typing import Dict, Any
try:
import fcntl
import pty
except ImportError:
fcntl = None
pty = None
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@@ -97,6 +102,11 @@ async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, An
async def _generate_pty(cmd: str, timeout: int, request: Request):
"""Run command in a pseudo-TTY so tqdm/progress bars work natively."""
if pty is None or fcntl is None:
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'PTY streaming is not available on Windows'})}\n\n"
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
return
loop = asyncio.get_event_loop()
master_fd, slave_fd = pty.openpty()