feat(models): support pinned endpoint model IDs

This commit is contained in:
Alexandre Teixeira
2026-06-03 13:00:07 +01:00
committed by GitHub
parent 1284b14a13
commit 145f4fd2b4
3 changed files with 493 additions and 25 deletions
+20
View File
@@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base):
is_enabled = Column(Boolean, default=True) is_enabled = Column(Boolean, default=True)
hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list) cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
model_type = Column(String, nullable=True, default="llm") # "llm" or "image" model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
# Whether models on this endpoint accept OpenAI-style function # Whether models on this endpoint accept OpenAI-style function
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto- # schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
@@ -856,6 +857,24 @@ def _migrate_add_cached_models_column():
except Exception as e: except Exception as e:
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}") logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
def _migrate_add_pinned_models_column():
"""Add pinned_models column to model_endpoints if it doesn't exist."""
import sqlite3
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
columns = [row[1] for row in cursor.fetchall()]
if columns and "pinned_models" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
def _migrate_add_notes_sort_order(): def _migrate_add_notes_sort_order():
"""Add sort_order, image_url, repeat columns to notes if they don't exist.""" """Add sort_order, image_url, repeat columns to notes if they don't exist."""
import sqlite3 import sqlite3
@@ -1511,6 +1530,7 @@ def init_db():
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
_migrate_add_hidden_models_column() _migrate_add_hidden_models_column()
_migrate_add_cached_models_column() _migrate_add_cached_models_column()
_migrate_add_pinned_models_column()
_migrate_add_notes_sort_order() _migrate_add_notes_sort_order()
_migrate_add_model_type_column() _migrate_add_model_type_column()
_migrate_add_model_endpoint_owner_column() _migrate_add_model_endpoint_owner_column()
+126 -22
View File
@@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
return "No models found for that provider/key." return "No models found for that provider/key."
def _visible_models(cached_models, hidden_models): def _normalize_model_ids(value):
"""Filter cached model IDs by hidden_models. Returns list of visible IDs.""" """Coerce a model-ID input into a clean, ordered list of strings.
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
Accepts a list, a JSON-encoded list string, or a comma/newline separated
string (handy for form or backend API input). Trims whitespace, drops
empty and non-string values, and de-duplicates preserving first-seen order.
"""
if value is None:
return []
items = value
if isinstance(value, str):
text = value.strip()
if not text:
return []
try:
parsed = json.loads(text)
except Exception:
parsed = None
items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text)
if not isinstance(items, list):
return []
out, seen = [], set()
for item in items:
if not isinstance(item, str):
continue
s = item.strip()
if not s or s in seen:
continue
seen.add(s)
out.append(s)
return out
def _merge_model_ids(*lists):
"""Concatenate model-ID lists, de-duplicating and preserving order."""
out, seen = [], set()
for ids in lists:
for m in (ids or []):
if not isinstance(m, str) or m in seen:
continue
seen.add(m)
out.append(m)
return out
def _visible_models(cached_models, hidden_models, pinned_models=None):
"""Merge cached + pinned model IDs, then filter out hidden ones.
Pinned IDs are admin-entered and may not appear in cached_models (e.g.
cloud deployment IDs the provider does not list in /v1/models). Returns an
ordered, de-duplicated list of visible IDs.
"""
# Normalize each input so JSON strings, lists, comma/newline strings, and
# malformed strings are all handled without raising.
merged = _merge_model_ids(
_normalize_model_ids(cached_models),
_normalize_model_ids(pinned_models),
)
if not hidden_models: if not hidden_models:
return all_models return merged
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or [])) hidden = set(_normalize_model_ids(hidden_models))
return [m for m in all_models if m not in hidden] return [m for m in merged if m not in hidden]
def setup_model_routes(model_discovery): def setup_model_routes(model_discovery):
@@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery):
hidden = set(json.loads(r.hidden_models)) hidden = set(json.loads(r.hidden_models))
except Exception: except Exception:
pass pass
visible = [m for m in all_models if m not in hidden] pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
status = "online" if all_models else "offline" visible = _visible_models(all_models, r.hidden_models, pinned)
# Endpoint counts as reachable if it has any model — including
# admin-pinned IDs that a probe would never surface.
status = "online" if (all_models or pinned) else "offline"
ping = None ping = None
if not all_models and r.is_enabled: if not all_models and not pinned and r.is_enabled:
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"): if ping.get("reachable"):
status = "empty" status = "empty"
@@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery):
"has_key": bool(r.api_key), "has_key": bool(r.api_key),
"is_enabled": r.is_enabled, "is_enabled": r.is_enabled,
"models": visible, "models": visible,
"pinned_models": pinned,
"hidden_count": len(hidden), "hidden_count": len(hidden),
"online": status != "offline", "online": status != "offline",
"status": status, "status": status,
@@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery):
require_models: str = Form("false"), require_models: str = Form("false"),
model_type: str = Form("llm"), model_type: str = Form("llm"),
supports_tools: str = Form(""), # "true"/"false"/"" (unknown) supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
container_local: str = Form("false"), container_local: str = Form("false"),
# Default `shared=true` → endpoints are visible to all users (the # Default `shared=true` → endpoints are visible to all users (the
# app's historical behaviour). Admins can pass `shared=false` to # app's historical behaviour). Admins can pass `shared=false` to
@@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery):
.first() .first()
) )
if existing: if existing:
# Persist any incoming pinned IDs onto the existing row. An
# empty/omitted form field must not wipe previously pinned IDs.
_incoming_pinned = _normalize_model_ids(pinned_models)
if _incoming_pinned:
_merged_pinned = _merge_model_ids(
_normalize_model_ids(getattr(existing, "pinned_models", None)),
_incoming_pinned,
)
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
_db_dedup.commit()
_invalidate_models_cache()
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
return { return {
"id": existing.id, "id": existing.id,
"name": existing.name, "name": existing.name,
"base_url": existing.base_url, "base_url": existing.base_url,
"models": json.loads(existing.cached_models) if existing.cached_models else [], "models": _visible_models(
getattr(existing, "cached_models", None),
getattr(existing, "hidden_models", None),
existing.pinned_models,
),
"pinned_models": _existing_pinned,
"online": True, "online": True,
"status": "online", "status": "online",
"existing": True, "existing": True,
@@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery):
try: try:
_st_raw = (supports_tools or "").strip().lower() _st_raw = (supports_tools or "").strip().lower()
_st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None) _st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None)
_pinned = _normalize_model_ids(pinned_models)
# Stamp owner so the picker only shows this endpoint to the admin # Stamp owner so the picker only shows this endpoint to the admin
# who added it. Pass `shared=true` to mark it null-owner (visible # who added it. Pass `shared=true` to mark it null-owner (visible
# to all users), preserving the pre-fix "everyone sees everything" # to all users), preserving the pre-fix "everyone sees everything"
@@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery):
is_enabled=True, is_enabled=True,
model_type=model_type.strip() if model_type else "llm", model_type=model_type.strip() if model_type else "llm",
cached_models=json.dumps(model_ids) if model_ids else None, cached_models=json.dumps(model_ids) if model_ids else None,
pinned_models=json.dumps(_pinned) if _pinned else None,
supports_tools=_st, supports_tools=_st,
owner=_owner_val, owner=_owner_val,
) )
@@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery):
"id": ep_id, "id": ep_id,
"name": name.strip(), "name": name.strip(),
"base_url": base_url, "base_url": base_url,
"models": model_ids, "models": _merge_model_ids(model_ids, _pinned),
"online": bool(model_ids) or bool(ping.get("reachable")), "pinned_models": _pinned,
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"), "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
"ping_error": ping.get("error") if ping else None, "ping_error": ping.get("error") if ping else None,
} }
@@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery):
hidden = set(json.loads(ep.hidden_models)) hidden = set(json.loads(ep.hidden_models))
except Exception: except Exception:
pass pass
# Try live probe, fall back to cached # Try live probe, fall back to cached. Pinned IDs are admin-entered
# and persist regardless of probe results — never overwritten here.
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3) all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
if all_models: if all_models:
ep.cached_models = json.dumps(all_models) ep.cached_models = json.dumps(all_models)
@@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery):
all_models = json.loads(ep.cached_models) all_models = json.loads(ep.cached_models)
except Exception: except Exception:
pass pass
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
pinned_set = set(pinned)
return [ return [
{"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden} {
for m in all_models "id": m,
"display": m.split("/")[-1],
"is_hidden": m in hidden,
"is_pinned": m in pinned_set,
}
for m in _merge_model_ids(all_models, pinned)
] ]
finally: finally:
db.close() db.close()
@router.patch("/model-endpoints/{ep_id}/models") @router.patch("/model-endpoints/{ep_id}/models")
async def update_hidden_models(ep_id: str, request: Request): async def update_hidden_models(ep_id: str, request: Request):
"""Bulk update hidden models list for an endpoint. """Bulk update hidden and/or pinned model lists for an endpoint.
Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]} Expects JSON body with optional keys:
{"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]}
Each key is updated only when present, so callers can patch one list
without clobbering the other.
""" """
require_admin(request) require_admin(request)
db = SessionLocal() db = SessionLocal()
@@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery):
if not ep: if not ep:
raise HTTPException(404, "Endpoint not found") raise HTTPException(404, "Endpoint not found")
body = await request.json() body = await request.json()
hidden = body.get("hidden", []) if not isinstance(body, dict):
raise HTTPException(400, "Body must be a JSON object")
if "hidden" in body:
hidden = body.get("hidden")
if not isinstance(hidden, list): if not isinstance(hidden, list):
raise HTTPException(400, "hidden must be a list of model IDs") raise HTTPException(400, "hidden must be a list of model IDs")
ep.hidden_models = json.dumps(hidden) if hidden else None ep.hidden_models = json.dumps(hidden) if hidden else None
# Accept either "pinned" or "pinned_models" for the manual IDs list.
if "pinned_models" in body or "pinned" in body:
pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned")))
ep.pinned_models = json.dumps(pinned) if pinned else None
db.commit() db.commit()
_invalidate_models_cache() _invalidate_models_cache()
return {"id": ep_id, "hidden_count": len(hidden)} hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0
pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0
return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count}
finally: finally:
db.close() db.close()
@@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery):
return {"endpoint_id": "", "endpoint_url": "", "model": ""} return {"endpoint_id": "", "endpoint_url": "", "model": ""}
base = _normalize_base(ep.base_url) base = _normalize_base(ep.base_url)
chat_url = build_chat_url(base) chat_url = build_chat_url(base)
if not model and getattr(ep, "cached_models", None): if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)):
try: try:
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None)) visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None))
if visible: if visible:
model = visible[0] model = visible[0]
except Exception: except Exception:
@@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery):
ep.name = body["name"].strip() or ep.name ep.name = body["name"].strip() or ep.name
if "model_type" in body and isinstance(body["model_type"], str): if "model_type" in body and isinstance(body["model_type"], str):
ep.model_type = body["model_type"].strip() or ep.model_type ep.model_type = body["model_type"].strip() or ep.model_type
if "pinned_models" in body:
_pinned = _normalize_model_ids(body["pinned_models"])
ep.pinned_models = json.dumps(_pinned) if _pinned else None
# Rotating an API key used to require DELETE+POST, which wiped # Rotating an API key used to require DELETE+POST, which wiped
# endpoint_url/model from every session referencing the old base # endpoint_url/model from every session referencing the old base
# URL. Allow in-place updates so the admin can change the key # URL. Allow in-place updates so the admin can change the key
@@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery):
"name": ep.name, "name": ep.name,
"model_type": ep.model_type, "model_type": ep.model_type,
"base_url": ep.base_url, "base_url": ep.base_url,
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
} }
finally: finally:
db.close() db.close()
+344
View File
@@ -1,6 +1,9 @@
"""Tests for model route helper functions — pure logic, no server needed.""" """Tests for model route helper functions — pure logic, no server needed."""
import asyncio
import json
import sys import sys
import types import types
from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import httpx import httpx
@@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver
from routes.model_routes import ( from routes.model_routes import (
_match_provider_curated, _match_provider_curated,
_curate_models, _curate_models,
_visible_models,
_normalize_model_ids,
_is_chat_model, _is_chat_model,
_classify_endpoint, _classify_endpoint,
_probe_endpoint, _probe_endpoint,
@@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable:
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail) monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
assert model_routes._docker_host_gateway_reachable() is False assert model_routes._docker_host_gateway_reachable() is False
# ── pinned model IDs: normalization helper ──
class TestNormalizeModelIds:
def test_list_passthrough_trims_and_dedupes(self):
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
def test_json_string_list(self):
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
def test_comma_and_newline_string(self):
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
def test_none_and_empty(self):
assert _normalize_model_ids(None) == []
assert _normalize_model_ids("") == []
assert _normalize_model_ids(" ") == []
def test_non_string_values_ignored(self):
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
# ── pinned model IDs: _visible_models merge ──
class TestVisibleModelsPinned:
def test_includes_pinned_not_in_cached(self):
visible = _visible_models(["a"], None, ["deploy-1"])
assert visible == ["a", "deploy-1"]
def test_cached_plus_pinned_dedup_preserves_order(self):
visible = _visible_models(["a", "b"], None, ["b", "c"])
assert visible == ["a", "b", "c"]
def test_hidden_can_hide_a_pinned_model(self):
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
assert visible == ["a"]
def test_accepts_json_string_inputs(self):
visible = _visible_models('["a"]', '["a"]', '["b"]')
assert visible == ["b"]
# ── pinned model IDs: route behaviour ──
# Building the router exercises FastAPI's Form() routes, which require
# python-multipart. The test env ships without it, so register a minimal stub
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
if "python_multipart" not in sys.modules:
try:
import python_multipart # noqa: F401
except ImportError:
_mp_stub = types.ModuleType("python_multipart")
_mp_stub.__version__ = "0.0.13"
sys.modules["python_multipart"] = _mp_stub
class _PinnedFakeQuery:
def __init__(self, rows):
self.rows = list(rows)
def filter(self, *conditions):
return self
def order_by(self, *args):
return self
def first(self):
return self.rows[0] if self.rows else None
def all(self):
return list(self.rows)
class _PinnedFakeDb:
def __init__(self, rows):
self.rows = rows
self.added = []
self.committed = 0
def query(self, model):
return _PinnedFakeQuery(self.rows)
def add(self, row):
self.added.append(row)
def commit(self):
self.committed += 1
def close(self):
pass
class _FakeCol:
"""Column stand-in: every comparison/operator just returns itself so the
dedupe query expressions evaluate without a real SQLAlchemy column."""
__hash__ = None
def __eq__(self, other):
return self
def is_(self, other):
return self
def __or__(self, other):
return self
def desc(self):
return self
class _RecordingEndpoint:
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
Class-level fake columns let it double as the query class in the dedupe
lookup; instance attributes (set in __init__) shadow them per-row.
"""
id = _FakeCol()
base_url = _FakeCol()
owner = _FakeCol()
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class _PinnedFakeRequest:
def __init__(self, body=None, headers=None):
self._body = body if body is not None else {}
self.headers = headers or {}
async def json(self):
return self._body
def _get_route(path, method):
from routes.model_routes import setup_model_routes
router = setup_model_routes(model_discovery=None)
for route in router.routes:
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError(f"{method} {path} not found")
def _make_endpoint(**kwargs):
base = dict(
id="ep1",
name="EP",
base_url="http://localhost:9999/v1",
api_key=None,
is_enabled=True,
hidden_models=None,
cached_models=None,
pinned_models=None,
model_type="llm",
supports_tools=None,
)
base.update(kwargs)
return SimpleNamespace(**base)
def test_patch_models_saves_pinned_models(monkeypatch):
ep = _make_endpoint()
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
result = asyncio.run(endpoint("ep1", request))
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
assert result["pinned_count"] == 2
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
asyncio.run(endpoint("ep1", request))
assert json.loads(ep.hidden_models) == ["hide-me"]
assert json.loads(ep.pinned_models) == ["deploy-1"]
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
result = endpoint("ep1", _PinnedFakeRequest())
ids = [row["id"] for row in result]
assert ids == ["deploy-1"]
assert result[0]["is_pinned"] is True
def test_reprobe_preserves_pinned_models(monkeypatch):
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
monkeypatch.setattr(
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
)
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
response = endpoint("ep1", _PinnedFakeRequest())
async def _drain():
async for _ in response.body_iterator:
pass
asyncio.run(_drain())
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
assert json.loads(ep.pinned_models) == ["deploy-1"]
assert json.loads(ep.cached_models) == ["m1"]
def test_visible_models_handles_malformed_strings():
# Non-JSON cached/pinned strings are treated as comma/newline lists and
# never raise; a malformed hidden string is normalized too.
result = _visible_models("a,b", "b", "{bad json")
assert isinstance(result, list)
assert result == ["a", "{bad json"]
assert _visible_models("", None, "") == []
assert _visible_models("only-cached", None, None) == ["only-cached"]
def _create_form_kwargs(**overrides):
"""Defaults for every Form() param create_model_endpoint reads directly.
Calling the route as a plain function bypasses FastAPI form parsing, so the
Form() sentinels must be replaced with real strings.
"""
kwargs = dict(
name="",
api_key="",
skip_probe="true", # avoid any network probe in unit tests
require_models="false",
model_type="llm",
supports_tools="",
pinned_models="",
container_local="false",
shared="true",
)
kwargs.update(overrides)
return kwargs
def _patch_create_deps(monkeypatch, db):
import src.auth_helpers as auth_helpers
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
db = _PinnedFakeDb([]) # no existing row → fresh create path
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
)
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
assert result["models"] == ["deploy-1", "deploy-2"]
assert result["online"] is True
# Persisted onto the created row.
assert len(db.added) == 1
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
existing = _make_endpoint(
cached_models=json.dumps(["m1"]),
hidden_models=None,
pinned_models=json.dumps(["old-pin"]),
)
db = _PinnedFakeDb([existing])
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(pinned_models="new-pin"),
)
assert result["existing"] is True
# Incoming pin merged onto the existing pins (no clobber, order preserved).
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
assert result["pinned_models"] == ["old-pin", "new-pin"]
# models = cached + pinned - hidden, visible merged list.
assert result["models"] == ["m1", "old-pin", "new-pin"]
# No new row created on the dedupe path.
assert db.added == []
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
existing = _make_endpoint(
cached_models=json.dumps(["m1"]),
pinned_models=json.dumps(["keep-me"]),
)
db = _PinnedFakeDb([existing])
_patch_create_deps(monkeypatch, db)
create = _get_route("/api/model-endpoints", "POST")
result = create(
_PinnedFakeRequest(),
base_url="http://host:1234/v1",
**_create_form_kwargs(), # pinned_models defaults to ""
)
assert json.loads(existing.pinned_models) == ["keep-me"]
assert result["pinned_models"] == ["keep-me"]
assert db.committed == 0 # nothing to persist