mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 18:25:26 -04:00
feat(models): support pinned endpoint model IDs
This commit is contained in:
committed by
GitHub
parent
1284b14a13
commit
145f4fd2b4
@@ -334,6 +334,7 @@ class ModelEndpoint(TimestampMixin, Base):
|
|||||||
is_enabled = Column(Boolean, default=True)
|
is_enabled = Column(Boolean, default=True)
|
||||||
hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing
|
hidden_models = Column(Text, nullable=True) # JSON list of model IDs that failed probing
|
||||||
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
|
cached_models = Column(Text, nullable=True) # JSON list of last-known model IDs (avoids probe on list)
|
||||||
|
pinned_models = Column(Text, nullable=True) # JSON list of admin-pinned model IDs (manual, may not appear in /v1/models)
|
||||||
model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
|
model_type = Column(String, nullable=True, default="llm") # "llm" or "image"
|
||||||
# Whether models on this endpoint accept OpenAI-style function
|
# Whether models on this endpoint accept OpenAI-style function
|
||||||
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
|
# schemas + emit `tool_calls`. Auto-detected at Cookbook auto-
|
||||||
@@ -856,6 +857,24 @@ def _migrate_add_cached_models_column():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
|
logging.getLogger(__name__).warning(f"cached_models migration failed: {e}")
|
||||||
|
|
||||||
|
def _migrate_add_pinned_models_column():
|
||||||
|
"""Add pinned_models column to model_endpoints if it doesn't exist."""
|
||||||
|
import sqlite3
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
if columns and "pinned_models" not in columns:
|
||||||
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN pinned_models TEXT")
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Migrated: added 'pinned_models' column to model_endpoints")
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"pinned_models migration failed: {e}")
|
||||||
|
|
||||||
def _migrate_add_notes_sort_order():
|
def _migrate_add_notes_sort_order():
|
||||||
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
|
"""Add sort_order, image_url, repeat columns to notes if they don't exist."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
@@ -1511,6 +1530,7 @@ def init_db():
|
|||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
_migrate_add_hidden_models_column()
|
_migrate_add_hidden_models_column()
|
||||||
_migrate_add_cached_models_column()
|
_migrate_add_cached_models_column()
|
||||||
|
_migrate_add_pinned_models_column()
|
||||||
_migrate_add_notes_sort_order()
|
_migrate_add_notes_sort_order()
|
||||||
_migrate_add_model_type_column()
|
_migrate_add_model_type_column()
|
||||||
_migrate_add_model_endpoint_owner_column()
|
_migrate_add_model_endpoint_owner_column()
|
||||||
|
|||||||
+129
-25
@@ -633,13 +633,68 @@ def _model_endpoint_error_message(base_url: str, ping: Dict[str, Any] = None) ->
|
|||||||
return "No models found for that provider/key."
|
return "No models found for that provider/key."
|
||||||
|
|
||||||
|
|
||||||
def _visible_models(cached_models, hidden_models):
|
def _normalize_model_ids(value):
|
||||||
"""Filter cached model IDs by hidden_models. Returns list of visible IDs."""
|
"""Coerce a model-ID input into a clean, ordered list of strings.
|
||||||
all_models = json.loads(cached_models) if isinstance(cached_models, str) else (cached_models or [])
|
|
||||||
|
Accepts a list, a JSON-encoded list string, or a comma/newline separated
|
||||||
|
string (handy for form or backend API input). Trims whitespace, drops
|
||||||
|
empty and non-string values, and de-duplicates preserving first-seen order.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
items = value
|
||||||
|
if isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
except Exception:
|
||||||
|
parsed = None
|
||||||
|
items = parsed if isinstance(parsed, list) else re.split(r"[,\n]", text)
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
out, seen = [], set()
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
s = item.strip()
|
||||||
|
if not s or s in seen:
|
||||||
|
continue
|
||||||
|
seen.add(s)
|
||||||
|
out.append(s)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_model_ids(*lists):
|
||||||
|
"""Concatenate model-ID lists, de-duplicating and preserving order."""
|
||||||
|
out, seen = [], set()
|
||||||
|
for ids in lists:
|
||||||
|
for m in (ids or []):
|
||||||
|
if not isinstance(m, str) or m in seen:
|
||||||
|
continue
|
||||||
|
seen.add(m)
|
||||||
|
out.append(m)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _visible_models(cached_models, hidden_models, pinned_models=None):
|
||||||
|
"""Merge cached + pinned model IDs, then filter out hidden ones.
|
||||||
|
|
||||||
|
Pinned IDs are admin-entered and may not appear in cached_models (e.g.
|
||||||
|
cloud deployment IDs the provider does not list in /v1/models). Returns an
|
||||||
|
ordered, de-duplicated list of visible IDs.
|
||||||
|
"""
|
||||||
|
# Normalize each input so JSON strings, lists, comma/newline strings, and
|
||||||
|
# malformed strings are all handled without raising.
|
||||||
|
merged = _merge_model_ids(
|
||||||
|
_normalize_model_ids(cached_models),
|
||||||
|
_normalize_model_ids(pinned_models),
|
||||||
|
)
|
||||||
if not hidden_models:
|
if not hidden_models:
|
||||||
return all_models
|
return merged
|
||||||
hidden = set(json.loads(hidden_models) if isinstance(hidden_models, str) else (hidden_models or []))
|
hidden = set(_normalize_model_ids(hidden_models))
|
||||||
return [m for m in all_models if m not in hidden]
|
return [m for m in merged if m not in hidden]
|
||||||
|
|
||||||
|
|
||||||
def setup_model_routes(model_discovery):
|
def setup_model_routes(model_discovery):
|
||||||
@@ -1123,10 +1178,13 @@ def setup_model_routes(model_discovery):
|
|||||||
hidden = set(json.loads(r.hidden_models))
|
hidden = set(json.loads(r.hidden_models))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
visible = [m for m in all_models if m not in hidden]
|
pinned = _normalize_model_ids(getattr(r, "pinned_models", None))
|
||||||
status = "online" if all_models else "offline"
|
visible = _visible_models(all_models, r.hidden_models, pinned)
|
||||||
|
# Endpoint counts as reachable if it has any model — including
|
||||||
|
# admin-pinned IDs that a probe would never surface.
|
||||||
|
status = "online" if (all_models or pinned) else "offline"
|
||||||
ping = None
|
ping = None
|
||||||
if not all_models and r.is_enabled:
|
if not all_models and not pinned and r.is_enabled:
|
||||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||||
if ping.get("reachable"):
|
if ping.get("reachable"):
|
||||||
status = "empty"
|
status = "empty"
|
||||||
@@ -1137,6 +1195,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"has_key": bool(r.api_key),
|
"has_key": bool(r.api_key),
|
||||||
"is_enabled": r.is_enabled,
|
"is_enabled": r.is_enabled,
|
||||||
"models": visible,
|
"models": visible,
|
||||||
|
"pinned_models": pinned,
|
||||||
"hidden_count": len(hidden),
|
"hidden_count": len(hidden),
|
||||||
"online": status != "offline",
|
"online": status != "offline",
|
||||||
"status": status,
|
"status": status,
|
||||||
@@ -1158,6 +1217,7 @@ def setup_model_routes(model_discovery):
|
|||||||
require_models: str = Form("false"),
|
require_models: str = Form("false"),
|
||||||
model_type: str = Form("llm"),
|
model_type: str = Form("llm"),
|
||||||
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
|
supports_tools: str = Form(""), # "true"/"false"/"" (unknown)
|
||||||
|
pinned_models: str = Form(""), # admin-pinned IDs: list/JSON/comma/newline
|
||||||
container_local: str = Form("false"),
|
container_local: str = Form("false"),
|
||||||
# Default `shared=true` → endpoints are visible to all users (the
|
# Default `shared=true` → endpoints are visible to all users (the
|
||||||
# app's historical behaviour). Admins can pass `shared=false` to
|
# app's historical behaviour). Admins can pass `shared=false` to
|
||||||
@@ -1199,11 +1259,28 @@ def setup_model_routes(model_discovery):
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
|
# Persist any incoming pinned IDs onto the existing row. An
|
||||||
|
# empty/omitted form field must not wipe previously pinned IDs.
|
||||||
|
_incoming_pinned = _normalize_model_ids(pinned_models)
|
||||||
|
if _incoming_pinned:
|
||||||
|
_merged_pinned = _merge_model_ids(
|
||||||
|
_normalize_model_ids(getattr(existing, "pinned_models", None)),
|
||||||
|
_incoming_pinned,
|
||||||
|
)
|
||||||
|
existing.pinned_models = json.dumps(_merged_pinned) if _merged_pinned else None
|
||||||
|
_db_dedup.commit()
|
||||||
|
_invalidate_models_cache()
|
||||||
|
_existing_pinned = _normalize_model_ids(getattr(existing, "pinned_models", None))
|
||||||
return {
|
return {
|
||||||
"id": existing.id,
|
"id": existing.id,
|
||||||
"name": existing.name,
|
"name": existing.name,
|
||||||
"base_url": existing.base_url,
|
"base_url": existing.base_url,
|
||||||
"models": json.loads(existing.cached_models) if existing.cached_models else [],
|
"models": _visible_models(
|
||||||
|
getattr(existing, "cached_models", None),
|
||||||
|
getattr(existing, "hidden_models", None),
|
||||||
|
existing.pinned_models,
|
||||||
|
),
|
||||||
|
"pinned_models": _existing_pinned,
|
||||||
"online": True,
|
"online": True,
|
||||||
"status": "online",
|
"status": "online",
|
||||||
"existing": True,
|
"existing": True,
|
||||||
@@ -1225,6 +1302,7 @@ def setup_model_routes(model_discovery):
|
|||||||
try:
|
try:
|
||||||
_st_raw = (supports_tools or "").strip().lower()
|
_st_raw = (supports_tools or "").strip().lower()
|
||||||
_st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None)
|
_st = True if _st_raw in ("true", "1", "yes") else (False if _st_raw in ("false", "0", "no") else None)
|
||||||
|
_pinned = _normalize_model_ids(pinned_models)
|
||||||
# Stamp owner so the picker only shows this endpoint to the admin
|
# Stamp owner so the picker only shows this endpoint to the admin
|
||||||
# who added it. Pass `shared=true` to mark it null-owner (visible
|
# who added it. Pass `shared=true` to mark it null-owner (visible
|
||||||
# to all users), preserving the pre-fix "everyone sees everything"
|
# to all users), preserving the pre-fix "everyone sees everything"
|
||||||
@@ -1240,6 +1318,7 @@ def setup_model_routes(model_discovery):
|
|||||||
is_enabled=True,
|
is_enabled=True,
|
||||||
model_type=model_type.strip() if model_type else "llm",
|
model_type=model_type.strip() if model_type else "llm",
|
||||||
cached_models=json.dumps(model_ids) if model_ids else None,
|
cached_models=json.dumps(model_ids) if model_ids else None,
|
||||||
|
pinned_models=json.dumps(_pinned) if _pinned else None,
|
||||||
supports_tools=_st,
|
supports_tools=_st,
|
||||||
owner=_owner_val,
|
owner=_owner_val,
|
||||||
)
|
)
|
||||||
@@ -1265,9 +1344,10 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep_id,
|
"id": ep_id,
|
||||||
"name": name.strip(),
|
"name": name.strip(),
|
||||||
"base_url": base_url,
|
"base_url": base_url,
|
||||||
"models": model_ids,
|
"models": _merge_model_ids(model_ids, _pinned),
|
||||||
"online": bool(model_ids) or bool(ping.get("reachable")),
|
"pinned_models": _pinned,
|
||||||
"status": "online" if model_ids else ("empty" if ping.get("reachable") else "offline"),
|
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||||
|
"status": "online" if (model_ids or _pinned) else ("empty" if ping.get("reachable") else "offline"),
|
||||||
"ping_error": ping.get("error") if ping else None,
|
"ping_error": ping.get("error") if ping else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1360,7 +1440,8 @@ def setup_model_routes(model_discovery):
|
|||||||
hidden = set(json.loads(ep.hidden_models))
|
hidden = set(json.loads(ep.hidden_models))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Try live probe, fall back to cached
|
# Try live probe, fall back to cached. Pinned IDs are admin-entered
|
||||||
|
# and persist regardless of probe results — never overwritten here.
|
||||||
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
|
all_models = _probe_endpoint(ep.base_url, ep.api_key, timeout=3)
|
||||||
if all_models:
|
if all_models:
|
||||||
ep.cached_models = json.dumps(all_models)
|
ep.cached_models = json.dumps(all_models)
|
||||||
@@ -1370,18 +1451,28 @@ def setup_model_routes(model_discovery):
|
|||||||
all_models = json.loads(ep.cached_models)
|
all_models = json.loads(ep.cached_models)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
|
||||||
|
pinned_set = set(pinned)
|
||||||
return [
|
return [
|
||||||
{"id": m, "display": m.split("/")[-1], "is_hidden": m in hidden}
|
{
|
||||||
for m in all_models
|
"id": m,
|
||||||
|
"display": m.split("/")[-1],
|
||||||
|
"is_hidden": m in hidden,
|
||||||
|
"is_pinned": m in pinned_set,
|
||||||
|
}
|
||||||
|
for m in _merge_model_ids(all_models, pinned)
|
||||||
]
|
]
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@router.patch("/model-endpoints/{ep_id}/models")
|
@router.patch("/model-endpoints/{ep_id}/models")
|
||||||
async def update_hidden_models(ep_id: str, request: Request):
|
async def update_hidden_models(ep_id: str, request: Request):
|
||||||
"""Bulk update hidden models list for an endpoint.
|
"""Bulk update hidden and/or pinned model lists for an endpoint.
|
||||||
|
|
||||||
Expects JSON body: {"hidden": ["model-id-1", "model-id-2"]}
|
Expects JSON body with optional keys:
|
||||||
|
{"hidden": ["model-id-1", ...], "pinned_models": ["deploy-id", ...]}
|
||||||
|
Each key is updated only when present, so callers can patch one list
|
||||||
|
without clobbering the other.
|
||||||
"""
|
"""
|
||||||
require_admin(request)
|
require_admin(request)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
@@ -1390,13 +1481,22 @@ def setup_model_routes(model_discovery):
|
|||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found")
|
raise HTTPException(404, "Endpoint not found")
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
hidden = body.get("hidden", [])
|
if not isinstance(body, dict):
|
||||||
if not isinstance(hidden, list):
|
raise HTTPException(400, "Body must be a JSON object")
|
||||||
raise HTTPException(400, "hidden must be a list of model IDs")
|
if "hidden" in body:
|
||||||
ep.hidden_models = json.dumps(hidden) if hidden else None
|
hidden = body.get("hidden")
|
||||||
|
if not isinstance(hidden, list):
|
||||||
|
raise HTTPException(400, "hidden must be a list of model IDs")
|
||||||
|
ep.hidden_models = json.dumps(hidden) if hidden else None
|
||||||
|
# Accept either "pinned" or "pinned_models" for the manual IDs list.
|
||||||
|
if "pinned_models" in body or "pinned" in body:
|
||||||
|
pinned = _normalize_model_ids(body.get("pinned_models", body.get("pinned")))
|
||||||
|
ep.pinned_models = json.dumps(pinned) if pinned else None
|
||||||
db.commit()
|
db.commit()
|
||||||
_invalidate_models_cache()
|
_invalidate_models_cache()
|
||||||
return {"id": ep_id, "hidden_count": len(hidden)}
|
hidden_count = len(json.loads(ep.hidden_models)) if ep.hidden_models else 0
|
||||||
|
pinned_count = len(json.loads(ep.pinned_models)) if ep.pinned_models else 0
|
||||||
|
return {"id": ep_id, "hidden_count": hidden_count, "pinned_count": pinned_count}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1494,9 +1594,9 @@ def setup_model_routes(model_discovery):
|
|||||||
return {"endpoint_id": "", "endpoint_url": "", "model": ""}
|
return {"endpoint_id": "", "endpoint_url": "", "model": ""}
|
||||||
base = _normalize_base(ep.base_url)
|
base = _normalize_base(ep.base_url)
|
||||||
chat_url = build_chat_url(base)
|
chat_url = build_chat_url(base)
|
||||||
if not model and getattr(ep, "cached_models", None):
|
if not model and (getattr(ep, "cached_models", None) or getattr(ep, "pinned_models", None)):
|
||||||
try:
|
try:
|
||||||
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None))
|
visible = _visible_models(ep.cached_models, getattr(ep, "hidden_models", None), getattr(ep, "pinned_models", None))
|
||||||
if visible:
|
if visible:
|
||||||
model = visible[0]
|
model = visible[0]
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -1532,6 +1632,9 @@ def setup_model_routes(model_discovery):
|
|||||||
ep.name = body["name"].strip() or ep.name
|
ep.name = body["name"].strip() or ep.name
|
||||||
if "model_type" in body and isinstance(body["model_type"], str):
|
if "model_type" in body and isinstance(body["model_type"], str):
|
||||||
ep.model_type = body["model_type"].strip() or ep.model_type
|
ep.model_type = body["model_type"].strip() or ep.model_type
|
||||||
|
if "pinned_models" in body:
|
||||||
|
_pinned = _normalize_model_ids(body["pinned_models"])
|
||||||
|
ep.pinned_models = json.dumps(_pinned) if _pinned else None
|
||||||
# Rotating an API key used to require DELETE+POST, which wiped
|
# Rotating an API key used to require DELETE+POST, which wiped
|
||||||
# endpoint_url/model from every session referencing the old base
|
# endpoint_url/model from every session referencing the old base
|
||||||
# URL. Allow in-place updates so the admin can change the key
|
# URL. Allow in-place updates so the admin can change the key
|
||||||
@@ -1560,6 +1663,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"name": ep.name,
|
"name": ep.name,
|
||||||
"model_type": ep.model_type,
|
"model_type": ep.model_type,
|
||||||
"base_url": ep.base_url,
|
"base_url": ep.base_url,
|
||||||
|
"pinned_models": _normalize_model_ids(getattr(ep, "pinned_models", None)),
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
"""Tests for model route helper functions — pure logic, no server needed."""
|
"""Tests for model route helper functions — pure logic, no server needed."""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -29,6 +32,8 @@ import src.endpoint_resolver as endpoint_resolver
|
|||||||
from routes.model_routes import (
|
from routes.model_routes import (
|
||||||
_match_provider_curated,
|
_match_provider_curated,
|
||||||
_curate_models,
|
_curate_models,
|
||||||
|
_visible_models,
|
||||||
|
_normalize_model_ids,
|
||||||
_is_chat_model,
|
_is_chat_model,
|
||||||
_classify_endpoint,
|
_classify_endpoint,
|
||||||
_probe_endpoint,
|
_probe_endpoint,
|
||||||
@@ -470,3 +475,342 @@ class TestDockerHostGatewayReachable:
|
|||||||
|
|
||||||
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
|
monkeypatch.setattr(model_routes.socket, "getaddrinfo", _fail)
|
||||||
assert model_routes._docker_host_gateway_reachable() is False
|
assert model_routes._docker_host_gateway_reachable() is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── pinned model IDs: normalization helper ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeModelIds:
|
||||||
|
def test_list_passthrough_trims_and_dedupes(self):
|
||||||
|
assert _normalize_model_ids([" a ", "a", "b", ""]) == ["a", "b"]
|
||||||
|
|
||||||
|
def test_json_string_list(self):
|
||||||
|
assert _normalize_model_ids('["x", "y", "x"]') == ["x", "y"]
|
||||||
|
|
||||||
|
def test_comma_and_newline_string(self):
|
||||||
|
assert _normalize_model_ids("a, b\n c ,a") == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_none_and_empty(self):
|
||||||
|
assert _normalize_model_ids(None) == []
|
||||||
|
assert _normalize_model_ids("") == []
|
||||||
|
assert _normalize_model_ids(" ") == []
|
||||||
|
|
||||||
|
def test_non_string_values_ignored(self):
|
||||||
|
assert _normalize_model_ids([1, "ok", None, {"a": 1}]) == ["ok"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── pinned model IDs: _visible_models merge ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisibleModelsPinned:
|
||||||
|
def test_includes_pinned_not_in_cached(self):
|
||||||
|
visible = _visible_models(["a"], None, ["deploy-1"])
|
||||||
|
assert visible == ["a", "deploy-1"]
|
||||||
|
|
||||||
|
def test_cached_plus_pinned_dedup_preserves_order(self):
|
||||||
|
visible = _visible_models(["a", "b"], None, ["b", "c"])
|
||||||
|
assert visible == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_hidden_can_hide_a_pinned_model(self):
|
||||||
|
visible = _visible_models(["a"], ["deploy-1"], ["deploy-1"])
|
||||||
|
assert visible == ["a"]
|
||||||
|
|
||||||
|
def test_accepts_json_string_inputs(self):
|
||||||
|
visible = _visible_models('["a"]', '["a"]', '["b"]')
|
||||||
|
assert visible == ["b"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── pinned model IDs: route behaviour ──
|
||||||
|
|
||||||
|
# Building the router exercises FastAPI's Form() routes, which require
|
||||||
|
# python-multipart. The test env ships without it, so register a minimal stub
|
||||||
|
# (mirrors tests/test_review_regressions.py) only when it's genuinely missing.
|
||||||
|
if "python_multipart" not in sys.modules:
|
||||||
|
try:
|
||||||
|
import python_multipart # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
_mp_stub = types.ModuleType("python_multipart")
|
||||||
|
_mp_stub.__version__ = "0.0.13"
|
||||||
|
sys.modules["python_multipart"] = _mp_stub
|
||||||
|
|
||||||
|
|
||||||
|
class _PinnedFakeQuery:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self.rows = list(rows)
|
||||||
|
|
||||||
|
def filter(self, *conditions):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def order_by(self, *args):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
return self.rows[0] if self.rows else None
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
return list(self.rows)
|
||||||
|
|
||||||
|
|
||||||
|
class _PinnedFakeDb:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self.rows = rows
|
||||||
|
self.added = []
|
||||||
|
self.committed = 0
|
||||||
|
|
||||||
|
def query(self, model):
|
||||||
|
return _PinnedFakeQuery(self.rows)
|
||||||
|
|
||||||
|
def add(self, row):
|
||||||
|
self.added.append(row)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
self.committed += 1
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeCol:
|
||||||
|
"""Column stand-in: every comparison/operator just returns itself so the
|
||||||
|
dedupe query expressions evaluate without a real SQLAlchemy column."""
|
||||||
|
|
||||||
|
__hash__ = None
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def is_(self, other):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __or__(self, other):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def desc(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingEndpoint:
|
||||||
|
"""ModelEndpoint stand-in that stores constructor kwargs as attributes.
|
||||||
|
|
||||||
|
Class-level fake columns let it double as the query class in the dedupe
|
||||||
|
lookup; instance attributes (set in __init__) shadow them per-row.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = _FakeCol()
|
||||||
|
base_url = _FakeCol()
|
||||||
|
owner = _FakeCol()
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class _PinnedFakeRequest:
|
||||||
|
def __init__(self, body=None, headers=None):
|
||||||
|
self._body = body if body is not None else {}
|
||||||
|
self.headers = headers or {}
|
||||||
|
|
||||||
|
async def json(self):
|
||||||
|
return self._body
|
||||||
|
|
||||||
|
|
||||||
|
def _get_route(path, method):
|
||||||
|
from routes.model_routes import setup_model_routes
|
||||||
|
router = setup_model_routes(model_discovery=None)
|
||||||
|
for route in router.routes:
|
||||||
|
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError(f"{method} {path} not found")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_endpoint(**kwargs):
|
||||||
|
base = dict(
|
||||||
|
id="ep1",
|
||||||
|
name="EP",
|
||||||
|
base_url="http://localhost:9999/v1",
|
||||||
|
api_key=None,
|
||||||
|
is_enabled=True,
|
||||||
|
hidden_models=None,
|
||||||
|
cached_models=None,
|
||||||
|
pinned_models=None,
|
||||||
|
model_type="llm",
|
||||||
|
supports_tools=None,
|
||||||
|
)
|
||||||
|
base.update(kwargs)
|
||||||
|
return SimpleNamespace(**base)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_models_saves_pinned_models(monkeypatch):
|
||||||
|
ep = _make_endpoint()
|
||||||
|
db = _PinnedFakeDb([ep])
|
||||||
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
|
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||||
|
|
||||||
|
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1", "deploy-1", "deploy-2"]})
|
||||||
|
result = asyncio.run(endpoint("ep1", request))
|
||||||
|
|
||||||
|
assert json.loads(ep.pinned_models) == ["deploy-1", "deploy-2"]
|
||||||
|
assert result["pinned_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_models_pinned_does_not_clobber_hidden(monkeypatch):
|
||||||
|
ep = _make_endpoint(hidden_models=json.dumps(["hide-me"]))
|
||||||
|
db = _PinnedFakeDb([ep])
|
||||||
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
|
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "PATCH")
|
||||||
|
|
||||||
|
request = _PinnedFakeRequest(body={"pinned_models": ["deploy-1"]})
|
||||||
|
asyncio.run(endpoint("ep1", request))
|
||||||
|
|
||||||
|
assert json.loads(ep.hidden_models) == ["hide-me"]
|
||||||
|
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_models_returns_pinned_when_probe_empty(monkeypatch):
|
||||||
|
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||||
|
db = _PinnedFakeDb([ep])
|
||||||
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
|
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: [])
|
||||||
|
endpoint = _get_route("/api/model-endpoints/{ep_id}/models", "GET")
|
||||||
|
|
||||||
|
result = endpoint("ep1", _PinnedFakeRequest())
|
||||||
|
|
||||||
|
ids = [row["id"] for row in result]
|
||||||
|
assert ids == ["deploy-1"]
|
||||||
|
assert result[0]["is_pinned"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_reprobe_preserves_pinned_models(monkeypatch):
|
||||||
|
ep = _make_endpoint(pinned_models=json.dumps(["deploy-1"]))
|
||||||
|
db = _PinnedFakeDb([ep])
|
||||||
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
|
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["m1"])
|
||||||
|
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
model_routes, "_probe_single_model", lambda *a, **k: {"status": "ok"}
|
||||||
|
)
|
||||||
|
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
|
||||||
|
|
||||||
|
response = endpoint("ep1", _PinnedFakeRequest())
|
||||||
|
|
||||||
|
async def _drain():
|
||||||
|
async for _ in response.body_iterator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(_drain())
|
||||||
|
|
||||||
|
# Probe rewrites cached/hidden but must never touch admin-pinned IDs.
|
||||||
|
assert json.loads(ep.pinned_models) == ["deploy-1"]
|
||||||
|
assert json.loads(ep.cached_models) == ["m1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_visible_models_handles_malformed_strings():
|
||||||
|
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
||||||
|
# never raise; a malformed hidden string is normalized too.
|
||||||
|
result = _visible_models("a,b", "b", "{bad json")
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert result == ["a", "{bad json"]
|
||||||
|
assert _visible_models("", None, "") == []
|
||||||
|
assert _visible_models("only-cached", None, None) == ["only-cached"]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_form_kwargs(**overrides):
|
||||||
|
"""Defaults for every Form() param create_model_endpoint reads directly.
|
||||||
|
|
||||||
|
Calling the route as a plain function bypasses FastAPI form parsing, so the
|
||||||
|
Form() sentinels must be replaced with real strings.
|
||||||
|
"""
|
||||||
|
kwargs = dict(
|
||||||
|
name="",
|
||||||
|
api_key="",
|
||||||
|
skip_probe="true", # avoid any network probe in unit tests
|
||||||
|
require_models="false",
|
||||||
|
model_type="llm",
|
||||||
|
supports_tools="",
|
||||||
|
pinned_models="",
|
||||||
|
container_local="false",
|
||||||
|
shared="true",
|
||||||
|
)
|
||||||
|
kwargs.update(overrides)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_create_deps(monkeypatch, db):
|
||||||
|
import src.auth_helpers as auth_helpers
|
||||||
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
|
monkeypatch.setattr(model_routes, "ModelEndpoint", _RecordingEndpoint)
|
||||||
|
monkeypatch.setattr(model_routes, "_normalize_base", lambda b: b)
|
||||||
|
monkeypatch.setattr(model_routes, "_rewrite_loopback_for_docker", lambda b, **k: b)
|
||||||
|
monkeypatch.setattr(model_routes, "_load_settings", lambda: {"default_endpoint_id": "exists"})
|
||||||
|
monkeypatch.setattr(endpoint_resolver, "resolve_url", lambda u: u)
|
||||||
|
monkeypatch.setattr(auth_helpers, "get_current_user", lambda req: None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_creates_endpoint_with_pinned_models(monkeypatch):
|
||||||
|
db = _PinnedFakeDb([]) # no existing row → fresh create path
|
||||||
|
_patch_create_deps(monkeypatch, db)
|
||||||
|
create = _get_route("/api/model-endpoints", "POST")
|
||||||
|
|
||||||
|
result = create(
|
||||||
|
_PinnedFakeRequest(),
|
||||||
|
base_url="http://host:1234/v1",
|
||||||
|
**_create_form_kwargs(pinned_models="deploy-1, deploy-1\ndeploy-2"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["pinned_models"] == ["deploy-1", "deploy-2"]
|
||||||
|
assert result["models"] == ["deploy-1", "deploy-2"]
|
||||||
|
assert result["online"] is True
|
||||||
|
# Persisted onto the created row.
|
||||||
|
assert len(db.added) == 1
|
||||||
|
assert json.loads(db.added[0].pinned_models) == ["deploy-1", "deploy-2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_dedupe_existing_merges_and_returns_pinned(monkeypatch):
|
||||||
|
existing = _make_endpoint(
|
||||||
|
cached_models=json.dumps(["m1"]),
|
||||||
|
hidden_models=None,
|
||||||
|
pinned_models=json.dumps(["old-pin"]),
|
||||||
|
)
|
||||||
|
db = _PinnedFakeDb([existing])
|
||||||
|
_patch_create_deps(monkeypatch, db)
|
||||||
|
create = _get_route("/api/model-endpoints", "POST")
|
||||||
|
|
||||||
|
result = create(
|
||||||
|
_PinnedFakeRequest(),
|
||||||
|
base_url="http://host:1234/v1",
|
||||||
|
**_create_form_kwargs(pinned_models="new-pin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["existing"] is True
|
||||||
|
# Incoming pin merged onto the existing pins (no clobber, order preserved).
|
||||||
|
assert json.loads(existing.pinned_models) == ["old-pin", "new-pin"]
|
||||||
|
assert result["pinned_models"] == ["old-pin", "new-pin"]
|
||||||
|
# models = cached + pinned - hidden, visible merged list.
|
||||||
|
assert result["models"] == ["m1", "old-pin", "new-pin"]
|
||||||
|
# No new row created on the dedupe path.
|
||||||
|
assert db.added == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_dedupe_existing_does_not_clobber_pinned_when_omitted(monkeypatch):
|
||||||
|
existing = _make_endpoint(
|
||||||
|
cached_models=json.dumps(["m1"]),
|
||||||
|
pinned_models=json.dumps(["keep-me"]),
|
||||||
|
)
|
||||||
|
db = _PinnedFakeDb([existing])
|
||||||
|
_patch_create_deps(monkeypatch, db)
|
||||||
|
create = _get_route("/api/model-endpoints", "POST")
|
||||||
|
|
||||||
|
result = create(
|
||||||
|
_PinnedFakeRequest(),
|
||||||
|
base_url="http://host:1234/v1",
|
||||||
|
**_create_form_kwargs(), # pinned_models defaults to ""
|
||||||
|
)
|
||||||
|
|
||||||
|
assert json.loads(existing.pinned_models) == ["keep-me"]
|
||||||
|
assert result["pinned_models"] == ["keep-me"]
|
||||||
|
assert db.committed == 0 # nothing to persist
|
||||||
|
|||||||
Reference in New Issue
Block a user