mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-28 07:35:27 -04:00
feat: add ChatGPT Subscription provider (#2876)
* feat: Add ChatGPT Subscription support and related features - Introduced a new provider option for ChatGPT Subscription in the endpoint selection UI. - Implemented OAuth flow for ChatGPT Subscription sign-in, including polling for authorization status. - Updated admin interface to handle ChatGPT Subscription, including disabling API key input and providing user guidance. - Enhanced cost tracking logic to differentiate between subscription and non-subscription endpoints. - Added new slash commands for managing skills, including listing, searching, and invoking skills. - Implemented caching for skill catalog to optimize performance. - Updated tests to cover new ChatGPT Subscription functionality and ensure proper endpoint probing. - Refactored existing code to accommodate new features and improve maintainability. * refactor: share provider device-flow setup - reuse one device-flow backend for Copilot and ChatGPT Subscription - add one frontend device-flow helper for Settings and /setup - put GitHub Copilot back into Add Models, now as a dropdown option - make provider selection just select; clicking Add starts sign-in - stop ChatGPT Subscription setup from opening auth tabs automatically - make /setup copilot and /setup chatgpt-subscription work from chat - show ChatGPT Subscription in the /setup suggestions - show the real error message when setup fails - add focused tests for the shared flow and setup UI * feat(chatgpt-subscription): harden credential lifecycle and streamline auth UX Backend: - Resolve runtime bearer for provider-auth endpoints at probe time via a shared _resolve_probe_key() that delegates to resolve_endpoint_runtime, applied across all probe/refresh call sites. - Skip live completion probes and health pings for discovery-only providers (centralized behind _is_discovery_only_provider) — the Codex/Responses API has no such endpoints, so status is derived from cached models. - Never persist the short lived ChatGPT bearer to the plaintext sessions table; proactively clear any stale bearer left by an earlier code path. - Revoke orphaned ProviderAuthSession credentials when the last endpoint backing them is deleted (_delete_orphaned_provider_auth), surfaced via cleared_provider_auth in the delete response. Frontend (admin.js): - Auto-start the device-auth flow on provider selection so the authorization panel (code + Authorize) shows immediately instead of behind a "Sign in" click. - Remove the redundant top button for device auth providers, move retry into the panel via an inline "Try again". - Drop the self-evident hint text and add an execCommand clipboard fallback so Copy works in non-secure (HTTP/LAN) contexts. * fix: harden chatgpt subscription provider * chore: remove PR media from branch * Fix chatgpt subscription recovery and token handling --------- Co-authored-by: 5p00kyy <admin@5p00ky.dev>
This commit is contained in:
+39
-25
@@ -57,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
||||
# Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
|
||||
|
||||
|
||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||
@@ -98,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
provider = _detect_provider(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
if provider == "anthropic":
|
||||
# Anthropic: match against hardcoded model list
|
||||
@@ -114,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
else:
|
||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||
try:
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
models_url = build_models_url(base)
|
||||
if models_url:
|
||||
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
model_ids = json.loads(ep.cached_models or "[]")
|
||||
except Exception:
|
||||
model_ids = []
|
||||
|
||||
@@ -1121,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
|
||||
total_models = 0
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
provider = _detect_provider(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
model_ids = []
|
||||
if provider == "anthropic":
|
||||
model_ids = list(ANTHROPIC_MODELS)
|
||||
else:
|
||||
try:
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
models_url = build_models_url(base)
|
||||
if models_url:
|
||||
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
model_ids = json.loads(ep.cached_models or "[]")
|
||||
except Exception:
|
||||
model_ids = ["(endpoint offline)"]
|
||||
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
"""ChatGPT subscription / Codex backend OAuth helpers.
|
||||
|
||||
This provider is intentionally separate from OpenAI API-key endpoints. It uses
|
||||
OpenAI account OAuth device authorization, stores refresh tokens server-side,
|
||||
and resolves a fresh bearer token at request time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
||||
|
||||
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
||||
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
||||
or "https://chatgpt.com/backend-api/codex"
|
||||
)
|
||||
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
|
||||
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
|
||||
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
|
||||
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
|
||||
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
||||
with _AUTH_REFRESH_LOCKS_GUARD:
|
||||
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_AUTH_REFRESH_LOCKS[auth_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatGPTSubscriptionError(RuntimeError):
|
||||
"""Base error for ChatGPT subscription provider failures."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
|
||||
"""Stored OAuth credentials are invalid or expired beyond refresh."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
|
||||
"""Upstream quota/rate limit; reconnecting will not fix it."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
|
||||
"""No matching owner-scoped auth session exists."""
|
||||
|
||||
|
||||
def is_chatgpt_subscription_base(url: str) -> bool:
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url or "")
|
||||
host = (parsed.hostname or "").lower().rstrip(".")
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
except Exception:
|
||||
return False
|
||||
return host == "chatgpt.com" and (
|
||||
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
|
||||
)
|
||||
|
||||
|
||||
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Origin": "https://chatgpt.com",
|
||||
"Referer": "https://chatgpt.com/codex",
|
||||
"User-Agent": "Odysseus ChatGPT Subscription",
|
||||
}
|
||||
if access_token:
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return headers
|
||||
|
||||
|
||||
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
|
||||
if not access_token:
|
||||
return []
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||
headers=chatgpt_headers(access_token),
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
data = response.json()
|
||||
except Exception:
|
||||
return []
|
||||
entries = data.get("models", []) if isinstance(data, dict) else []
|
||||
sortable: list[tuple[int, str]] = []
|
||||
for item in entries:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
slug = item.get("slug")
|
||||
if not isinstance(slug, str) or not slug.strip():
|
||||
continue
|
||||
visibility = item.get("visibility", "")
|
||||
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
|
||||
continue
|
||||
priority = item.get("priority")
|
||||
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||
sortable.append((rank, slug.strip()))
|
||||
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||
ordered: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _, slug in sortable:
|
||||
if slug not in seen:
|
||||
ordered.append(slug)
|
||||
seen.add(slug)
|
||||
return ordered
|
||||
|
||||
|
||||
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
|
||||
if response.status_code < 400:
|
||||
return
|
||||
code = ""
|
||||
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
|
||||
try:
|
||||
payload = response.json()
|
||||
err = payload.get("error") if isinstance(payload, dict) else None
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or err.get("type") or "").strip()
|
||||
msg = err.get("message")
|
||||
if msg:
|
||||
message = f"ChatGPT Subscription {action} failed: {msg}"
|
||||
elif isinstance(err, str):
|
||||
code = err.strip()
|
||||
desc = payload.get("error_description") or payload.get("message")
|
||||
if desc:
|
||||
message = f"ChatGPT Subscription {action} failed: {desc}"
|
||||
except Exception:
|
||||
pass
|
||||
if response.status_code == 429:
|
||||
raise ChatGPTSubscriptionRateLimited(
|
||||
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
|
||||
)
|
||||
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
|
||||
raise ChatGPTSubscriptionReauthRequired(message)
|
||||
raise ChatGPTSubscriptionError(message)
|
||||
|
||||
|
||||
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
|
||||
_raise_for_oauth_response(response, action)
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
|
||||
if not isinstance(data, dict):
|
||||
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
|
||||
return data
|
||||
|
||||
|
||||
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
|
||||
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "device-code request")
|
||||
if not data.get("device_auth_id") or not data.get("user_code"):
|
||||
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
|
||||
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
|
||||
data.setdefault("interval", 5)
|
||||
data.setdefault("expires_in", 900)
|
||||
return data
|
||||
|
||||
|
||||
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
|
||||
json={"device_auth_id": device_auth_id, "user_code": user_code},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code in (403, 404):
|
||||
return {"status": "pending", "error": "authorization_pending"}
|
||||
return _json_or_error(response, "device-code poll")
|
||||
|
||||
|
||||
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": authorization_code,
|
||||
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "token exchange")
|
||||
if not data.get("access_token"):
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
|
||||
return data
|
||||
|
||||
|
||||
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
|
||||
del access_token
|
||||
if not refresh_token:
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
|
||||
response = httpx.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "token refresh")
|
||||
if not data.get("access_token"):
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
|
||||
return data
|
||||
|
||||
|
||||
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
|
||||
parts = (token or "").split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError("not a JWT")
|
||||
segment = parts[1]
|
||||
segment += "=" * (-len(segment) % 4)
|
||||
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
|
||||
payload = json.loads(raw.decode("utf-8"))
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
|
||||
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
|
||||
try:
|
||||
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
|
||||
except Exception:
|
||||
return True
|
||||
return exp <= int(time.time()) + int(skew_seconds)
|
||||
|
||||
|
||||
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(ProviderAuthSession).filter(
|
||||
ProviderAuthSession.id == auth_id,
|
||||
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
)
|
||||
if owner:
|
||||
q = q.filter(ProviderAuthSession.owner == owner)
|
||||
row = q.first()
|
||||
if row is None:
|
||||
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
|
||||
|
||||
access_token = row.access_token or ""
|
||||
if force_refresh or access_token_is_expiring(access_token):
|
||||
with _refresh_lock_for(auth_id):
|
||||
db.refresh(row)
|
||||
access_token = row.access_token or ""
|
||||
refresh_token = row.refresh_token or ""
|
||||
if force_refresh or access_token_is_expiring(access_token):
|
||||
refreshed = refresh_oauth_tokens(access_token, refresh_token)
|
||||
row.access_token = refreshed["access_token"]
|
||||
if refreshed.get("refresh_token"):
|
||||
row.refresh_token = refreshed["refresh_token"]
|
||||
row.last_refresh = utcnow_naive()
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
access_token = row.access_token or ""
|
||||
|
||||
return {
|
||||
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
|
||||
"api_key": access_token,
|
||||
"auth_mode": row.auth_mode or "chatgpt",
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def to_http_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ChatGPTSubscriptionRateLimited):
|
||||
return HTTPException(429, str(exc))
|
||||
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
|
||||
return HTTPException(401, f"{exc} Reconnect the provider.")
|
||||
return HTTPException(502, str(exc))
|
||||
|
||||
|
||||
def build_responses_input(messages: list[dict]) -> list[dict]:
|
||||
input_items: list[dict] = []
|
||||
for msg in messages or []:
|
||||
role = msg.get("role") or "user"
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = "" if content is None else str(content)
|
||||
input_type = "output_text" if role == "assistant" else "input_text"
|
||||
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
|
||||
return input_items
|
||||
@@ -70,6 +70,25 @@ def _endpoint_enabled_models(ep) -> list:
|
||||
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
|
||||
|
||||
|
||||
def resolve_endpoint_runtime(ep, owner: Optional[str] = None) -> Tuple[str, Optional[str]]:
|
||||
"""Resolve a ModelEndpoint row to its runtime base URL and bearer/API key.
|
||||
|
||||
Static-key providers use ``ModelEndpoint.api_key``. Session-backed providers
|
||||
store refreshable credentials in ProviderAuthSession and must resolve a
|
||||
current access token at call time.
|
||||
"""
|
||||
base = normalize_base(getattr(ep, "base_url", "") or "")
|
||||
api_key = getattr(ep, "api_key", None)
|
||||
auth_id = getattr(ep, "provider_auth_id", None)
|
||||
if auth_id:
|
||||
from src.chatgpt_subscription import resolve_runtime_credentials
|
||||
|
||||
creds = resolve_runtime_credentials(auth_id, owner=owner)
|
||||
base = normalize_base(creds.get("base_url") or base)
|
||||
api_key = creds.get("api_key")
|
||||
return base, api_key
|
||||
|
||||
|
||||
# Cache for Tailscale hostname → IP resolution
|
||||
_tailscale_cache: Dict[str, Optional[str]] = {}
|
||||
|
||||
@@ -133,7 +152,7 @@ def resolve_url(url: str) -> str:
|
||||
def normalize_base(url: str) -> str:
|
||||
"""Strip known API path suffixes from a base URL."""
|
||||
url = (url or "").strip().rstrip("/")
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages", "/responses"]:
|
||||
if url.endswith(suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
for suffix in ["/chat", "/tags", "/generate"]:
|
||||
@@ -158,10 +177,12 @@ def build_chat_url(base: str) -> str:
|
||||
return _anthropic_api_root(base) + "/v1/messages"
|
||||
if provider == "ollama":
|
||||
return _ollama_api_root(base) + "/chat"
|
||||
if provider == "chatgpt-subscription":
|
||||
return base.rstrip("/") + "/responses"
|
||||
return base + "/chat/completions"
|
||||
|
||||
|
||||
def build_models_url(base: str) -> str:
|
||||
def build_models_url(base: str) -> Optional[str]:
|
||||
"""Return the provider-specific model-list endpoint URL for a base."""
|
||||
base = resolve_url(base)
|
||||
provider = _detect_provider(base)
|
||||
@@ -169,6 +190,8 @@ def build_models_url(base: str) -> str:
|
||||
return _anthropic_api_root(base) + "/v1/models"
|
||||
if provider == "ollama":
|
||||
return _ollama_api_root(base) + "/tags"
|
||||
if provider == "chatgpt-subscription":
|
||||
return None
|
||||
return base + "/models"
|
||||
|
||||
|
||||
@@ -184,6 +207,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
||||
if provider == "copilot":
|
||||
from src.copilot import copilot_headers
|
||||
return copilot_headers(api_key)
|
||||
if provider == "chatgpt-subscription":
|
||||
from src.chatgpt_subscription import chatgpt_headers
|
||||
return chatgpt_headers(api_key)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if provider == "openrouter":
|
||||
@@ -262,9 +288,13 @@ def resolve_endpoint(
|
||||
if not ep:
|
||||
return fallback_url, fallback_model, fallback_headers
|
||||
|
||||
base = normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
|
||||
return fallback_url, fallback_model, fallback_headers
|
||||
chat_url = build_chat_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
# Discard a configured model the user has since disabled on the
|
||||
# endpoint (e.g. a stale `default_model` left pointing at a now-hidden
|
||||
@@ -308,9 +338,13 @@ def resolve_endpoint_by_id(
|
||||
ep = q.first()
|
||||
if not ep:
|
||||
return None
|
||||
base = normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
|
||||
return None
|
||||
chat_url = build_chat_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
m = (model or "").strip()
|
||||
# Drop a model the user disabled on the endpoint, then pick the first
|
||||
# enabled chat model rather than a hidden one.
|
||||
|
||||
+217
-7
@@ -426,6 +426,9 @@ def _detect_provider(url: str) -> str:
|
||||
return "openrouter"
|
||||
if _host_match(url, "groq.com"):
|
||||
return "groq"
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
if is_chatgpt_subscription_base(url):
|
||||
return "chatgpt-subscription"
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(url):
|
||||
return "copilot"
|
||||
@@ -462,6 +465,8 @@ def _provider_label(url: str) -> str:
|
||||
if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go"
|
||||
if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen"
|
||||
if _host_match(url, "groq.com"): return "Groq"
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
if is_chatgpt_subscription_base(url): return "ChatGPT Subscription"
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(url): return "GitHub Copilot"
|
||||
if _host_match(url, "mistral.ai"): return "Mistral"
|
||||
@@ -479,6 +484,77 @@ def _provider_label(url: str) -> str:
|
||||
return host or "provider"
|
||||
|
||||
|
||||
def _normalize_chatgpt_subscription_url(url: str) -> str:
|
||||
base = (url or "").strip().rstrip("/")
|
||||
if base.endswith("/responses"):
|
||||
return base
|
||||
return base + "/responses"
|
||||
|
||||
|
||||
def _message_content_as_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
if part:
|
||||
parts.append(str(part))
|
||||
continue
|
||||
if isinstance(part.get("text"), str):
|
||||
parts.append(part["text"])
|
||||
continue
|
||||
if isinstance(part.get("content"), str):
|
||||
parts.append(part["content"])
|
||||
return "\n".join(parts)
|
||||
return "" if content is None else str(content)
|
||||
|
||||
|
||||
def _chatgpt_subscription_instructions(messages: List[Dict]) -> str:
|
||||
instructions = [
|
||||
_message_content_as_text(msg.get("content")).strip()
|
||||
for msg in messages or []
|
||||
if (msg.get("role") or "") == "system"
|
||||
]
|
||||
instructions = [part for part in instructions if part]
|
||||
if instructions:
|
||||
return "\n\n".join(instructions)
|
||||
return "You are a helpful AI assistant."
|
||||
|
||||
|
||||
def _build_chatgpt_responses_payload(
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
*,
|
||||
stream: bool = False,
|
||||
) -> Dict:
|
||||
from src.chatgpt_subscription import build_responses_input
|
||||
|
||||
conversation = [msg for msg in (messages or []) if (msg.get("role") or "") != "system"]
|
||||
payload: Dict = {
|
||||
"model": model,
|
||||
"instructions": _chatgpt_subscription_instructions(messages),
|
||||
"input": build_responses_input(conversation),
|
||||
"stream": stream,
|
||||
"store": False,
|
||||
}
|
||||
if not _restricts_temperature(model):
|
||||
payload["temperature"] = temperature
|
||||
if max_tokens and max_tokens > 0:
|
||||
payload["max_output_tokens"] = max_tokens
|
||||
return payload
|
||||
|
||||
|
||||
def _format_chatgpt_subscription_error(status_code: int, text: str) -> str:
|
||||
if status_code in (401, 403):
|
||||
return "ChatGPT Subscription credentials expired or were rejected. Reconnect the provider."
|
||||
if status_code == 429:
|
||||
return "ChatGPT Subscription quota or rate limit was reached. Retry after the upstream limit resets."
|
||||
return _format_upstream_error(status_code, text, "https://chatgpt.com/backend-api/codex")
|
||||
|
||||
|
||||
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
|
||||
"""Turn an upstream HTTP error into a user-readable sentence.
|
||||
|
||||
@@ -874,7 +950,7 @@ def _normalize_anthropic_url(url: str) -> str:
|
||||
def _model_list_base(url: str) -> str:
|
||||
"""Normalize model/chat URLs to the configured endpoint base."""
|
||||
base = (url or "").strip().rstrip("/")
|
||||
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
|
||||
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages", "/responses"):
|
||||
if base.endswith(suffix):
|
||||
base = base[: -len(suffix)].rstrip("/")
|
||||
for suffix in ("/chat", "/tags", "/generate"):
|
||||
@@ -903,7 +979,12 @@ def _parse_model_cache(raw) -> List[str]:
|
||||
return out
|
||||
|
||||
|
||||
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
def _configured_cached_model_ids(
|
||||
endpoint_url: str,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Return cached models for a configured endpoint matching endpoint_url."""
|
||||
target = _model_list_base(endpoint_url)
|
||||
if not target:
|
||||
@@ -914,7 +995,13 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
return []
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if endpoint_id:
|
||||
q = q.filter(ModelEndpoint.id == endpoint_id)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
rows = q.all()
|
||||
for ep in rows:
|
||||
if _model_list_base(getattr(ep, "base_url", "")) != target:
|
||||
continue
|
||||
@@ -933,9 +1020,16 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
||||
def list_model_ids(
|
||||
base_chat_url: str,
|
||||
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||
headers: Optional[Dict] = None,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""List available model IDs from an endpoint."""
|
||||
cached = _configured_cached_model_ids(base_chat_url)
|
||||
cached = _configured_cached_model_ids(base_chat_url, owner=owner, endpoint_id=endpoint_id)
|
||||
if cached:
|
||||
return cached
|
||||
provider = _detect_provider(base_chat_url)
|
||||
@@ -971,9 +1065,16 @@ def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||
pass
|
||||
return []
|
||||
|
||||
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
|
||||
def normalize_model_id(
|
||||
endpoint_url: str,
|
||||
requested: str,
|
||||
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Normalize a model ID to match available models."""
|
||||
avail = list_model_ids(endpoint_url, timeout)
|
||||
avail = list_model_ids(endpoint_url, timeout, owner=owner, endpoint_id=endpoint_id)
|
||||
if not avail:
|
||||
return None
|
||||
if requested in avail:
|
||||
@@ -1169,6 +1270,49 @@ async def llm_call_async(
|
||||
logger.debug(f"Returning cached response for key: {cache_key}")
|
||||
return cached_response
|
||||
|
||||
if provider == "chatgpt-subscription":
|
||||
# ChatGPT/Codex requires streamed Responses requests even for callers
|
||||
# that want a plain string (auto-title, memory extraction, etc.).
|
||||
# Reuse stream_llm's validated Codex SSE path and collect deltas.
|
||||
parts: List[str] = []
|
||||
async for chunk in stream_llm(
|
||||
url,
|
||||
model,
|
||||
messages_copy,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
):
|
||||
event_is_error = False
|
||||
for line in str(chunk).splitlines():
|
||||
if line.startswith("event:"):
|
||||
event_is_error = line[6:].strip() == "error"
|
||||
continue
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
raw = line[5:].strip()
|
||||
if not raw:
|
||||
continue
|
||||
if raw == "[DONE]":
|
||||
response = "".join(parts)
|
||||
_set_cached_response(cache_key, response)
|
||||
return response
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if event_is_error or data.get("error") or (data.get("status") and data.get("text")):
|
||||
status = int(data.get("status") or 502)
|
||||
text = data.get("text") or data.get("error") or "ChatGPT Subscription request failed"
|
||||
raise HTTPException(status, text)
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, str):
|
||||
parts.append(delta)
|
||||
response = "".join(parts)
|
||||
_set_cached_response(cache_key, response)
|
||||
return response
|
||||
|
||||
if provider == "anthropic":
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
@@ -1294,6 +1438,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
model, messages_copy, temperature, max_tokens,
|
||||
stream=True, tools=tools, num_ctx=get_context_length(url, model),
|
||||
)
|
||||
elif provider == "chatgpt-subscription":
|
||||
target_url = _normalize_chatgpt_subscription_url(url)
|
||||
h = _provider_headers(provider, headers)
|
||||
payload = _build_chatgpt_responses_payload(model, messages_copy, temperature, max_tokens, stream=True)
|
||||
else:
|
||||
target_url = url
|
||||
payload = {
|
||||
@@ -1325,6 +1473,68 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
return
|
||||
note_model_activity(target_url, model)
|
||||
|
||||
# ── ChatGPT Subscription / Codex Responses streaming ──
|
||||
if provider == "chatgpt-subscription":
|
||||
event_name = ""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
try:
|
||||
client = _get_http_client()
|
||||
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||
_clear_host_dead(target_url)
|
||||
if r.status_code != 200:
|
||||
raw = (await r.aread()).decode(errors="replace")
|
||||
friendly = _format_chatgpt_subscription_error(r.status_code, raw)
|
||||
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
||||
return
|
||||
async for line in r.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("event:"):
|
||||
event_name = line[6:].strip()
|
||||
continue
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
raw = line[5:].strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
evt = data.get("type") or event_name
|
||||
if evt == "response.output_text.delta":
|
||||
delta = data.get("delta") or ""
|
||||
if delta:
|
||||
yield f'data: {json.dumps({"delta": delta})}\n\n'
|
||||
elif evt == "response.completed":
|
||||
usage = (data.get("response") or {}).get("usage") or data.get("usage") or {}
|
||||
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or input_tokens
|
||||
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") or output_tokens
|
||||
if input_tokens or output_tokens:
|
||||
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": input_tokens, "output_tokens": output_tokens}})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
elif evt in ("response.failed", "error"):
|
||||
err = data.get("error") or (data.get("response") or {}).get("error") or {}
|
||||
text = err.get("message") if isinstance(err, dict) else str(err or "ChatGPT Subscription request failed")
|
||||
yield f'event: error\ndata: {json.dumps({"status": 502, "text": text})}\n\n'
|
||||
return
|
||||
yield "data: [DONE]\n\n"
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
_cooled = _mark_host_dead(target_url)
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"ChatGPT Subscription stream connect to {target_url} failed: {e}{_tail}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
||||
except httpx.ReadTimeout:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
||||
except httpx.NetworkError:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
||||
except Exception as e:
|
||||
logger.error(f"ChatGPT Subscription stream error: {e}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
||||
return
|
||||
|
||||
# ── Native Ollama streaming ──
|
||||
if provider == "ollama":
|
||||
_ollama_tool_calls: List[Dict] = []
|
||||
|
||||
Reference in New Issue
Block a user