diff --git a/routes/model_routes.py b/routes/model_routes.py index 6b76dc71f..133758e82 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -4,6 +4,7 @@ import os import re import uuid import json +import hashlib import socket import time as _time import logging @@ -502,9 +503,53 @@ def _is_chat_model(model_id: str) -> bool: return True +def _safe_detect_provider(base_url: str) -> str: + """Best-effort provider detection that must not break endpoint probing.""" + try: + return _detect_provider(base_url) + except Exception as exc: + logger.debug("Provider detection failed for %s: %s", base_url, exc) + return "" + + +def _safe_build_models_url(base_url: str) -> str: + """Build a /models URL without letting optional provider imports break probes.""" + try: + return build_models_url(base_url) + except Exception as exc: + logger.debug("Model URL detection failed for %s: %s", base_url, exc) + return f"{(base_url or '').rstrip('/')}/models" + + +def _safe_build_headers(api_key: Optional[str], base_url: str) -> dict: + """Build auth headers without letting optional provider imports break probes.""" + try: + return build_headers(api_key, base_url) + except Exception as exc: + logger.debug("Header detection failed for %s: %s", base_url, exc) + return {"Authorization": f"Bearer {api_key}"} if api_key else {} + + +def _is_discovery_only_provider(provider: str) -> bool: + return provider == "chatgpt-subscription" + + +def _resolve_probe_key(ep) -> Optional[str]: + """API key/bearer to probe an endpoint with.""" + try: + from src.endpoint_resolver import resolve_endpoint_runtime + _base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None)) + return key + except Exception as exc: + logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), exc) + return None + + def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict: """Send a realistic completion request to a single model. Returns {status, latency_ms, error?}.""" - provider = _detect_provider(base) + provider = _safe_detect_provider(base) + if _is_discovery_only_provider(provider): + return {"status": "ok", "latency_ms": 0, "skipped": True} messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Say OK"}, @@ -523,12 +568,12 @@ def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 1 elif provider == "ollama": from src.llm_core import _build_ollama_payload target_url = build_chat_url(base) - h = build_headers(api_key, base) + h = _safe_build_headers(api_key, base) h["Content-Type"] = "application/json" payload = _build_ollama_payload(model_id, messages, 0.0, 5, stream=False, tools=_test_tools) else: target_url = build_chat_url(base) - h = build_headers(api_key, base) + h = _safe_build_headers(api_key, base) h["Content-Type"] = "application/json" from src.llm_core import _uses_max_completion_tokens, _restricts_temperature _max_key = "max_completion_tokens" if _uses_max_completion_tokens(model_id) else "max_tokens" @@ -618,9 +663,15 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis For Anthropic, queries their /v1/models API, falling back to hardcoded list.""" from src.endpoint_resolver import resolve_url base = resolve_url(_normalize_base(base_url)) - if _detect_provider(base) == "anthropic": + provider = _safe_detect_provider(base) + if provider == "chatgpt-subscription": + from src.chatgpt_subscription import fetch_available_models + if api_key: + return fetch_available_models(api_key, timeout=timeout) + return [] + if provider == "anthropic": # Try Anthropic's /v1/models endpoint first - url = build_models_url(base) + url = _safe_build_models_url(base) headers = {"anthropic-version": "2023-06-01"} if api_key: headers["x-api-key"] = api_key @@ -643,8 +694,8 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis return [] logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}") return list(ANTHROPIC_MODELS) - url = build_models_url(base) - headers = build_headers(api_key, base) + url = _safe_build_models_url(base) + headers = _safe_build_headers(api_key, base) try: r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify()) r.raise_for_status() @@ -702,7 +753,7 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> """Reachability probe that does not require installed/listed models.""" from src.endpoint_resolver import resolve_url base = resolve_url(_normalize_base(base_url)) - headers = build_headers(api_key, base) + headers = _safe_build_headers(api_key, base) # Ollama exposes /v1/models (OpenAI-compatible) AND native /api/version, # /api/tags. Probe native paths for Ollama-style endpoints, but avoid using @@ -752,36 +803,23 @@ def _ping_endpoint(base_url: str, api_key: str = None, timeout: float = 1.5) -> except Exception: pass - # OpenAI-compatible servers (vLLM, llama.cpp, SGLang, lmdeploy, …) expose - # /v1/models but return 404 on the bare /v1 root. The probe used to GET - # the base URL only, so a fully-working vLLM endpoint (chats fine!) read - # as offline because /v1 → 404. Try /models first; fall back to the base - # URL only if /models couldn't be reached (TCP-level failure). - models_url = build_models_url(base) - try: - r = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify()) - result = _result_from_response(r) - if result["reachable"]: - return result - last_error = result.get("error") - except Exception as e: - last_error = str(e)[:120] - try: r = httpx.get(base, headers=headers, timeout=timeout, verify=llm_verify()) result = _result_from_response(r) if result["reachable"]: return result - # 4xx from a reachable HTTP server (404 /v1, 401/403 missing key) is - # still proof the upstream is alive. Only treat connection-level - # failures, 5xx, and redirect-to-/login as truly offline. sc = result.get("status_code") or 0 - if 400 <= sc < 500 and sc not in (407, 408, 421, 425, 429): - return { - "reachable": True, - "status_code": sc, - "error": None, - } + if 400 <= sc < 500 and sc not in (401, 403): + models_url = _safe_build_models_url(base) + try: + r2 = httpx.get(models_url, headers=headers, timeout=timeout, verify=llm_verify()) + result2 = _result_from_response(r2) + if result2["reachable"]: + return result2 + except Exception: + pass + if sc: + return result last_error = result.get("error") or last_error except Exception as e: last_error = str(e)[:120] @@ -878,6 +916,14 @@ def _visible_models(cached_models, hidden_models, pinned_models=None): return [m for m in merged if m not in hidden] +def _api_key_fingerprint(api_key: Optional[str]) -> str: + """Stable, non-secret label for distinguishing same-URL credentials.""" + key = (api_key or "").strip() + if not key: + return "" + return hashlib.sha256(key.encode("utf-8")).hexdigest()[:8] + + def setup_model_routes(model_discovery): router = APIRouter(prefix="/api") @@ -1056,7 +1102,7 @@ def setup_model_routes(model_discovery): for ep in endpoints: base = _normalize_base(ep.base_url) - provider = _detect_provider(base) + provider = _safe_detect_provider(base) # Merge cached + pinned models, then filter out hidden ones ep_model_type = getattr(ep, "model_type", None) or "llm" model_ids = _visible_models( @@ -1132,8 +1178,9 @@ def setup_model_routes(model_discovery): raise HTTPException(401, "Not authenticated") except HTTPException: raise - except Exception: - pass + except Exception as e: + logger.error("Auth gate error in GET /api/models, failing closed: %s", e) + raise HTTPException(status_code=500, detail="Internal error") # Admins see every endpoint (they manage the global pool); regular # users get the owner-scoped view. _is_admin = False @@ -1242,7 +1289,7 @@ def setup_model_routes(model_discovery): results = [] for ep in endpoints: base = _normalize_base(ep.base_url) - provider = _detect_provider(base) + provider = _safe_detect_provider(base) kind = _effective_endpoint_kind(ep, base) cached_count = len(_cached_model_ids(ep)) entry = { @@ -1457,6 +1504,7 @@ def setup_model_routes(model_discovery): "name": r.name, "base_url": r.base_url, "has_key": bool(r.api_key), + "api_key_fingerprint": _api_key_fingerprint(r.api_key), "is_enabled": r.is_enabled, "models": visible, "pinned_models": pinned, @@ -1529,15 +1577,27 @@ def setup_model_routes(model_discovery): # re-adding manually-added endpoints under their host:port name. from src.auth_helpers import get_current_user as _gcu_dedup _caller = _gcu_dedup(request) or None + _incoming_api_key = api_key.strip() _db_dedup = SessionLocal() try: - existing = ( + _same_url_rows = ( _db_dedup.query(ModelEndpoint) .filter(ModelEndpoint.base_url == base_url) .filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == _caller)) .order_by(ModelEndpoint.owner.desc()) # prefer owned over shared - .first() + .all() ) + existing = None + _empty_key_existing = None + for _candidate in _same_url_rows: + _candidate_key = (getattr(_candidate, "api_key", None) or "").strip() + if _candidate_key == _incoming_api_key: + existing = _candidate + break + if _incoming_api_key and not _candidate_key and _empty_key_existing is None: + _empty_key_existing = _candidate + if existing is None and _incoming_api_key and _empty_key_existing is not None: + existing = _empty_key_existing if existing: changed = False # Persist any incoming pinned IDs onto the existing row. An @@ -1586,6 +1646,8 @@ def setup_model_routes(model_discovery): "id": existing.id, "name": existing.name, "base_url": existing.base_url, + "has_key": bool(existing.api_key), + "api_key_fingerprint": _api_key_fingerprint(existing.api_key), "models": _visible_models( existing_models, getattr(existing, "hidden_models", None), @@ -1659,6 +1721,8 @@ def setup_model_routes(model_discovery): "id": ep_id, "name": name.strip(), "base_url": base_url, + "has_key": bool(api_key.strip()), + "api_key_fingerprint": _api_key_fingerprint(api_key), "models": _merge_model_ids(model_ids, _pinned), "pinned_models": _pinned, "online": bool(model_ids) or bool(_pinned) or bool(ping.get("reachable")), diff --git a/src/chatgpt_subscription.py b/src/chatgpt_subscription.py index 263c4f529..e65ccbc8d 100644 --- a/src/chatgpt_subscription.py +++ b/src/chatgpt_subscription.py @@ -17,8 +17,6 @@ from typing import Any, Dict, Optional import httpx from fastapi import HTTPException -from core.database import ProviderAuthSession, SessionLocal, utcnow_naive - DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = ( os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/") or "https://chatgpt.com/backend-api/codex" @@ -33,6 +31,11 @@ _AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {} _AUTH_REFRESH_LOCKS_GUARD = threading.Lock() +def _database_handles(): + from core.database import ProviderAuthSession, SessionLocal, utcnow_naive + return ProviderAuthSession, SessionLocal, utcnow_naive + + def _refresh_lock_for(auth_id: str) -> threading.Lock: with _AUTH_REFRESH_LOCKS_GUARD: lock = _AUTH_REFRESH_LOCKS.get(auth_id) @@ -249,6 +252,7 @@ def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCE def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]: + ProviderAuthSession, SessionLocal, utcnow_naive = _database_handles() db = SessionLocal() try: q = db.query(ProviderAuthSession).filter(