mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-16 01:35:36 -04:00
fix(models): stabilize proxy endpoint refresh behavior
* fix: support large proxy model endpoint refresh Large OpenAI-compatible proxy endpoints can expose hundreds of models and make /v1/models slow. Treating those endpoints like local model servers caused model picker opens and background probes to repeatedly hit /models, producing timeouts and making otherwise usable endpoints appear offline. Make model endpoint discovery cached-first for normal UI usage, add explicit proxy/API classification and refresh policy fields, exclude proxy/API endpoints from aggressive local probing, and preserve cached models when refresh fails. Manual Test/Add/Refresh actions still fetch the full model list with longer timeouts so users can intentionally import large proxy model lists without blocking normal model picker usage. * fix: preserve endpoint ping status semantics
This commit is contained in:
+443
-164
@@ -11,7 +11,7 @@ import httpx
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
|
||||
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request, Response
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import StreamingResponse
|
||||
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
|
||||
@@ -335,6 +335,141 @@ def _truthy(value: str | None) -> bool:
|
||||
return (value or "").strip().lower() in ("true", "1", "yes", "on")
|
||||
|
||||
|
||||
_ENDPOINT_KINDS = {"auto", "local", "api", "proxy"}
|
||||
_REFRESH_MODES = {"auto", "manual", "disabled"}
|
||||
|
||||
|
||||
def _normalize_endpoint_kind(value: Any) -> str:
|
||||
kind = str(value or "auto").strip().lower()
|
||||
return kind if kind in _ENDPOINT_KINDS else "auto"
|
||||
|
||||
|
||||
def _normalize_refresh_mode(value: Any, endpoint_kind: str = "auto") -> str:
|
||||
mode = str(value or "").strip().lower()
|
||||
kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
if mode in ("manual", "disabled"):
|
||||
return mode
|
||||
if mode == "auto" and kind != "proxy":
|
||||
return "auto"
|
||||
# Proxies default to manual cached-first behavior. Normal local/API
|
||||
# endpoints keep automatic bounded refreshes.
|
||||
return "manual" if kind == "proxy" else "auto"
|
||||
|
||||
|
||||
def _endpoint_kind(ep: Any) -> str:
|
||||
return _normalize_endpoint_kind(getattr(ep, "endpoint_kind", None))
|
||||
|
||||
|
||||
def _endpoint_refresh_mode(ep: Any, endpoint_kind: str | None = None) -> str:
|
||||
return _normalize_refresh_mode(getattr(ep, "model_refresh_mode", None), endpoint_kind or _endpoint_kind(ep))
|
||||
|
||||
|
||||
def _endpoint_refresh_interval(ep: Any, category: str) -> float:
|
||||
raw = getattr(ep, "model_refresh_interval", None)
|
||||
try:
|
||||
val = int(raw) if raw is not None else 0
|
||||
except Exception:
|
||||
val = 0
|
||||
if val > 0:
|
||||
return float(max(30, val))
|
||||
return 60.0 if category == "local" else 3600.0
|
||||
|
||||
|
||||
def _endpoint_refresh_timeout(ep: Any, category: str) -> float:
|
||||
raw = getattr(ep, "model_refresh_timeout", None)
|
||||
try:
|
||||
val = int(raw) if raw is not None else 0
|
||||
except Exception:
|
||||
val = 0
|
||||
if val > 0:
|
||||
return float(max(1, min(30, val)))
|
||||
return 2.5 if category == "local" else 2.0
|
||||
|
||||
|
||||
def _manual_refresh_timeout(ep: Any, category: str, requested: Any = None) -> float:
|
||||
"""Timeout for explicit user-triggered model-list refreshes.
|
||||
|
||||
Background refreshes stay short. A manual refresh is the one path where a
|
||||
large proxy may legitimately need 15-30s to aggregate its catalog.
|
||||
"""
|
||||
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
|
||||
if requested_val is not None:
|
||||
return float(requested_val)
|
||||
stored = _parse_positive_int(getattr(ep, "model_refresh_timeout", None), minimum=1, maximum=60)
|
||||
if category == "local":
|
||||
return float(stored) if stored is not None else _endpoint_refresh_timeout(ep, category)
|
||||
return float(max(stored or 30, 30))
|
||||
|
||||
|
||||
def _parse_model_list(raw: Any) -> List[str]:
|
||||
"""Return a sanitized list of model ids from JSON/list/comma text."""
|
||||
if raw is None:
|
||||
return []
|
||||
value = raw
|
||||
if isinstance(value, str):
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, list):
|
||||
value = parsed
|
||||
else:
|
||||
value = re.split(r"[\n,]+", text)
|
||||
except Exception:
|
||||
value = re.split(r"[\n,]+", text)
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
out = []
|
||||
seen = set()
|
||||
for item in value:
|
||||
mid = str(item or "").strip()
|
||||
if not mid or mid in seen:
|
||||
continue
|
||||
seen.add(mid)
|
||||
out.append(mid)
|
||||
return out
|
||||
|
||||
|
||||
def _parse_positive_int(raw: Any, *, minimum: int = 1, maximum: int = 86400) -> Optional[int]:
|
||||
try:
|
||||
val = int(str(raw).strip())
|
||||
except Exception:
|
||||
return None
|
||||
if val < minimum:
|
||||
return None
|
||||
return min(val, maximum)
|
||||
|
||||
|
||||
def _explicit_model_list_timeout(base_url: str, endpoint_kind: str = "auto", requested: Any = None) -> float:
|
||||
"""Timeout for explicit user-triggered model-list fetches during setup."""
|
||||
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
|
||||
if requested_val is not None:
|
||||
return float(requested_val)
|
||||
kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
category = _classify_endpoint(base_url, kind)
|
||||
if kind in ("api", "proxy") or category == "api":
|
||||
return 30.0
|
||||
return 3.0 if _is_ollama_base(base_url) else 2.0
|
||||
|
||||
|
||||
def _cached_model_ids(ep: Any) -> List[str]:
|
||||
return _parse_model_list(getattr(ep, "cached_models", None))
|
||||
|
||||
|
||||
def _hidden_model_ids(ep: Any) -> set:
|
||||
return set(_parse_model_list(getattr(ep, "hidden_models", None)))
|
||||
|
||||
|
||||
def _is_ollama_base(base_url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(base_url)
|
||||
host = (parsed.hostname or "").lower()
|
||||
return parsed.port == 11434 or "ollama" in host
|
||||
except Exception:
|
||||
return "ollama" in (base_url or "").lower()
|
||||
|
||||
|
||||
# Prefixes/substrings for models that are NOT chat-completions-capable
|
||||
_NON_CHAT_PREFIXES = (
|
||||
"dall-e", "tts-", "whisper", "text-embedding", "embedding",
|
||||
@@ -441,10 +576,15 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
|
||||
_TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.")
|
||||
|
||||
|
||||
def _classify_endpoint(base_url: str) -> str:
|
||||
def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str:
|
||||
"""Return 'local' if the endpoint URL points to a private/local address, else 'api'.
|
||||
Includes the Tailscale CGNAT range (100.64.0.0/10) so tailnet-hosted
|
||||
servers (e.g. Cookbook serve endpoints) get reachability-probed too."""
|
||||
kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
if kind == "local":
|
||||
return "local"
|
||||
if kind in ("api", "proxy"):
|
||||
return "api"
|
||||
try:
|
||||
host = urlparse(base_url).hostname or ""
|
||||
if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES):
|
||||
@@ -456,6 +596,21 @@ def _classify_endpoint(base_url: str) -> str:
|
||||
return "api"
|
||||
|
||||
|
||||
def _effective_endpoint_kind(ep: Any, base_url: str) -> str:
|
||||
"""Return explicit kind, with a legacy proxy heuristic for keyed /v1 URLs."""
|
||||
kind = _endpoint_kind(ep)
|
||||
if kind != "auto":
|
||||
return kind
|
||||
if getattr(ep, "api_key", None) and not _is_ollama_base(base_url):
|
||||
try:
|
||||
path = (urlparse(base_url).path or "").rstrip("/")
|
||||
if path.endswith("/v1") or "/openai" in path:
|
||||
return "proxy"
|
||||
except Exception:
|
||||
pass
|
||||
return "auto"
|
||||
|
||||
|
||||
|
||||
def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> List[str]:
|
||||
"""Probe a base URL's /models endpoint and return list of model IDs.
|
||||
@@ -546,30 +701,18 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
"""Reachability probe that does not require installed/listed models."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
|
||||
# /api/tags. The OpenAI-style GET base + "/models" returns 404 when the
|
||||
# base is the host root or the native /api root (e.g. http://localhost:11434,
|
||||
# http://localhost:11434/api) because /models lives under /v1 there. Treat
|
||||
# 4xx on a port-11434 / Ollama-named base as "try the native paths" rather
|
||||
# than as a definitive offline verdict — Ollama is reachable, it just
|
||||
# doesn't speak OpenAI on that prefix. Without this gate the quickstart
|
||||
# marks an alive Ollama as offline whenever cached_models is empty (issue
|
||||
# #1025): _probe_endpoint() falls through to /api/tags on the same 404, but
|
||||
# _ping_endpoint() was returning before that fallback could run.
|
||||
# /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
|
||||
# /models as a generic health check because large proxy catalogs can be slow.
|
||||
parsed_base = urlparse(base)
|
||||
looks_like_ollama = (
|
||||
parsed_base.port == 11434
|
||||
or "ollama" in (parsed_base.hostname or "").lower()
|
||||
)
|
||||
|
||||
url = base + "/models"
|
||||
last_error: Optional[str] = None
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout)
|
||||
def _result_from_response(r) -> Dict[str, Any]:
|
||||
if 300 <= r.status_code < 400:
|
||||
loc = r.headers.get("location", "")
|
||||
if loc.startswith("/login") or "/login" in loc:
|
||||
@@ -579,13 +722,15 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
|
||||
}
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
|
||||
if r.status_code < 400:
|
||||
return {"reachable": True, "status_code": r.status_code, "error": None}
|
||||
if r.status_code < 500 and not looks_like_ollama:
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
if 200 <= r.status_code < 300:
|
||||
return {
|
||||
"reachable": True,
|
||||
"status_code": r.status_code,
|
||||
"error": None,
|
||||
}
|
||||
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
|
||||
|
||||
last_error: Optional[str] = None
|
||||
|
||||
try:
|
||||
if looks_like_ollama:
|
||||
@@ -597,14 +742,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
||||
for path in ("/api/version", "/api/tags"):
|
||||
try:
|
||||
r = httpx.get(root + path, timeout=timeout)
|
||||
if r.status_code < 400:
|
||||
return {"reachable": True, "status_code": r.status_code, "error": None}
|
||||
last_error = f"HTTP {r.status_code}"
|
||||
result = _result_from_response(r)
|
||||
if result["reachable"]:
|
||||
return result
|
||||
last_error = result.get("error")
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
r = httpx.get(base, headers=headers, timeout=timeout)
|
||||
return _result_from_response(r)
|
||||
except Exception as e:
|
||||
last_error = str(e)[:120]
|
||||
|
||||
return {"reachable": False, "status_code": None, "error": last_error}
|
||||
|
||||
|
||||
@@ -715,17 +867,71 @@ def setup_model_routes(model_discovery):
|
||||
flip)."""
|
||||
_models_cache.clear()
|
||||
|
||||
# Track endpoints that have failed recently so we back off probing dead ones.
|
||||
_probe_failures = {} # ep_id → (last_fail_ts, consecutive_fails)
|
||||
# Track model-list refreshes by URL+key. This prevents repeated picker/API
|
||||
# opens from starting duplicate /models probes, and gives slow/offline
|
||||
# providers a cooldown after failures.
|
||||
_refresh_state: Dict[str, Dict[str, Any]] = {}
|
||||
_refresh_inflight = {"v": False} # coarse single-flight guard
|
||||
_REFRESH_FAILURE_BASE = 300.0
|
||||
_REFRESH_FAILURE_MAX = 3600.0
|
||||
|
||||
def _refresh_caches_bg():
|
||||
"""Background thread: re-probe all endpoints in PARALLEL with a tight
|
||||
timeout, skipping endpoints that have been failing repeatedly.
|
||||
def _refresh_key(base: str, api_key: Optional[str]) -> str:
|
||||
return f"{base.rstrip('/')}\x00{api_key or ''}"
|
||||
|
||||
Was the cause of gradual server degradation: sequential 3s-timeout
|
||||
probes against many endpoints (some offline) tied up the threadpool
|
||||
for 15-30s every cache cycle, eventually exhausting it."""
|
||||
def _ts(value: Any) -> float:
|
||||
try:
|
||||
return float(value.timestamp()) if value else 0.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _failure_delay(fails: int) -> float:
|
||||
if fails <= 0:
|
||||
return 0.0
|
||||
return min(_REFRESH_FAILURE_BASE * (2 ** max(0, fails - 1)), _REFRESH_FAILURE_MAX)
|
||||
|
||||
def _should_refresh_endpoint(ep: Any, now: float, force: bool = False) -> tuple[bool, Dict[str, Any]]:
|
||||
base = _normalize_base(getattr(ep, "base_url", "") or "")
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
category = _classify_endpoint(base, kind)
|
||||
mode = _endpoint_refresh_mode(ep, kind)
|
||||
cached = _cached_model_ids(ep)
|
||||
key = _refresh_key(base, getattr(ep, "api_key", None))
|
||||
state = _refresh_state.get(key, {})
|
||||
|
||||
info = {
|
||||
"id": getattr(ep, "id", ""),
|
||||
"base": base,
|
||||
"api_key": getattr(ep, "api_key", None),
|
||||
"kind": kind,
|
||||
"category": category,
|
||||
"mode": mode,
|
||||
"key": key,
|
||||
"timeout": _endpoint_refresh_timeout(ep, category),
|
||||
}
|
||||
if not base:
|
||||
return False, info
|
||||
if state.get("inflight"):
|
||||
return False, info
|
||||
if mode in ("manual", "disabled") and not force:
|
||||
return False, info
|
||||
fails = int(state.get("fail_count") or 0)
|
||||
if fails and not force:
|
||||
last_failure = float(state.get("last_failure") or 0.0)
|
||||
if now - last_failure < _failure_delay(fails):
|
||||
return False, info
|
||||
if cached and not force:
|
||||
interval = _endpoint_refresh_interval(ep, category)
|
||||
last_good = float(state.get("last_success") or 0.0) or _ts(getattr(ep, "updated_at", None)) or _ts(getattr(ep, "created_at", None))
|
||||
if last_good and now - last_good < interval:
|
||||
return False, info
|
||||
return True, info
|
||||
|
||||
def _refresh_caches_bg(force: bool = False):
|
||||
"""Background thread: safely refresh model caches with per-base single-flight.
|
||||
|
||||
The public /api/models path stays cached-first. This refresh never clears
|
||||
a non-empty cached model list on timeout/failure, and proxy/manual
|
||||
endpoints are skipped unless explicitly forced."""
|
||||
import threading
|
||||
if _refresh_inflight["v"]:
|
||||
return # already running
|
||||
@@ -735,44 +941,63 @@ def setup_model_routes(model_discovery):
|
||||
try:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
db = SessionLocal()
|
||||
changed = False
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
# Skip endpoints that have failed 3+ times in a row in the last 5 min
|
||||
now = _time.time()
|
||||
to_probe = []
|
||||
groups: Dict[str, Dict[str, Any]] = {}
|
||||
for ep in endpoints:
|
||||
ts, fails = _probe_failures.get(ep.id, (0, 0))
|
||||
if fails >= 3 and (now - ts) < 300:
|
||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||
if not ok:
|
||||
continue
|
||||
to_probe.append(ep)
|
||||
groups.setdefault(info["key"], {
|
||||
"base": info["base"],
|
||||
"api_key": info["api_key"],
|
||||
"timeout": info["timeout"],
|
||||
"endpoint_ids": [],
|
||||
})["endpoint_ids"].append(info["id"])
|
||||
|
||||
def _probe_one(ep):
|
||||
base = _normalize_base(ep.base_url)
|
||||
for key in groups:
|
||||
st = _refresh_state.setdefault(key, {})
|
||||
st["inflight"] = True
|
||||
st["last_attempt"] = now
|
||||
|
||||
def _probe_one(key: str, data: Dict[str, Any]):
|
||||
try:
|
||||
ids = _probe_endpoint(base, ep.api_key, timeout=2)
|
||||
return ep, ids, None
|
||||
ids = _probe_endpoint(data["base"], data.get("api_key"), timeout=data.get("timeout") or 2)
|
||||
return key, data["endpoint_ids"], ids, None
|
||||
except Exception as e:
|
||||
return ep, None, e
|
||||
return key, data["endpoint_ids"], None, e
|
||||
|
||||
if to_probe:
|
||||
# Bounded parallelism — 8 concurrent probes is plenty
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(to_probe))) as pool:
|
||||
futures = [pool.submit(_probe_one, ep) for ep in to_probe]
|
||||
if groups:
|
||||
with ThreadPoolExecutor(max_workers=min(4, len(groups))) as pool:
|
||||
futures = [pool.submit(_probe_one, key, data) for key, data in groups.items()]
|
||||
for fut in as_completed(futures):
|
||||
ep, ids, err = fut.result()
|
||||
key, endpoint_ids, ids, err = fut.result()
|
||||
st = _refresh_state.setdefault(key, {})
|
||||
if ids:
|
||||
ep.cached_models = json.dumps(ids)
|
||||
_probe_failures.pop(ep.id, None)
|
||||
for ep_id in endpoint_ids:
|
||||
ep_obj = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep_obj:
|
||||
ep_obj.cached_models = json.dumps(ids)
|
||||
changed = True
|
||||
st["last_success"] = _time.time()
|
||||
st["fail_count"] = 0
|
||||
st.pop("last_failure", None)
|
||||
else:
|
||||
prev = _probe_failures.get(ep.id, (0, 0))
|
||||
_probe_failures[ep.id] = (_time.time(), prev[1] + 1)
|
||||
st["last_failure"] = _time.time()
|
||||
st["fail_count"] = int(st.get("fail_count") or 0) + 1
|
||||
st["inflight"] = False
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
_invalidate_models_cache()
|
||||
if changed:
|
||||
_invalidate_models_cache()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
for st in _refresh_state.values():
|
||||
st["inflight"] = False
|
||||
_refresh_inflight["v"] = False
|
||||
threading.Thread(target=_do, daemon=True).start()
|
||||
|
||||
@@ -804,24 +1029,15 @@ def setup_model_routes(model_discovery):
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
# Use cached models — background refresh keeps them updated
|
||||
model_ids = []
|
||||
if ep.cached_models:
|
||||
try:
|
||||
model_ids = json.loads(ep.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
model_ids = _cached_model_ids(ep)
|
||||
ep_model_type = getattr(ep, "model_type", None) or "llm"
|
||||
# Filter out hidden (probe-failed) models
|
||||
hidden = set()
|
||||
if ep.hidden_models:
|
||||
try:
|
||||
hidden = set(json.loads(ep.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
hidden = _hidden_model_ids(ep)
|
||||
model_ids = [m for m in model_ids if m not in hidden]
|
||||
# Build correct URL based on provider
|
||||
chat_url = build_chat_url(base)
|
||||
category = _classify_endpoint(base)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
category = _classify_endpoint(base, kind)
|
||||
|
||||
if model_ids:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
@@ -837,6 +1053,7 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_id": ep.id,
|
||||
"endpoint_name": ep.name,
|
||||
"category": category,
|
||||
"endpoint_kind": kind,
|
||||
"model_type": ep_model_type,
|
||||
})
|
||||
else:
|
||||
@@ -852,6 +1069,7 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_id": ep.id,
|
||||
"endpoint_name": ep.name,
|
||||
"category": category,
|
||||
"endpoint_kind": kind,
|
||||
"model_type": ep_model_type,
|
||||
"offline": True,
|
||||
})
|
||||
@@ -898,11 +1116,11 @@ def setup_model_routes(model_discovery):
|
||||
result = _fetch_models(owner=owner, is_admin=_is_admin)
|
||||
_models_cache[_cache_key] = {"data": result, "time": now}
|
||||
# Kick off background refresh to update caches from live endpoints
|
||||
_refresh_caches_bg()
|
||||
_refresh_caches_bg(force=refresh)
|
||||
return result
|
||||
|
||||
# Brief cache for local-probe results so picker-open doesn't hammer
|
||||
# /v1/models every time. 8s TTL — long enough to amortize cost,
|
||||
# endpoint health checks every time. 8s TTL — long enough to amortize cost,
|
||||
# short enough that a freshly-killed local server shows as offline
|
||||
# within ~8s of the user noticing.
|
||||
_LOCAL_PROBE_TTL = 8.0
|
||||
@@ -912,7 +1130,7 @@ def setup_model_routes(model_discovery):
|
||||
async def probe_local_endpoints(request: Request):
|
||||
"""Fast parallel reachability check for LOCAL endpoints only.
|
||||
Cloud endpoints (api.openai.com, api.anthropic.com, etc.) are
|
||||
assumed up. Local endpoints get a 1.5s /models probe so the UI
|
||||
assumed up. Local endpoints get a 1.5s cheap reachability probe so the UI
|
||||
can dim stale entries pointing at dead vLLM servers. Returns
|
||||
{ep_id: {alive, latency_ms, error}}."""
|
||||
require_admin(request)
|
||||
@@ -924,36 +1142,44 @@ def setup_model_routes(model_discovery):
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
local_eps = [
|
||||
(ep.id, _normalize_base(ep.base_url), ep.api_key)
|
||||
for ep in endpoints
|
||||
if _classify_endpoint(_normalize_base(ep.base_url)) == "local"
|
||||
]
|
||||
local_eps = []
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
if _classify_endpoint(base, kind) == "local":
|
||||
local_eps.append((ep.id, base, ep.api_key))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]:
|
||||
grouped: Dict[str, Dict[str, Any]] = {}
|
||||
for ep_id, base, api_key in local_eps:
|
||||
key = _refresh_key(base, api_key)
|
||||
grouped.setdefault(key, {"base": base, "api_key": api_key, "endpoint_ids": []})["endpoint_ids"].append(ep_id)
|
||||
|
||||
async def _probe_one(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
t0 = _time.time()
|
||||
try:
|
||||
models = _probe_endpoint(base, api_key, timeout=2.5)
|
||||
import asyncio as _asyncio
|
||||
ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 1.5)
|
||||
lat = round((_time.time() - t0) * 1000)
|
||||
return {
|
||||
"alive": bool(models),
|
||||
"alive": bool(ping.get("reachable")),
|
||||
"latency_ms": lat,
|
||||
"status_code": 200 if models else None,
|
||||
"error": None if models else "No models found",
|
||||
"status_code": ping.get("status_code"),
|
||||
"error": ping.get("error"),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]}
|
||||
|
||||
import asyncio as _asyncio
|
||||
results_list = await _asyncio.gather(
|
||||
*[_probe_one(eid, base, key) for eid, base, key in local_eps],
|
||||
*[_probe_one(data) for data in grouped.values()],
|
||||
return_exceptions=False,
|
||||
)
|
||||
results: Dict[str, Any] = {}
|
||||
for (eid, _, _), r in zip(local_eps, results_list):
|
||||
results[eid] = r
|
||||
for data, r in zip(grouped.values(), results_list):
|
||||
for eid in data["endpoint_ids"]:
|
||||
results[eid] = r
|
||||
|
||||
_local_probe_cache["data"] = results
|
||||
_local_probe_cache["time"] = now
|
||||
@@ -973,50 +1199,28 @@ def setup_model_routes(model_discovery):
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
provider = _detect_provider(base)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
cached_count = len(_cached_model_ids(ep))
|
||||
entry = {
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": base,
|
||||
"provider": provider,
|
||||
"category": _classify_endpoint(base),
|
||||
"category": _classify_endpoint(base, kind),
|
||||
"endpoint_kind": kind,
|
||||
}
|
||||
if provider == "anthropic":
|
||||
# Anthropic has no /models endpoint; just check connectivity
|
||||
try:
|
||||
t0 = _time.time()
|
||||
r = httpx.get(base.rstrip("/"), timeout=5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online"
|
||||
entry["model_count"] = len(ANTHROPIC_MODELS)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "offline"
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = 0
|
||||
else:
|
||||
url = build_models_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
try:
|
||||
t0 = _time.time()
|
||||
r = httpx.get(url, headers=headers, timeout=5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
entry["status"] = "online"
|
||||
entry["model_count"] = len(models)
|
||||
except Exception as e:
|
||||
if "latency_ms" not in entry:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "offline"
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = 0
|
||||
try:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
entry["error"] = str(e)
|
||||
entry["model_count"] = cached_count
|
||||
results.append(entry)
|
||||
|
||||
return {"endpoints": results}
|
||||
@@ -1165,19 +1369,8 @@ def setup_model_routes(model_discovery):
|
||||
rows = db.query(ModelEndpoint).order_by(ModelEndpoint.created_at).all()
|
||||
results = []
|
||||
for r in rows:
|
||||
# Use cached model list to avoid slow probe on every load
|
||||
all_models = []
|
||||
if r.cached_models:
|
||||
try:
|
||||
all_models = json.loads(r.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
hidden = set()
|
||||
if r.hidden_models:
|
||||
try:
|
||||
hidden = set(json.loads(r.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
all_models = _cached_model_ids(r)
|
||||
hidden = _hidden_model_ids(r)
|
||||
pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
|
||||
visible = _visible_models(all_models, r.hidden_models, pinned)
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
@@ -1188,6 +1381,8 @@ def setup_model_routes(model_discovery):
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
base = _normalize_base(r.base_url)
|
||||
kind = _effective_endpoint_kind(r, base)
|
||||
results.append({
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
@@ -1202,6 +1397,11 @@ def setup_model_routes(model_discovery):
|
||||
"ping_error": (ping or {}).get("error") if ping else None,
|
||||
"model_type": getattr(r, "model_type", None) or "llm",
|
||||
"supports_tools": getattr(r, "supports_tools", None),
|
||||
"endpoint_kind": kind,
|
||||
"category": _classify_endpoint(base, kind),
|
||||
"model_refresh_mode": _endpoint_refresh_mode(r, kind),
|
||||
"model_refresh_interval": getattr(r, "model_refresh_interval", None),
|
||||
"model_refresh_timeout": getattr(r, "model_refresh_timeout", None),
|
||||
})
|
||||
return results
|
||||
finally:
|
||||
@@ -1216,6 +1416,10 @@ def setup_model_routes(model_discovery):
|
||||
skip_probe: str = Form("false"),
|
||||
require_models: str = Form("false"),
|
||||
model_type: str = Form("llm"),
|
||||
endpoint_kind: str = Form("auto"),
|
||||
model_refresh_mode: str = Form(""),
|
||||
model_refresh_interval: str = Form(""),
|
||||
model_refresh_timeout: str = Form(""),
|
||||
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
|
||||
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
|
||||
container_local: str = Form("false"),
|
||||
@@ -1240,8 +1444,15 @@ def setup_model_routes(model_discovery):
|
||||
if not name.strip():
|
||||
name = base_url.replace("http://", "").replace("https://", "").split("/")[0]
|
||||
|
||||
requested_kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
refresh_mode = _normalize_refresh_mode(model_refresh_mode, requested_kind)
|
||||
refresh_interval = _parse_positive_int(model_refresh_interval, minimum=30, maximum=86400)
|
||||
refresh_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
|
||||
require_model_list = _truthy(require_models)
|
||||
should_probe = require_model_list or not _truthy(skip_probe)
|
||||
should_probe = (
|
||||
require_model_list or requested_kind in ("api", "proxy") or not _truthy(skip_probe)
|
||||
)
|
||||
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
|
||||
|
||||
# Dedupe: if an endpoint with the same base_url already exists and
|
||||
# is reachable by the caller (shared or owned by them), return it
|
||||
@@ -1259,6 +1470,7 @@ def setup_model_routes(model_discovery):
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
changed = False
|
||||
# Persist any incoming pinned IDs onto the existing row. An
|
||||
# empty/omitted form field must not wipe previously pinned IDs.
|
||||
_incoming_pinned = _normalize_model_ids(pinned_models)
|
||||
@@ -1268,15 +1480,45 @@ def setup_model_routes(model_discovery):
|
||||
_incoming_pinned,
|
||||
)
|
||||
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
|
||||
changed = True
|
||||
existing_kind_for_probe = requested_kind if requested_kind != "auto" else _effective_endpoint_kind(existing, base_url)
|
||||
if requested_kind != "auto" and _endpoint_kind(existing) == "auto":
|
||||
existing.endpoint_kind = requested_kind
|
||||
changed = True
|
||||
if model_refresh_mode or (requested_kind == "proxy" and _endpoint_refresh_mode(existing, requested_kind) != refresh_mode):
|
||||
existing.model_refresh_mode = refresh_mode
|
||||
changed = True
|
||||
if refresh_interval is not None:
|
||||
existing.model_refresh_interval = refresh_interval
|
||||
changed = True
|
||||
if refresh_timeout is not None:
|
||||
existing.model_refresh_timeout = refresh_timeout
|
||||
changed = True
|
||||
if api_key.strip() and not existing.api_key:
|
||||
existing.api_key = api_key.strip()
|
||||
changed = True
|
||||
if should_probe:
|
||||
probed_models = _probe_endpoint(
|
||||
base_url,
|
||||
(api_key.strip() or existing.api_key or None),
|
||||
timeout=_explicit_model_list_timeout(base_url, existing_kind_for_probe, refresh_timeout),
|
||||
)
|
||||
if probed_models:
|
||||
existing.cached_models = json.dumps(probed_models)
|
||||
changed = True
|
||||
if changed:
|
||||
_db_dedup.commit()
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
existing_models = _cached_model_ids(existing)
|
||||
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
|
||||
existing_kind = _effective_endpoint_kind(existing, existing.base_url)
|
||||
return {
|
||||
"id": existing.id,
|
||||
"name": existing.name,
|
||||
"base_url": existing.base_url,
|
||||
"models": _visible_models(
|
||||
getattr(existing, "cached_models", None),
|
||||
existing_models,
|
||||
getattr(existing, "hidden_models", None),
|
||||
existing.pinned_models,
|
||||
),
|
||||
@@ -1284,16 +1526,16 @@ def setup_model_routes(model_discovery):
|
||||
"online": True,
|
||||
"status": "online",
|
||||
"existing": True,
|
||||
"endpoint_kind": existing_kind,
|
||||
"category": _classify_endpoint(existing.base_url, existing_kind),
|
||||
}
|
||||
finally:
|
||||
_db_dedup.close()
|
||||
|
||||
# Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh)
|
||||
_probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1
|
||||
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else []
|
||||
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=explicit_timeout) if should_probe else []
|
||||
ping = {"reachable": False, "error": None}
|
||||
if should_probe and not model_ids:
|
||||
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout)
|
||||
if (should_probe or requested_kind in ("api", "proxy")) and not model_ids:
|
||||
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=min(explicit_timeout, 2.0))
|
||||
if require_model_list and not model_ids:
|
||||
raise HTTPException(400, _model_endpoint_error_message(base_url, ping))
|
||||
|
||||
@@ -1317,6 +1559,10 @@ def setup_model_routes(model_discovery):
|
||||
api_key=api_key.strip() or None,
|
||||
is_enabled=True,
|
||||
model_type=model_type.strip() if model_type else "llm",
|
||||
endpoint_kind=requested_kind,
|
||||
model_refresh_mode=refresh_mode,
|
||||
model_refresh_interval=refresh_interval,
|
||||
model_refresh_timeout=refresh_timeout,
|
||||
cached_models=json.dumps(model_ids) if model_ids else None,
|
||||
pinned_models=json.dumps(_pinned) if _pinned else None,
|
||||
supports_tools=_st,
|
||||
@@ -1349,6 +1595,8 @@ def setup_model_routes(model_discovery):
|
||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
"endpoint_kind": requested_kind,
|
||||
"category": _classify_endpoint(base_url, requested_kind),
|
||||
}
|
||||
|
||||
@router.post("/model-endpoints/test")
|
||||
@@ -1356,6 +1604,8 @@ def setup_model_routes(model_discovery):
|
||||
request: Request,
|
||||
base_url: str = Form(...),
|
||||
api_key: str = Form(""),
|
||||
endpoint_kind: str = Form("auto"),
|
||||
model_refresh_timeout: str = Form(""),
|
||||
):
|
||||
require_admin(request)
|
||||
base_url = _normalize_base(base_url)
|
||||
@@ -1364,9 +1614,11 @@ def setup_model_routes(model_discovery):
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base_url = resolve_url(base_url)
|
||||
base_url = _rewrite_loopback_for_docker(base_url)
|
||||
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
|
||||
requested_kind = _normalize_endpoint_kind(endpoint_kind)
|
||||
configured_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
|
||||
probe_timeout = _explicit_model_list_timeout(base_url, requested_kind, configured_timeout)
|
||||
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
|
||||
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=min(probe_timeout, 2.0))
|
||||
return {
|
||||
"base_url": base_url,
|
||||
"online": bool(models) or bool(ping.get("reachable")),
|
||||
@@ -1374,6 +1626,8 @@ def setup_model_routes(model_discovery):
|
||||
"ping_error": ping.get("error") if ping else None,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
"endpoint_kind": requested_kind,
|
||||
"category": _classify_endpoint(base_url, requested_kind),
|
||||
}
|
||||
|
||||
@router.get("/model-endpoints/{ep_id}/probe")
|
||||
@@ -1415,7 +1669,8 @@ def setup_model_routes(model_discovery):
|
||||
ep_obj = db2.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep_obj:
|
||||
ep_obj.hidden_models = json.dumps(failed) if failed else None
|
||||
ep_obj.cached_models = json.dumps(all_models) if all_models else None
|
||||
if all_models:
|
||||
ep_obj.cached_models = json.dumps(all_models)
|
||||
db2.commit()
|
||||
finally:
|
||||
db2.close()
|
||||
@@ -1426,7 +1681,13 @@ def setup_model_routes(model_discovery):
|
||||
return StreamingResponse(_stream(), media_type="text/event-stream")
|
||||
|
||||
@router.get("/model-endpoints/{ep_id}/models")
|
||||
def list_endpoint_models(ep_id: str, request: Request):
|
||||
def list_endpoint_models(
|
||||
ep_id: str,
|
||||
request: Request,
|
||||
response: Response,
|
||||
refresh: bool = False,
|
||||
refresh_timeout: Optional[int] = Query(None, ge=1, le=60),
|
||||
):
|
||||
"""List all discovered models for an endpoint with hidden/visible state."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
@@ -1434,23 +1695,28 @@ def setup_model_routes(model_discovery):
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
hidden = set()
|
||||
if ep.hidden_models:
|
||||
hidden = _hidden_model_ids(ep)
|
||||
all_models = _cached_model_ids(ep)
|
||||
if refresh:
|
||||
base = _normalize_base(ep.base_url)
|
||||
kind = _effective_endpoint_kind(ep, base)
|
||||
category = _classify_endpoint(base, kind)
|
||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||
try:
|
||||
hidden = set(json.loads(ep.hidden_models))
|
||||
except Exception:
|
||||
pass
|
||||
# Try live probe, fall back to cached. Pinned IDs are admin-entered
|
||||
# and persist regardless of probe results — never overwritten here.
|
||||
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
|
||||
if all_models:
|
||||
ep.cached_models = json.dumps(all_models)
|
||||
db.commit()
|
||||
elif ep.cached_models:
|
||||
try:
|
||||
all_models = json.loads(ep.cached_models)
|
||||
except Exception:
|
||||
pass
|
||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
||||
except Exception as exc:
|
||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||
probed = []
|
||||
if probed:
|
||||
all_models = probed
|
||||
ep.cached_models = json.dumps(all_models)
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
response.headers["X-Model-Refresh-Status"] = "refreshed"
|
||||
response.headers["X-Model-Refresh-Count"] = str(len(probed))
|
||||
else:
|
||||
response.headers["X-Model-Refresh-Status"] = "failed"
|
||||
response.headers["X-Model-Refresh-Warning"] = "Model refresh failed or returned no models; kept cached models."
|
||||
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
|
||||
pinned_set = set(pinned)
|
||||
return [
|
||||
@@ -1502,7 +1768,6 @@ def setup_model_routes(model_discovery):
|
||||
|
||||
@router.get("/default-chat")
|
||||
def get_default_chat(request: Request):
|
||||
import json as _json
|
||||
# SECURITY: resolve the default endpoint + model from the CALLER's
|
||||
# per-user prefs ONLY. We deliberately do NOT fall back to the
|
||||
# global `default_model` / `default_endpoint_id` in settings.json
|
||||
@@ -1635,6 +1900,16 @@ def setup_model_routes(model_discovery):
|
||||
if "pinned_models" in body:
|
||||
_pinned = _normalize_model_ids(body["pinned_models"])
|
||||
ep.pinned_models = json.dumps(_pinned) if _pinned else None
|
||||
if "endpoint_kind" in body:
|
||||
ep.endpoint_kind = _normalize_endpoint_kind(body.get("endpoint_kind"))
|
||||
if "model_refresh_mode" in body:
|
||||
ep.model_refresh_mode = _normalize_refresh_mode(body.get("model_refresh_mode"), _endpoint_kind(ep))
|
||||
if "model_refresh_interval" in body:
|
||||
interval = _parse_positive_int(body.get("model_refresh_interval"), minimum=30, maximum=86400)
|
||||
ep.model_refresh_interval = interval
|
||||
if "model_refresh_timeout" in body:
|
||||
timeout = _parse_positive_int(body.get("model_refresh_timeout"), minimum=1, maximum=60)
|
||||
ep.model_refresh_timeout = timeout
|
||||
# Rotating an API key used to require DELETE+POST, which wiped
|
||||
# endpoint_url/model from every session referencing the old base
|
||||
# URL. Allow in-place updates so the admin can change the key
|
||||
@@ -1664,6 +1939,10 @@ def setup_model_routes(model_discovery):
|
||||
"model_type": ep.model_type,
|
||||
"base_url": ep.base_url,
|
||||
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
|
||||
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
|
||||
"model_refresh_interval": getattr(ep, "model_refresh_interval", None),
|
||||
"model_refresh_timeout": getattr(ep, "model_refresh_timeout", None),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user