fix(models): stabilize proxy endpoint refresh behavior

* fix: support large proxy model endpoint refresh

Large OpenAI-compatible proxy endpoints can expose hundreds of models and make /v1/models slow. Treating those endpoints like local model servers caused model picker opens and background probes to repeatedly hit /models, producing timeouts and making otherwise usable endpoints appear offline.

Make model endpoint discovery cached-first for normal UI usage, add explicit proxy/API classification and refresh policy fields, exclude proxy/API endpoints from aggressive local probing, and preserve cached models when refresh fails.

Manual Test/Add/Refresh actions still fetch the full model list with longer timeouts so users can intentionally import large proxy model lists without blocking normal model picker usage.

* fix: preserve endpoint ping status semantics
This commit is contained in:
Yuri
2026-06-04 00:56:11 -03:00
committed by GitHub
parent eee2167502
commit a2e691da2b
10 changed files with 1323 additions and 231 deletions
+443 -164
View File
@@ -11,7 +11,7 @@ import httpx
from datetime import datetime
from typing import List, Dict, Any, Optional
from urllib.parse import urlparse, urlunparse
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request
from fastapi import APIRouter, HTTPException, Form, Query, Body, Request, Response
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
from core.database import SessionLocal, ModelEndpoint, Session as DbSession
@@ -335,6 +335,141 @@ def _truthy(value: str | None) -> bool:
return (value or "").strip().lower() in ("true", "1", "yes", "on")
_ENDPOINT_KINDS = {"auto", "local", "api", "proxy"}
_REFRESH_MODES = {"auto", "manual", "disabled"}
def _normalize_endpoint_kind(value: Any) -> str:
kind = str(value or "auto").strip().lower()
return kind if kind in _ENDPOINT_KINDS else "auto"
def _normalize_refresh_mode(value: Any, endpoint_kind: str = "auto") -> str:
mode = str(value or "").strip().lower()
kind = _normalize_endpoint_kind(endpoint_kind)
if mode in ("manual", "disabled"):
return mode
if mode == "auto" and kind != "proxy":
return "auto"
# Proxies default to manual cached-first behavior. Normal local/API
# endpoints keep automatic bounded refreshes.
return "manual" if kind == "proxy" else "auto"
def _endpoint_kind(ep: Any) -> str:
return _normalize_endpoint_kind(getattr(ep, "endpoint_kind", None))
def _endpoint_refresh_mode(ep: Any, endpoint_kind: str | None = None) -> str:
return _normalize_refresh_mode(getattr(ep, "model_refresh_mode", None), endpoint_kind or _endpoint_kind(ep))
def _endpoint_refresh_interval(ep: Any, category: str) -> float:
raw = getattr(ep, "model_refresh_interval", None)
try:
val = int(raw) if raw is not None else 0
except Exception:
val = 0
if val > 0:
return float(max(30, val))
return 60.0 if category == "local" else 3600.0
def _endpoint_refresh_timeout(ep: Any, category: str) -> float:
raw = getattr(ep, "model_refresh_timeout", None)
try:
val = int(raw) if raw is not None else 0
except Exception:
val = 0
if val > 0:
return float(max(1, min(30, val)))
return 2.5 if category == "local" else 2.0
def _manual_refresh_timeout(ep: Any, category: str, requested: Any = None) -> float:
"""Timeout for explicit user-triggered model-list refreshes.
Background refreshes stay short. A manual refresh is the one path where a
large proxy may legitimately need 15-30s to aggregate its catalog.
"""
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
if requested_val is not None:
return float(requested_val)
stored = _parse_positive_int(getattr(ep, "model_refresh_timeout", None), minimum=1, maximum=60)
if category == "local":
return float(stored) if stored is not None else _endpoint_refresh_timeout(ep, category)
return float(max(stored or 30, 30))
def _parse_model_list(raw: Any) -> List[str]:
"""Return a sanitized list of model ids from JSON/list/comma text."""
if raw is None:
return []
value = raw
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
if isinstance(parsed, list):
value = parsed
else:
value = re.split(r"[\n,]+", text)
except Exception:
value = re.split(r"[\n,]+", text)
if not isinstance(value, list):
return []
out = []
seen = set()
for item in value:
mid = str(item or "").strip()
if not mid or mid in seen:
continue
seen.add(mid)
out.append(mid)
return out
def _parse_positive_int(raw: Any, *, minimum: int = 1, maximum: int = 86400) -> Optional[int]:
try:
val = int(str(raw).strip())
except Exception:
return None
if val < minimum:
return None
return min(val, maximum)
def _explicit_model_list_timeout(base_url: str, endpoint_kind: str = "auto", requested: Any = None) -> float:
"""Timeout for explicit user-triggered model-list fetches during setup."""
requested_val = _parse_positive_int(requested, minimum=1, maximum=60)
if requested_val is not None:
return float(requested_val)
kind = _normalize_endpoint_kind(endpoint_kind)
category = _classify_endpoint(base_url, kind)
if kind in ("api", "proxy") or category == "api":
return 30.0
return 3.0 if _is_ollama_base(base_url) else 2.0
def _cached_model_ids(ep: Any) -> List[str]:
return _parse_model_list(getattr(ep, "cached_models", None))
def _hidden_model_ids(ep: Any) -> set:
return set(_parse_model_list(getattr(ep, "hidden_models", None)))
def _is_ollama_base(base_url: str) -> bool:
try:
parsed = urlparse(base_url)
host = (parsed.hostname or "").lower()
return parsed.port == 11434 or "ollama" in host
except Exception:
return "ollama" in (base_url or "").lower()
# Prefixes/substrings for models that are NOT chat-completions-capable
_NON_CHAT_PREFIXES = (
"dall-e", "tts-", "whisper", "text-embedding", "embedding",
@@ -441,10 +576,15 @@ _PRIVATE_PREFIXES = ("10.", "172.16.", "172.17.", "172.18.", "172.19.",
_TAILSCALE_RE = re.compile(r"^100\.(6[4-9]|[7-9]\d|1[01]\d|12[0-7])\.")
def _classify_endpoint(base_url: str) -> str:
def _classify_endpoint(base_url: str, endpoint_kind: str = "auto") -> str:
"""Return 'local' if the endpoint URL points to a private/local address, else 'api'.
Includes the Tailscale CGNAT range (100.64.0.0/10) so tailnet-hosted
servers (e.g. Cookbook serve endpoints) get reachability-probed too."""
kind = _normalize_endpoint_kind(endpoint_kind)
if kind == "local":
return "local"
if kind in ("api", "proxy"):
return "api"
try:
host = urlparse(base_url).hostname or ""
if host in _LOCAL_HOSTS or host.startswith(_PRIVATE_PREFIXES):
@@ -456,6 +596,21 @@ def _classify_endpoint(base_url: str) -> str:
return "api"
def _effective_endpoint_kind(ep: Any, base_url: str) -> str:
"""Return explicit kind, with a legacy proxy heuristic for keyed /v1 URLs."""
kind = _endpoint_kind(ep)
if kind != "auto":
return kind
if getattr(ep, "api_key", None) and not _is_ollama_base(base_url):
try:
path = (urlparse(base_url).path or "").rstrip("/")
if path.endswith("/v1") or "/openai" in path:
return "proxy"
except Exception:
pass
return "auto"
def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> List[str]:
"""Probe a base URL's /models endpoint and return list of model IDs.
@@ -546,30 +701,18 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
"""Reachability probe that does not require installed/listed models."""
from src.endpoint_resolver import resolve_url
base = resolve_url(_normalize_base(base_url))
headers = {}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
headers = build_headers(api_key, base)
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
# /api/tags. The OpenAI-style GET base + "/models" returns 404 when the
# base is the host root or the native /api root (e.g. http://localhost:11434,
# http://localhost:11434/api) because /models lives under /v1 there. Treat
# 4xx on a port-11434 / Ollama-named base as "try the native paths" rather
# than as a definitive offline verdict — Ollama is reachable, it just
# doesn't speak OpenAI on that prefix. Without this gate the quickstart
# marks an alive Ollama as offline whenever cached_models is empty (issue
# #1025): _probe_endpoint() falls through to /api/tags on the same 404, but
# _ping_endpoint() was returning before that fallback could run.
# /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
# /models as a generic health check because large proxy catalogs can be slow.
parsed_base = urlparse(base)
looks_like_ollama = (
parsed_base.port == 11434
or "ollama" in (parsed_base.hostname or "").lower()
)
url = base + "/models"
last_error: Optional[str] = None
try:
r = httpx.get(url, headers=headers, timeout=timeout)
def _result_from_response(r) -> Dict[str, Any]:
if 300 <= r.status_code < 400:
loc = r.headers.get("location", "")
if loc.startswith("/login") or "/login" in loc:
@@ -579,13 +722,15 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
"error": "That is Odysseus, not a model server. Use the Ollama URL, usually http://host.docker.internal:11434/v1 in Docker.",
}
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code} redirect"}
if r.status_code < 400:
return {"reachable": True, "status_code": r.status_code, "error": None}
if r.status_code < 500 and not looks_like_ollama:
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
last_error = f"HTTP {r.status_code}"
except Exception as e:
last_error = str(e)[:120]
if 200 <= r.status_code < 300:
return {
"reachable": True,
"status_code": r.status_code,
"error": None,
}
return {"reachable": False, "status_code": r.status_code, "error": f"HTTP {r.status_code}"}
last_error: Optional[str] = None
try:
if looks_like_ollama:
@@ -597,14 +742,21 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
for path in ("/api/version", "/api/tags"):
try:
r = httpx.get(root + path, timeout=timeout)
if r.status_code < 400:
return {"reachable": True, "status_code": r.status_code, "error": None}
last_error = f"HTTP {r.status_code}"
result = _result_from_response(r)
if result["reachable"]:
return result
last_error = result.get("error")
except Exception as e:
last_error = str(e)[:120]
except Exception:
pass
try:
r = httpx.get(base, headers=headers, timeout=timeout)
return _result_from_response(r)
except Exception as e:
last_error = str(e)[:120]
return {"reachable": False, "status_code": None, "error": last_error}
@@ -715,17 +867,71 @@ def setup_model_routes(model_discovery):
flip)."""
_models_cache.clear()
# Track endpoints that have failed recently so we back off probing dead ones.
_probe_failures = {} # ep_id → (last_fail_ts, consecutive_fails)
# Track model-list refreshes by URL+key. This prevents repeated picker/API
# opens from starting duplicate /models probes, and gives slow/offline
# providers a cooldown after failures.
_refresh_state: Dict[str, Dict[str, Any]] = {}
_refresh_inflight = {"v": False} # coarse single-flight guard
_REFRESH_FAILURE_BASE = 300.0
_REFRESH_FAILURE_MAX = 3600.0
def _refresh_caches_bg():
"""Background thread: re-probe all endpoints in PARALLEL with a tight
timeout, skipping endpoints that have been failing repeatedly.
def _refresh_key(base: str, api_key: Optional[str]) -> str:
return f"{base.rstrip('/')}\x00{api_key or ''}"
Was the cause of gradual server degradation: sequential 3s-timeout
probes against many endpoints (some offline) tied up the threadpool
for 15-30s every cache cycle, eventually exhausting it."""
def _ts(value: Any) -> float:
try:
return float(value.timestamp()) if value else 0.0
except Exception:
return 0.0
def _failure_delay(fails: int) -> float:
if fails <= 0:
return 0.0
return min(_REFRESH_FAILURE_BASE * (2 ** max(0, fails - 1)), _REFRESH_FAILURE_MAX)
def _should_refresh_endpoint(ep: Any, now: float, force: bool = False) -> tuple[bool, Dict[str, Any]]:
base = _normalize_base(getattr(ep, "base_url", "") or "")
kind = _effective_endpoint_kind(ep, base)
category = _classify_endpoint(base, kind)
mode = _endpoint_refresh_mode(ep, kind)
cached = _cached_model_ids(ep)
key = _refresh_key(base, getattr(ep, "api_key", None))
state = _refresh_state.get(key, {})
info = {
"id": getattr(ep, "id", ""),
"base": base,
"api_key": getattr(ep, "api_key", None),
"kind": kind,
"category": category,
"mode": mode,
"key": key,
"timeout": _endpoint_refresh_timeout(ep, category),
}
if not base:
return False, info
if state.get("inflight"):
return False, info
if mode in ("manual", "disabled") and not force:
return False, info
fails = int(state.get("fail_count") or 0)
if fails and not force:
last_failure = float(state.get("last_failure") or 0.0)
if now - last_failure < _failure_delay(fails):
return False, info
if cached and not force:
interval = _endpoint_refresh_interval(ep, category)
last_good = float(state.get("last_success") or 0.0) or _ts(getattr(ep, "updated_at", None)) or _ts(getattr(ep, "created_at", None))
if last_good and now - last_good < interval:
return False, info
return True, info
def _refresh_caches_bg(force: bool = False):
"""Background thread: safely refresh model caches with per-base single-flight.
The public /api/models path stays cached-first. This refresh never clears
a non-empty cached model list on timeout/failure, and proxy/manual
endpoints are skipped unless explicitly forced."""
import threading
if _refresh_inflight["v"]:
return # already running
@@ -735,44 +941,63 @@ def setup_model_routes(model_discovery):
try:
from concurrent.futures import ThreadPoolExecutor, as_completed
db = SessionLocal()
changed = False
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
# Skip endpoints that have failed 3+ times in a row in the last 5 min
now = _time.time()
to_probe = []
groups: Dict[str, Dict[str, Any]] = {}
for ep in endpoints:
ts, fails = _probe_failures.get(ep.id, (0, 0))
if fails >= 3 and (now - ts) < 300:
ok, info = _should_refresh_endpoint(ep, now, force=force)
if not ok:
continue
to_probe.append(ep)
groups.setdefault(info["key"], {
"base": info["base"],
"api_key": info["api_key"],
"timeout": info["timeout"],
"endpoint_ids": [],
})["endpoint_ids"].append(info["id"])
def _probe_one(ep):
base = _normalize_base(ep.base_url)
for key in groups:
st = _refresh_state.setdefault(key, {})
st["inflight"] = True
st["last_attempt"] = now
def _probe_one(key: str, data: Dict[str, Any]):
try:
ids = _probe_endpoint(base, ep.api_key, timeout=2)
return ep, ids, None
ids = _probe_endpoint(data["base"], data.get("api_key"), timeout=data.get("timeout") or 2)
return key, data["endpoint_ids"], ids, None
except Exception as e:
return ep, None, e
return key, data["endpoint_ids"], None, e
if to_probe:
# Bounded parallelism — 8 concurrent probes is plenty
with ThreadPoolExecutor(max_workers=min(8, len(to_probe))) as pool:
futures = [pool.submit(_probe_one, ep) for ep in to_probe]
if groups:
with ThreadPoolExecutor(max_workers=min(4, len(groups))) as pool:
futures = [pool.submit(_probe_one, key, data) for key, data in groups.items()]
for fut in as_completed(futures):
ep, ids, err = fut.result()
key, endpoint_ids, ids, err = fut.result()
st = _refresh_state.setdefault(key, {})
if ids:
ep.cached_models = json.dumps(ids)
_probe_failures.pop(ep.id, None)
for ep_id in endpoint_ids:
ep_obj = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if ep_obj:
ep_obj.cached_models = json.dumps(ids)
changed = True
st["last_success"] = _time.time()
st["fail_count"] = 0
st.pop("last_failure", None)
else:
prev = _probe_failures.get(ep.id, (0, 0))
_probe_failures[ep.id] = (_time.time(), prev[1] + 1)
st["last_failure"] = _time.time()
st["fail_count"] = int(st.get("fail_count") or 0) + 1
st["inflight"] = False
db.commit()
finally:
db.close()
_invalidate_models_cache()
if changed:
_invalidate_models_cache()
except Exception:
pass
finally:
for st in _refresh_state.values():
st["inflight"] = False
_refresh_inflight["v"] = False
threading.Thread(target=_do, daemon=True).start()
@@ -804,24 +1029,15 @@ def setup_model_routes(model_discovery):
base = _normalize_base(ep.base_url)
provider = _detect_provider(base)
# Use cached models — background refresh keeps them updated
model_ids = []
if ep.cached_models:
try:
model_ids = json.loads(ep.cached_models)
except Exception:
pass
model_ids = _cached_model_ids(ep)
ep_model_type = getattr(ep, "model_type", None) or "llm"
# Filter out hidden (probe-failed) models
hidden = set()
if ep.hidden_models:
try:
hidden = set(json.loads(ep.hidden_models))
except Exception:
pass
hidden = _hidden_model_ids(ep)
model_ids = [m for m in model_ids if m not in hidden]
# Build correct URL based on provider
chat_url = build_chat_url(base)
category = _classify_endpoint(base)
kind = _effective_endpoint_kind(ep, base)
category = _classify_endpoint(base, kind)
if model_ids:
curated_key = _match_provider_curated(base, None)
@@ -837,6 +1053,7 @@ def setup_model_routes(model_discovery):
"endpoint_id": ep.id,
"endpoint_name": ep.name,
"category": category,
"endpoint_kind": kind,
"model_type": ep_model_type,
})
else:
@@ -852,6 +1069,7 @@ def setup_model_routes(model_discovery):
"endpoint_id": ep.id,
"endpoint_name": ep.name,
"category": category,
"endpoint_kind": kind,
"model_type": ep_model_type,
"offline": True,
})
@@ -898,11 +1116,11 @@ def setup_model_routes(model_discovery):
result = _fetch_models(owner=owner, is_admin=_is_admin)
_models_cache[_cache_key] = {"data": result, "time": now}
# Kick off background refresh to update caches from live endpoints
_refresh_caches_bg()
_refresh_caches_bg(force=refresh)
return result
# Brief cache for local-probe results so picker-open doesn't hammer
# /v1/models every time. 8s TTL — long enough to amortize cost,
# endpoint health checks every time. 8s TTL — long enough to amortize cost,
# short enough that a freshly-killed local server shows as offline
# within ~8s of the user noticing.
_LOCAL_PROBE_TTL = 8.0
@@ -912,7 +1130,7 @@ def setup_model_routes(model_discovery):
async def probe_local_endpoints(request: Request):
"""Fast parallel reachability check for LOCAL endpoints only.
Cloud endpoints (api.openai.com, api.anthropic.com, etc.) are
assumed up. Local endpoints get a 1.5s /models probe so the UI
assumed up. Local endpoints get a 1.5s cheap reachability probe so the UI
can dim stale entries pointing at dead vLLM servers. Returns
{ep_id: {alive, latency_ms, error}}."""
require_admin(request)
@@ -924,36 +1142,44 @@ def setup_model_routes(model_discovery):
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
local_eps = [
(ep.id, _normalize_base(ep.base_url), ep.api_key)
for ep in endpoints
if _classify_endpoint(_normalize_base(ep.base_url)) == "local"
]
local_eps = []
for ep in endpoints:
base = _normalize_base(ep.base_url)
kind = _effective_endpoint_kind(ep, base)
if _classify_endpoint(base, kind) == "local":
local_eps.append((ep.id, base, ep.api_key))
finally:
db.close()
async def _probe_one(ep_id: str, base: str, api_key: Optional[str]) -> Dict[str, Any]:
grouped: Dict[str, Dict[str, Any]] = {}
for ep_id, base, api_key in local_eps:
key = _refresh_key(base, api_key)
grouped.setdefault(key, {"base": base, "api_key": api_key, "endpoint_ids": []})["endpoint_ids"].append(ep_id)
async def _probe_one(data: Dict[str, Any]) -> Dict[str, Any]:
t0 = _time.time()
try:
models = _probe_endpoint(base, api_key, timeout=2.5)
import asyncio as _asyncio
ping = await _asyncio.to_thread(_ping_endpoint, data["base"], data.get("api_key"), 1.5)
lat = round((_time.time() - t0) * 1000)
return {
"alive": bool(models),
"alive": bool(ping.get("reachable")),
"latency_ms": lat,
"status_code": 200 if models else None,
"error": None if models else "No models found",
"status_code": ping.get("status_code"),
"error": ping.get("error"),
}
except Exception as e:
return {"alive": False, "latency_ms": None, "status_code": None, "error": str(e)[:120]}
import asyncio as _asyncio
results_list = await _asyncio.gather(
*[_probe_one(eid, base, key) for eid, base, key in local_eps],
*[_probe_one(data) for data in grouped.values()],
return_exceptions=False,
)
results: Dict[str, Any] = {}
for (eid, _, _), r in zip(local_eps, results_list):
results[eid] = r
for data, r in zip(grouped.values(), results_list):
for eid in data["endpoint_ids"]:
results[eid] = r
_local_probe_cache["data"] = results
_local_probe_cache["time"] = now
@@ -973,50 +1199,28 @@ def setup_model_routes(model_discovery):
for ep in endpoints:
base = _normalize_base(ep.base_url)
provider = _detect_provider(base)
kind = _effective_endpoint_kind(ep, base)
cached_count = len(_cached_model_ids(ep))
entry = {
"id": ep.id,
"name": ep.name,
"base_url": base,
"provider": provider,
"category": _classify_endpoint(base),
"category": _classify_endpoint(base, kind),
"endpoint_kind": kind,
}
if provider == "anthropic":
# Anthropic has no /models endpoint; just check connectivity
try:
t0 = _time.time()
r = httpx.get(base.rstrip("/"), timeout=5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
entry["status"] = "online"
entry["model_count"] = len(ANTHROPIC_MODELS)
except Exception as e:
entry["latency_ms"] = None
entry["status"] = "offline"
entry["error"] = str(e)
entry["model_count"] = 0
else:
url = build_models_url(base)
headers = build_headers(ep.api_key, base)
try:
t0 = _time.time()
r = httpx.get(url, headers=headers, timeout=5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
models = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
entry["status"] = "online"
entry["model_count"] = len(models)
except Exception as e:
if "latency_ms" not in entry:
entry["latency_ms"] = None
entry["status"] = "offline"
entry["error"] = str(e)
entry["model_count"] = 0
try:
t0 = _time.time()
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
entry["error"] = ping.get("error")
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
except Exception as e:
entry["latency_ms"] = None
entry["status"] = "online" if cached_count else "offline"
entry["error"] = str(e)
entry["model_count"] = cached_count
results.append(entry)
return {"endpoints": results}
@@ -1165,19 +1369,8 @@ def setup_model_routes(model_discovery):
rows = db.query(ModelEndpoint).order_by(ModelEndpoint.created_at).all()
results = []
for r in rows:
# Use cached model list to avoid slow probe on every load
all_models = []
if r.cached_models:
try:
all_models = json.loads(r.cached_models)
except Exception:
pass
hidden = set()
if r.hidden_models:
try:
hidden = set(json.loads(r.hidden_models))
except Exception:
pass
all_models = _cached_model_ids(r)
hidden = _hidden_model_ids(r)
pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
visible = _visible_models(all_models, r.hidden_models, pinned)
# Endpoint counts as reachable if it has any model — including
@@ -1188,6 +1381,8 @@ def setup_model_routes(model_discovery):
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"):
status = "empty"
base = _normalize_base(r.base_url)
kind = _effective_endpoint_kind(r, base)
results.append({
"id": r.id,
"name": r.name,
@@ -1202,6 +1397,11 @@ def setup_model_routes(model_discovery):
"ping_error": (ping or {}).get("error") if ping else None,
"model_type": getattr(r, "model_type", None) or "llm",
"supports_tools": getattr(r, "supports_tools", None),
"endpoint_kind": kind,
"category": _classify_endpoint(base, kind),
"model_refresh_mode": _endpoint_refresh_mode(r, kind),
"model_refresh_interval": getattr(r, "model_refresh_interval", None),
"model_refresh_timeout": getattr(r, "model_refresh_timeout", None),
})
return results
finally:
@@ -1216,6 +1416,10 @@ def setup_model_routes(model_discovery):
skip_probe: str = Form("false"),
require_models: str = Form("false"),
model_type: str = Form("llm"),
endpoint_kind: str = Form("auto"),
model_refresh_mode: str = Form(""),
model_refresh_interval: str = Form(""),
model_refresh_timeout: str = Form(""),
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
container_local: str = Form("false"),
@@ -1240,8 +1444,15 @@ def setup_model_routes(model_discovery):
if not name.strip():
name = base_url.replace("http://", "").replace("https://", "").split("/")[0]
requested_kind = _normalize_endpoint_kind(endpoint_kind)
refresh_mode = _normalize_refresh_mode(model_refresh_mode, requested_kind)
refresh_interval = _parse_positive_int(model_refresh_interval, minimum=30, maximum=86400)
refresh_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
require_model_list = _truthy(require_models)
should_probe = require_model_list or not _truthy(skip_probe)
should_probe = (
require_model_list or requested_kind in ("api", "proxy") or not _truthy(skip_probe)
)
explicit_timeout = _explicit_model_list_timeout(base_url, requested_kind, refresh_timeout)
# Dedupe: if an endpoint with the same base_url already exists and
# is reachable by the caller (shared or owned by them), return it
@@ -1259,6 +1470,7 @@ def setup_model_routes(model_discovery):
.first()
)
if existing:
changed = False
# Persist any incoming pinned IDs onto the existing row. An
# empty/omitted form field must not wipe previously pinned IDs.
_incoming_pinned = _normalize_model_ids(pinned_models)
@@ -1268,15 +1480,45 @@ def setup_model_routes(model_discovery):
_incoming_pinned,
)
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
changed = True
existing_kind_for_probe = requested_kind if requested_kind != "auto" else _effective_endpoint_kind(existing, base_url)
if requested_kind != "auto" and _endpoint_kind(existing) == "auto":
existing.endpoint_kind = requested_kind
changed = True
if model_refresh_mode or (requested_kind == "proxy" and _endpoint_refresh_mode(existing, requested_kind) != refresh_mode):
existing.model_refresh_mode = refresh_mode
changed = True
if refresh_interval is not None:
existing.model_refresh_interval = refresh_interval
changed = True
if refresh_timeout is not None:
existing.model_refresh_timeout = refresh_timeout
changed = True
if api_key.strip() and not existing.api_key:
existing.api_key = api_key.strip()
changed = True
if should_probe:
probed_models = _probe_endpoint(
base_url,
(api_key.strip() or existing.api_key or None),
timeout=_explicit_model_list_timeout(base_url, existing_kind_for_probe, refresh_timeout),
)
if probed_models:
existing.cached_models = json.dumps(probed_models)
changed = True
if changed:
_db_dedup.commit()
_invalidate_models_cache()
_local_probe_cache["data"] = None
existing_models = _cached_model_ids(existing)
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
existing_kind = _effective_endpoint_kind(existing, existing.base_url)
return {
"id": existing.id,
"name": existing.name,
"base_url": existing.base_url,
"models": _visible_models(
getattr(existing, "cached_models", None),
existing_models,
getattr(existing, "hidden_models", None),
existing.pinned_models,
),
@@ -1284,16 +1526,16 @@ def setup_model_routes(model_discovery):
"online": True,
"status": "online",
"existing": True,
"endpoint_kind": existing_kind,
"category": _classify_endpoint(existing.base_url, existing_kind),
}
finally:
_db_dedup.close()
# Quick model list fetch (1s timeout — if endpoint is slow, it'll update on next refresh)
_probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 1
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout) if should_probe else []
model_ids = _probe_endpoint(base_url, api_key.strip() or None, timeout=explicit_timeout) if should_probe else []
ping = {"reachable": False, "error": None}
if should_probe and not model_ids:
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=_probe_timeout)
if (should_probe or requested_kind in ("api", "proxy")) and not model_ids:
ping = _ping_endpoint(base_url, api_key.strip() or None, timeout=min(explicit_timeout, 2.0))
if require_model_list and not model_ids:
raise HTTPException(400, _model_endpoint_error_message(base_url, ping))
@@ -1317,6 +1559,10 @@ def setup_model_routes(model_discovery):
api_key=api_key.strip() or None,
is_enabled=True,
model_type=model_type.strip() if model_type else "llm",
endpoint_kind=requested_kind,
model_refresh_mode=refresh_mode,
model_refresh_interval=refresh_interval,
model_refresh_timeout=refresh_timeout,
cached_models=json.dumps(model_ids) if model_ids else None,
pinned_models=json.dumps(_pinned) if _pinned else None,
supports_tools=_st,
@@ -1349,6 +1595,8 @@ def setup_model_routes(model_discovery):
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None,
"endpoint_kind": requested_kind,
"category": _classify_endpoint(base_url, requested_kind),
}
@router.post("/model-endpoints/test")
@@ -1356,6 +1604,8 @@ def setup_model_routes(model_discovery):
request: Request,
base_url: str = Form(...),
api_key: str = Form(""),
endpoint_kind: str = Form("auto"),
model_refresh_timeout: str = Form(""),
):
require_admin(request)
base_url = _normalize_base(base_url)
@@ -1364,9 +1614,11 @@ def setup_model_routes(model_discovery):
from src.endpoint_resolver import resolve_url
base_url = resolve_url(base_url)
base_url = _rewrite_loopback_for_docker(base_url)
probe_timeout = 3 if (":11434" in base_url or "ollama" in base_url.lower()) else 2
requested_kind = _normalize_endpoint_kind(endpoint_kind)
configured_timeout = _parse_positive_int(model_refresh_timeout, minimum=1, maximum=60)
probe_timeout = _explicit_model_list_timeout(base_url, requested_kind, configured_timeout)
models = _probe_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=probe_timeout)
ping = {"reachable": True, "error": None} if models else _ping_endpoint(base_url, api_key.strip() or None, timeout=min(probe_timeout, 2.0))
return {
"base_url": base_url,
"online": bool(models) or bool(ping.get("reachable")),
@@ -1374,6 +1626,8 @@ def setup_model_routes(model_discovery):
"ping_error": ping.get("error") if ping else None,
"models": models,
"count": len(models),
"endpoint_kind": requested_kind,
"category": _classify_endpoint(base_url, requested_kind),
}
@router.get("/model-endpoints/{ep_id}/probe")
@@ -1415,7 +1669,8 @@ def setup_model_routes(model_discovery):
ep_obj = db2.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if ep_obj:
ep_obj.hidden_models = json.dumps(failed) if failed else None
ep_obj.cached_models = json.dumps(all_models) if all_models else None
if all_models:
ep_obj.cached_models = json.dumps(all_models)
db2.commit()
finally:
db2.close()
@@ -1426,7 +1681,13 @@ def setup_model_routes(model_discovery):
return StreamingResponse(_stream(), media_type="text/event-stream")
@router.get("/model-endpoints/{ep_id}/models")
def list_endpoint_models(ep_id: str, request: Request):
def list_endpoint_models(
ep_id: str,
request: Request,
response: Response,
refresh: bool = False,
refresh_timeout: Optional[int] = Query(None, ge=1, le=60),
):
"""List all discovered models for an endpoint with hidden/visible state."""
require_admin(request)
db = SessionLocal()
@@ -1434,23 +1695,28 @@ def setup_model_routes(model_discovery):
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if not ep:
raise HTTPException(404, "Endpoint not found")
hidden = set()
if ep.hidden_models:
hidden = _hidden_model_ids(ep)
all_models = _cached_model_ids(ep)
if refresh:
base = _normalize_base(ep.base_url)
kind = _effective_endpoint_kind(ep, base)
category = _classify_endpoint(base, kind)
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
try:
hidden = set(json.loads(ep.hidden_models))
except Exception:
pass
# Try live probe, fall back to cached. Pinned IDs are admin-entered
# and persist regardless of probe results — never overwritten here.
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
if all_models:
ep.cached_models = json.dumps(all_models)
db.commit()
elif ep.cached_models:
try:
all_models = json.loads(ep.cached_models)
except Exception:
pass
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
except Exception as exc:
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
probed = []
if probed:
all_models = probed
ep.cached_models = json.dumps(all_models)
db.commit()
_invalidate_models_cache()
response.headers["X-Model-Refresh-Status"] = "refreshed"
response.headers["X-Model-Refresh-Count"] = str(len(probed))
else:
response.headers["X-Model-Refresh-Status"] = "failed"
response.headers["X-Model-Refresh-Warning"] = "Model refresh failed or returned no models; kept cached models."
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
pinned_set = set(pinned)
return [
@@ -1502,7 +1768,6 @@ def setup_model_routes(model_discovery):
@router.get("/default-chat")
def get_default_chat(request: Request):
import json as _json
# SECURITY: resolve the default endpoint + model from the CALLER's
# per-user prefs ONLY. We deliberately do NOT fall back to the
# global `default_model` / `default_endpoint_id` in settings.json
@@ -1635,6 +1900,16 @@ def setup_model_routes(model_discovery):
if "pinned_models" in body:
_pinned = _normalize_model_ids(body["pinned_models"])
ep.pinned_models = json.dumps(_pinned) if _pinned else None
if "endpoint_kind" in body:
ep.endpoint_kind = _normalize_endpoint_kind(body.get("endpoint_kind"))
if "model_refresh_mode" in body:
ep.model_refresh_mode = _normalize_refresh_mode(body.get("model_refresh_mode"), _endpoint_kind(ep))
if "model_refresh_interval" in body:
interval = _parse_positive_int(body.get("model_refresh_interval"), minimum=30, maximum=86400)
ep.model_refresh_interval = interval
if "model_refresh_timeout" in body:
timeout = _parse_positive_int(body.get("model_refresh_timeout"), minimum=1, maximum=60)
ep.model_refresh_timeout = timeout
# Rotating an API key used to require DELETE+POST, which wiped
# endpoint_url/model from every session referencing the old base
# URL. Allow in-place updates so the admin can change the key
@@ -1664,6 +1939,10 @@ def setup_model_routes(model_discovery):
"model_type": ep.model_type,
"base_url": ep.base_url,
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
"endpoint_kind": getattr(ep, "endpoint_kind", None) or "auto",
"model_refresh_mode": getattr(ep, "model_refresh_mode", None) or "auto",
"model_refresh_interval": getattr(ep, "model_refresh_interval", None),
"model_refresh_timeout": getattr(ep, "model_refresh_timeout", None),
}
finally:
db.close()