mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 02:05:22 -04:00
Fix model endpoint route test regressions
This commit is contained in:
+102
-38
@@ -4,6 +4,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
|
import hashlib
|
||||||
import socket
|
import socket
|
||||||
import time as _time
|
import time as _time
|
||||||
import logging
|
import logging
|
||||||
@@ -502,9 +503,53 @@ def _is_chat_model(model_id: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_detect_provider(base_url: str) -> str:
|
||||||
|
"""Best-effort provider detection that must not break endpoint probing."""
|
||||||
|
try:
|
||||||
|
return _detect_provider(base_url)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Provider detection failed for %s: %s", base_url, exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_build_models_url(base_url: str) -> str:
|
||||||
|
"""Build a /models URL without letting optional provider imports break probes."""
|
||||||
|
try:
|
||||||
|
return build_models_url(base_url)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Model URL detection failed for %s: %s", base_url, exc)
|
||||||
|
return f"{(base_url or '').rstrip('/')}/models"
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_build_headers(api_key: Optional[str], base_url: str) -> dict:
|
||||||
|
"""Build auth headers without letting optional provider imports break probes."""
|
||||||
|
try:
|
||||||
|
return build_headers(api_key, base_url)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Header detection failed for %s: %s", base_url, exc)
|
||||||
|
return {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_discovery_only_provider(provider: str) -> bool:
|
||||||
|
return provider == "chatgpt-subscription"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_probe_key(ep) -> Optional[str]:
|
||||||
|
"""API key/bearer to probe an endpoint with."""
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||||
|
return key
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||||
provider = _detect_provider(base)
|
provider = _safe_detect_provider(base)
|
||||||
|
if _is_discovery_only_provider(provider):
|
||||||
|
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Say OK"},
|
{"role": "user", "content": "Say OK"},
|
||||||
@@ -523,12 +568,12 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1
|
|||||||
elif provider == "ollama":
|
elif provider == "ollama":
|
||||||
from src.llm_core import _build_ollama_payload
|
from src.llm_core import _build_ollama_payload
|
||||||
target_url = build_chat_url(base)
|
target_url = build_chat_url(base)
|
||||||
h = build_headers(api_key, base)
|
h = _safe_build_headers(api_key, base)
|
||||||
h["Content-Type"] = "application/json"
|
h["Content-Type"] = "application/json"
|
||||||
payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools)
|
payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools)
|
||||||
else:
|
else:
|
||||||
target_url = build_chat_url(base)
|
target_url = build_chat_url(base)
|
||||||
h = build_headers(api_key, base)
|
h = _safe_build_headers(api_key, base)
|
||||||
h["Content-Type"] = "application/json"
|
h["Content-Type"] = "application/json"
|
||||||
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
|
from src.llm_core import _uses_max_completion_tokens, _restricts_temperature
|
||||||
_max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens"
|
_max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens"
|
||||||
@@ -618,9 +663,15 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
base = resolve_url(_normalize_base(base_url))
|
base = resolve_url(_normalize_base(base_url))
|
||||||
if _detect_provider(base) == "anthropic":
|
provider = _safe_detect_provider(base)
|
||||||
|
if provider == "chatgpt-subscription":
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
if api_key:
|
||||||
|
return fetch_available_models(api_key, timeout=timeout)
|
||||||
|
return []
|
||||||
|
if provider == "anthropic":
|
||||||
# Try Anthropic's /v1/models endpoint first
|
# Try Anthropic's /v1/models endpoint first
|
||||||
url = build_models_url(base)
|
url = _safe_build_models_url(base)
|
||||||
headers = {"anthropic-version": "2023-06-01"}
|
headers = {"anthropic-version": "2023-06-01"}
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["x-api-key"] = api_key
|
headers["x-api-key"] = api_key
|
||||||
@@ -643,8 +694,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
return []
|
return []
|
||||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||||
return list(ANTHROPIC_MODELS)
|
return list(ANTHROPIC_MODELS)
|
||||||
url = build_models_url(base)
|
url = _safe_build_models_url(base)
|
||||||
headers = build_headers(api_key, base)
|
headers = _safe_build_headers(api_key, base)
|
||||||
try:
|
try:
|
||||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
@@ -702,7 +753,7 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
"""Reachability probe that does not require installed/listed models."""
|
"""Reachability probe that does not require installed/listed models."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
base = resolve_url(_normalize_base(base_url))
|
base = resolve_url(_normalize_base(base_url))
|
||||||
headers = build_headers(api_key, base)
|
headers = _safe_build_headers(api_key, base)
|
||||||
|
|
||||||
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
|
# Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version,
|
||||||
# /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
|
# /api/tags. Probe native paths for Ollama-style endpoints, but avoid using
|
||||||
@@ -752,36 +803,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) ->
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# OpenAI-compatible servers (vLLM, llama.cpp, SGLang, lmdeploy, …) expose
|
|
||||||
# /v1/models but return 404 on the bare /v1 root. The probe used to GET
|
|
||||||
# the base URL only, so a fully-working vLLM endpoint (chats fine!) read
|
|
||||||
# as offline because /v1 → 404. Try /models first; fall back to the base
|
|
||||||
# URL only if /models couldn't be reached (TCP-level failure).
|
|
||||||
models_url = build_models_url(base)
|
|
||||||
try:
|
|
||||||
r = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
|
||||||
result = _result_from_response(r)
|
|
||||||
if result["reachable"]:
|
|
||||||
return result
|
|
||||||
last_error = result.get("error")
|
|
||||||
except Exception as e:
|
|
||||||
last_error = str(e)[:120]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
result = _result_from_response(r)
|
result = _result_from_response(r)
|
||||||
if result["reachable"]:
|
if result["reachable"]:
|
||||||
return result
|
return result
|
||||||
# 4xx from a reachable HTTP server (404 /v1, 401/403 missing key) is
|
|
||||||
# still proof the upstream is alive. Only treat connection-level
|
|
||||||
# failures, 5xx, and redirect-to-/login as truly offline.
|
|
||||||
sc = result.get("status_code") or 0
|
sc = result.get("status_code") or 0
|
||||||
if 400 <= sc < 500 and sc not in (407, 408, 421, 425, 429):
|
if 400 <= sc < 500 and sc not in (401, 403):
|
||||||
return {
|
models_url = _safe_build_models_url(base)
|
||||||
"reachable": True,
|
try:
|
||||||
"status_code": sc,
|
r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
"error": None,
|
result2 = _result_from_response(r2)
|
||||||
}
|
if result2["reachable"]:
|
||||||
|
return result2
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if sc:
|
||||||
|
return result
|
||||||
last_error = result.get("error") or last_error
|
last_error = result.get("error") or last_error
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = str(e)[:120]
|
last_error = str(e)[:120]
|
||||||
@@ -878,6 +916,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None):
|
|||||||
return [m for m in merged if m not in hidden]
|
return [m for m in merged if m not in hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_fingerprint(api_key: Optional[str]) -> str:
|
||||||
|
"""Stable, non-secret label for distinguishing same-URL credentials."""
|
||||||
|
key = (api_key or "").strip()
|
||||||
|
if not key:
|
||||||
|
return ""
|
||||||
|
return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8]
|
||||||
|
|
||||||
|
|
||||||
def setup_model_routes(model_discovery):
|
def setup_model_routes(model_discovery):
|
||||||
router = APIRouter(prefix="/api")
|
router = APIRouter(prefix="/api")
|
||||||
|
|
||||||
@@ -1056,7 +1102,7 @@ def setup_model_routes(model_discovery):
|
|||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
base = _normalize_base(ep.base_url)
|
||||||
provider = _detect_provider(base)
|
provider = _safe_detect_provider(base)
|
||||||
# Merge cached + pinned models, then filter out hidden ones
|
# Merge cached + pinned models, then filter out hidden ones
|
||||||
ep_model_type = getattr(ep, "model_type", None) or "llm"
|
ep_model_type = getattr(ep, "model_type", None) or "llm"
|
||||||
model_ids = _visible_models(
|
model_ids = _visible_models(
|
||||||
@@ -1132,8 +1178,9 @@ def setup_model_routes(model_discovery):
|
|||||||
raise HTTPException(401, "Not authenticated")
|
raise HTTPException(401, "Not authenticated")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.error("Auth gate error in GET /api/models, failing closed: %s", e)
|
||||||
|
raise HTTPException(status_code=500, detail="Internal error")
|
||||||
# Admins see every endpoint (they manage the global pool); regular
|
# Admins see every endpoint (they manage the global pool); regular
|
||||||
# users get the owner-scoped view.
|
# users get the owner-scoped view.
|
||||||
_is_admin = False
|
_is_admin = False
|
||||||
@@ -1242,7 +1289,7 @@ def setup_model_routes(model_discovery):
|
|||||||
results = []
|
results = []
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
base = _normalize_base(ep.base_url)
|
||||||
provider = _detect_provider(base)
|
provider = _safe_detect_provider(base)
|
||||||
kind = _effective_endpoint_kind(ep, base)
|
kind = _effective_endpoint_kind(ep, base)
|
||||||
cached_count = len(_cached_model_ids(ep))
|
cached_count = len(_cached_model_ids(ep))
|
||||||
entry = {
|
entry = {
|
||||||
@@ -1457,6 +1504,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"name": r.name,
|
"name": r.name,
|
||||||
"base_url": r.base_url,
|
"base_url": r.base_url,
|
||||||
"has_key": bool(r.api_key),
|
"has_key": bool(r.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(r.api_key),
|
||||||
"is_enabled": r.is_enabled,
|
"is_enabled": r.is_enabled,
|
||||||
"models": visible,
|
"models": visible,
|
||||||
"pinned_models": pinned,
|
"pinned_models": pinned,
|
||||||
@@ -1529,15 +1577,27 @@ def setup_model_routes(model_discovery):
|
|||||||
# re-adding manually-added endpoints under their host:port name.
|
# re-adding manually-added endpoints under their host:port name.
|
||||||
from src.auth_helpers import get_current_user as _gcu_dedup
|
from src.auth_helpers import get_current_user as _gcu_dedup
|
||||||
_caller = _gcu_dedup(request) or None
|
_caller = _gcu_dedup(request) or None
|
||||||
|
_incoming_api_key = api_key.strip()
|
||||||
_db_dedup = SessionLocal()
|
_db_dedup = SessionLocal()
|
||||||
try:
|
try:
|
||||||
existing = (
|
_same_url_rows = (
|
||||||
_db_dedup.query(ModelEndpoint)
|
_db_dedup.query(ModelEndpoint)
|
||||||
.filter(ModelEndpoint.base_url == base_url)
|
.filter(ModelEndpoint.base_url == base_url)
|
||||||
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
.filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller))
|
||||||
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
.order_by(ModelEndpoint.owner.desc()) # prefer owned over shared
|
||||||
.first()
|
.all()
|
||||||
)
|
)
|
||||||
|
existing = None
|
||||||
|
_empty_key_existing = None
|
||||||
|
for _candidate in _same_url_rows:
|
||||||
|
_candidate_key = (getattr(_candidate, "api_key", None) or "").strip()
|
||||||
|
if _candidate_key == _incoming_api_key:
|
||||||
|
existing = _candidate
|
||||||
|
break
|
||||||
|
if _incoming_api_key and not _candidate_key and _empty_key_existing is None:
|
||||||
|
_empty_key_existing = _candidate
|
||||||
|
if existing is None and _incoming_api_key and _empty_key_existing is not None:
|
||||||
|
existing = _empty_key_existing
|
||||||
if existing:
|
if existing:
|
||||||
changed = False
|
changed = False
|
||||||
# Persist any incoming pinned IDs onto the existing row. An
|
# Persist any incoming pinned IDs onto the existing row. An
|
||||||
@@ -1586,6 +1646,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": existing.id,
|
"id": existing.id,
|
||||||
"name": existing.name,
|
"name": existing.name,
|
||||||
"base_url": existing.base_url,
|
"base_url": existing.base_url,
|
||||||
|
"has_key": bool(existing.api_key),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(existing.api_key),
|
||||||
"models": _visible_models(
|
"models": _visible_models(
|
||||||
existing_models,
|
existing_models,
|
||||||
getattr(existing, "hidden_models", None),
|
getattr(existing, "hidden_models", None),
|
||||||
@@ -1659,6 +1721,8 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep_id,
|
"id": ep_id,
|
||||||
"name": name.strip(),
|
"name": name.strip(),
|
||||||
"base_url": base_url,
|
"base_url": base_url,
|
||||||
|
"has_key": bool(api_key.strip()),
|
||||||
|
"api_key_fingerprint": _api_key_fingerprint(api_key),
|
||||||
"models": _merge_model_ids(model_ids, _pinned),
|
"models": _merge_model_ids(model_ids, _pinned),
|
||||||
"pinned_models": _pinned,
|
"pinned_models": _pinned,
|
||||||
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
"online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")),
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ from typing import Any, Dict, Optional
|
|||||||
import httpx
|
import httpx
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
|
||||||
|
|
||||||
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
||||||
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
||||||
or "https://chatgpt.com/backend-api/codex"
|
or "https://chatgpt.com/backend-api/codex"
|
||||||
@@ -33,6 +31,11 @@ _AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
|
|||||||
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _database_handles():
|
||||||
|
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
return ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
|
||||||
|
|
||||||
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
||||||
with _AUTH_REFRESH_LOCKS_GUARD:
|
with _AUTH_REFRESH_LOCKS_GUARD:
|
||||||
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
||||||
@@ -249,6 +252,7 @@ def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCE
|
|||||||
|
|
||||||
|
|
||||||
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
||||||
|
ProviderAuthSession, SessionLocal, utcnow_naive = _database_handles()
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
q = db.query(ProviderAuthSession).filter(
|
q = db.query(ProviderAuthSession).filter(
|
||||||
|
|||||||
Reference in New Issue
Block a user