feat: add ChatGPT Subscription provider (#2876)

* feat: Add ChatGPT Subscription support and related features

- Introduced a new provider option for ChatGPT Subscription in the endpoint selection UI.
- Implemented OAuth flow for ChatGPT Subscription sign-in, including polling for authorization status.
- Updated admin interface to handle ChatGPT Subscription, including disabling API key input and providing user guidance.
- Enhanced cost tracking logic to differentiate between subscription and non-subscription endpoints.
- Added new slash commands for managing skills, including listing, searching, and invoking skills.
- Implemented caching for skill catalog to optimize performance.
- Updated tests to cover new ChatGPT Subscription functionality and ensure proper endpoint probing.
- Refactored existing code to accommodate new features and improve maintainability.

* refactor: share provider device-flow setup

- reuse one device-flow backend for Copilot and ChatGPT Subscription
- add one frontend device-flow helper for Settings and /setup
- put GitHub Copilot back into Add Models, now as a dropdown option
- make provider selection just select; clicking Add starts sign-in
- stop ChatGPT Subscription setup from opening auth tabs automatically
- make /setup copilot and /setup chatgpt-subscription work from chat
- show ChatGPT Subscription in the /setup suggestions
- show the real error message when setup fails
- add focused tests for the shared flow and setup UI

* feat(chatgpt-subscription): harden credential lifecycle and streamline auth UX

Backend:
- Resolve runtime bearer for provider-auth endpoints at probe time via a
  shared _resolve_probe_key() that delegates to resolve_endpoint_runtime,
  applied across all probe/refresh call sites.
- Skip live completion probes and health pings for discovery-only providers
  (centralized behind _is_discovery_only_provider) — the Codex/Responses API
  has no such endpoints, so status is derived from cached models.
- Never persist the short lived ChatGPT bearer to the plaintext sessions
  table; proactively clear any stale bearer left by an earlier code path.
- Revoke orphaned ProviderAuthSession credentials when the last endpoint
  backing them is deleted (_delete_orphaned_provider_auth), surfaced via
  cleared_provider_auth in the delete response.

Frontend (admin.js):
- Auto-start the device-auth flow on provider selection so the authorization
  panel (code + Authorize) shows immediately instead of behind a "Sign in" click.
- Remove the redundant top button for device auth providers, move retry
  into the panel via an inline "Try again".
- Drop the self-evident hint text and add an execCommand clipboard fallback so
  Copy works in non-secure (HTTP/LAN) contexts.

* fix: harden chatgpt subscription provider

* chore: remove PR media from branch

* Fix chatgpt subscription recovery and token handling

---------

Co-authored-by: 5p00kyy <admin@5p00ky.dev>
This commit is contained in:
stocky789
2026-06-08 18:19:18 +10:00
committed by GitHub
parent ac94885c84
commit 1e0d9b92af
37 changed files with 3425 additions and 485 deletions
+39 -25
View File
@@ -57,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
# Model resolution
# ---------------------------------------------------------------------------
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
@@ -98,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
for ep in endpoints:
base = _normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception:
continue
provider = _detect_provider(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
if provider == "anthropic":
# Anthropic: match against hardcoded model list
@@ -114,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
else:
# OpenAI-compatible and native Ollama: probe the provider's model list.
try:
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
models_url = build_models_url(base)
if models_url:
r = httpx.get(models_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
model_ids = json.loads(ep.cached_models or "[]")
except Exception:
model_ids = []
@@ -1121,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
total_models = 0
for ep in endpoints:
base = _normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception:
continue
provider = _detect_provider(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
model_ids = []
if provider == "anthropic":
model_ids = list(ANTHROPIC_MODELS)
else:
try:
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
models_url = build_models_url(base)
if models_url:
r = httpx.get(models_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
model_ids = json.loads(ep.cached_models or "[]")
except Exception:
model_ids = ["(endpoint offline)"]
+311
View File
@@ -0,0 +1,311 @@
"""ChatGPT subscription / Codex backend OAuth helpers.
This provider is intentionally separate from OpenAI API-key endpoints. It uses
OpenAI account OAuth device authorization, stores refresh tokens server-side,
and resolves a fresh bearer token at request time.
"""
from __future__ import annotations
import base64
import json
import os
import threading
import time
from typing import Any, Dict, Optional
import httpx
from fastapi import HTTPException
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
or "https://chatgpt.com/backend-api/codex"
)
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
def _refresh_lock_for(auth_id: str) -> threading.Lock:
with _AUTH_REFRESH_LOCKS_GUARD:
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
if lock is None:
lock = threading.Lock()
_AUTH_REFRESH_LOCKS[auth_id] = lock
return lock
class ChatGPTSubscriptionError(RuntimeError):
"""Base error for ChatGPT subscription provider failures."""
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
"""Stored OAuth credentials are invalid or expired beyond refresh."""
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
"""Upstream quota/rate limit; reconnecting will not fix it."""
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
"""No matching owner-scoped auth session exists."""
def is_chatgpt_subscription_base(url: str) -> bool:
try:
from urllib.parse import urlparse
parsed = urlparse(url or "")
host = (parsed.hostname or "").lower().rstrip(".")
path = (parsed.path or "").rstrip("/")
except Exception:
return False
return host == "chatgpt.com" and (
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
)
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
headers = {
"Accept": "application/json, text/event-stream",
"Origin": "https://chatgpt.com",
"Referer": "https://chatgpt.com/codex",
"User-Agent": "Odysseus ChatGPT Subscription",
}
if access_token:
headers["Authorization"] = f"Bearer {access_token}"
return headers
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
if not access_token:
return []
try:
response = httpx.get(
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
headers=chatgpt_headers(access_token),
timeout=timeout,
)
if response.status_code != 200:
return []
data = response.json()
except Exception:
return []
entries = data.get("models", []) if isinstance(data, dict) else []
sortable: list[tuple[int, str]] = []
for item in entries:
if not isinstance(item, dict):
continue
slug = item.get("slug")
if not isinstance(slug, str) or not slug.strip():
continue
visibility = item.get("visibility", "")
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
continue
priority = item.get("priority")
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
sortable.append((rank, slug.strip()))
sortable.sort(key=lambda item: (item[0], item[1]))
ordered: list[str] = []
seen: set[str] = set()
for _, slug in sortable:
if slug not in seen:
ordered.append(slug)
seen.add(slug)
return ordered
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
if response.status_code < 400:
return
code = ""
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
try:
payload = response.json()
err = payload.get("error") if isinstance(payload, dict) else None
if isinstance(err, dict):
code = str(err.get("code") or err.get("type") or "").strip()
msg = err.get("message")
if msg:
message = f"ChatGPT Subscription {action} failed: {msg}"
elif isinstance(err, str):
code = err.strip()
desc = payload.get("error_description") or payload.get("message")
if desc:
message = f"ChatGPT Subscription {action} failed: {desc}"
except Exception:
pass
if response.status_code == 429:
raise ChatGPTSubscriptionRateLimited(
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
)
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
raise ChatGPTSubscriptionReauthRequired(message)
raise ChatGPTSubscriptionError(message)
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
_raise_for_oauth_response(response, action)
try:
data = response.json()
except Exception as exc:
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
if not isinstance(data, dict):
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
return data
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
response = httpx.post(
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
headers={"Content-Type": "application/json"},
timeout=timeout,
)
data = _json_or_error(response, "device-code request")
if not data.get("device_auth_id") or not data.get("user_code"):
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
data.setdefault("interval", 5)
data.setdefault("expires_in", 900)
return data
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
response = httpx.post(
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
json={"device_auth_id": device_auth_id, "user_code": user_code},
headers={"Content-Type": "application/json"},
timeout=timeout,
)
if response.status_code in (403, 404):
return {"status": "pending", "error": "authorization_pending"}
return _json_or_error(response, "device-code poll")
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
response = httpx.post(
CHATGPT_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "authorization_code",
"code": authorization_code,
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
"client_id": CHATGPT_OAUTH_CLIENT_ID,
"code_verifier": code_verifier,
},
timeout=timeout,
)
data = _json_or_error(response, "token exchange")
if not data.get("access_token"):
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
return data
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
del access_token
if not refresh_token:
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
response = httpx.post(
CHATGPT_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": CHATGPT_OAUTH_CLIENT_ID,
},
timeout=timeout,
)
data = _json_or_error(response, "token refresh")
if not data.get("access_token"):
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
return data
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
parts = (token or "").split(".")
if len(parts) < 2:
raise ValueError("not a JWT")
segment = parts[1]
segment += "=" * (-len(segment) % 4)
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
payload = json.loads(raw.decode("utf-8"))
return payload if isinstance(payload, dict) else {}
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
try:
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
except Exception:
return True
return exp <= int(time.time()) + int(skew_seconds)
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
db = SessionLocal()
try:
q = db.query(ProviderAuthSession).filter(
ProviderAuthSession.id == auth_id,
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
)
if owner:
q = q.filter(ProviderAuthSession.owner == owner)
row = q.first()
if row is None:
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
access_token = row.access_token or ""
if force_refresh or access_token_is_expiring(access_token):
with _refresh_lock_for(auth_id):
db.refresh(row)
access_token = row.access_token or ""
refresh_token = row.refresh_token or ""
if force_refresh or access_token_is_expiring(access_token):
refreshed = refresh_oauth_tokens(access_token, refresh_token)
row.access_token = refreshed["access_token"]
if refreshed.get("refresh_token"):
row.refresh_token = refreshed["refresh_token"]
row.last_refresh = utcnow_naive()
db.commit()
db.refresh(row)
access_token = row.access_token or ""
return {
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
"api_key": access_token,
"auth_mode": row.auth_mode or "chatgpt",
}
finally:
db.close()
def to_http_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ChatGPTSubscriptionRateLimited):
return HTTPException(429, str(exc))
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
return HTTPException(401, f"{exc} Reconnect the provider.")
return HTTPException(502, str(exc))
def build_responses_input(messages: list[dict]) -> list[dict]:
input_items: list[dict] = []
for msg in messages or []:
role = msg.get("role") or "user"
if role == "tool":
role = "user"
content = msg.get("content")
if isinstance(content, list):
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
else:
text = "" if content is None else str(content)
input_type = "output_text" if role == "assistant" else "input_text"
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
return input_items
+40 -6
View File
@@ -70,6 +70,25 @@ def _endpoint_enabled_models(ep) -> list:
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
def resolve_endpoint_runtime(ep, owner: Optional[str] = None) -> Tuple[str, Optional[str]]:
"""Resolve a ModelEndpoint row to its runtime base URL and bearer/API key.
Static-key providers use ``ModelEndpoint.api_key``. Session-backed providers
store refreshable credentials in ProviderAuthSession and must resolve a
current access token at call time.
"""
base = normalize_base(getattr(ep, "base_url", "") or "")
api_key = getattr(ep, "api_key", None)
auth_id = getattr(ep, "provider_auth_id", None)
if auth_id:
from src.chatgpt_subscription import resolve_runtime_credentials
creds = resolve_runtime_credentials(auth_id, owner=owner)
base = normalize_base(creds.get("base_url") or base)
api_key = creds.get("api_key")
return base, api_key
# Cache for Tailscale hostname → IP resolution
_tailscale_cache: Dict[str, Optional[str]] = {}
@@ -133,7 +152,7 @@ def resolve_url(url: str) -> str:
def normalize_base(url: str) -> str:
"""Strip known API path suffixes from a base URL."""
url = (url or "").strip().rstrip("/")
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages", "/responses"]:
if url.endswith(suffix):
url = url[: -len(suffix)].rstrip("/")
for suffix in ["/chat", "/tags", "/generate"]:
@@ -158,10 +177,12 @@ def build_chat_url(base: str) -> str:
return _anthropic_api_root(base) + "/v1/messages"
if provider == "ollama":
return _ollama_api_root(base) + "/chat"
if provider == "chatgpt-subscription":
return base.rstrip("/") + "/responses"
return base + "/chat/completions"
def build_models_url(base: str) -> str:
def build_models_url(base: str) -> Optional[str]:
"""Return the provider-specific model-list endpoint URL for a base."""
base = resolve_url(base)
provider = _detect_provider(base)
@@ -169,6 +190,8 @@ def build_models_url(base: str) -> str:
return _anthropic_api_root(base) + "/v1/models"
if provider == "ollama":
return _ollama_api_root(base) + "/tags"
if provider == "chatgpt-subscription":
return None
return base + "/models"
@@ -184,6 +207,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
if provider == "copilot":
from src.copilot import copilot_headers
return copilot_headers(api_key)
if provider == "chatgpt-subscription":
from src.chatgpt_subscription import chatgpt_headers
return chatgpt_headers(api_key)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if provider == "openrouter":
@@ -262,9 +288,13 @@ def resolve_endpoint(
if not ep:
return fallback_url, fallback_model, fallback_headers
base = normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
return fallback_url, fallback_model, fallback_headers
chat_url = build_chat_url(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
# Discard a configured model the user has since disabled on the
# endpoint (e.g. a stale `default_model` left pointing at a now-hidden
@@ -308,9 +338,13 @@ def resolve_endpoint_by_id(
ep = q.first()
if not ep:
return None
base = normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
return None
chat_url = build_chat_url(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
m = (model or "").strip()
# Drop a model the user disabled on the endpoint, then pick the first
# enabled chat model rather than a hidden one.
+217 -7
View File
@@ -426,6 +426,9 @@ def _detect_provider(url: str) -> str:
return "openrouter"
if _host_match(url, "groq.com"):
return "groq"
from src.chatgpt_subscription import is_chatgpt_subscription_base
if is_chatgpt_subscription_base(url):
return "chatgpt-subscription"
from src.copilot import is_copilot_base
if is_copilot_base(url):
return "copilot"
@@ -462,6 +465,8 @@ def _provider_label(url: str) -> str:
if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go"
if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen"
if _host_match(url, "groq.com"): return "Groq"
from src.chatgpt_subscription import is_chatgpt_subscription_base
if is_chatgpt_subscription_base(url): return "ChatGPT Subscription"
from src.copilot import is_copilot_base
if is_copilot_base(url): return "GitHub Copilot"
if _host_match(url, "mistral.ai"): return "Mistral"
@@ -479,6 +484,77 @@ def _provider_label(url: str) -> str:
return host or "provider"
def _normalize_chatgpt_subscription_url(url: str) -> str:
base = (url or "").strip().rstrip("/")
if base.endswith("/responses"):
return base
return base + "/responses"
def _message_content_as_text(content) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if not isinstance(part, dict):
if part:
parts.append(str(part))
continue
if isinstance(part.get("text"), str):
parts.append(part["text"])
continue
if isinstance(part.get("content"), str):
parts.append(part["content"])
return "\n".join(parts)
return "" if content is None else str(content)
def _chatgpt_subscription_instructions(messages: List[Dict]) -> str:
instructions = [
_message_content_as_text(msg.get("content")).strip()
for msg in messages or []
if (msg.get("role") or "") == "system"
]
instructions = [part for part in instructions if part]
if instructions:
return "\n\n".join(instructions)
return "You are a helpful AI assistant."
def _build_chatgpt_responses_payload(
model: str,
messages: List[Dict],
temperature: float,
max_tokens: int,
*,
stream: bool = False,
) -> Dict:
from src.chatgpt_subscription import build_responses_input
conversation = [msg for msg in (messages or []) if (msg.get("role") or "") != "system"]
payload: Dict = {
"model": model,
"instructions": _chatgpt_subscription_instructions(messages),
"input": build_responses_input(conversation),
"stream": stream,
"store": False,
}
if not _restricts_temperature(model):
payload["temperature"] = temperature
if max_tokens and max_tokens > 0:
payload["max_output_tokens"] = max_tokens
return payload
def _format_chatgpt_subscription_error(status_code: int, text: str) -> str:
if status_code in (401, 403):
return "ChatGPT Subscription credentials expired or were rejected. Reconnect the provider."
if status_code == 429:
return "ChatGPT Subscription quota or rate limit was reached. Retry after the upstream limit resets."
return _format_upstream_error(status_code, text, "https://chatgpt.com/backend-api/codex")
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
"""Turn an upstream HTTP error into a user-readable sentence.
@@ -874,7 +950,7 @@ def _normalize_anthropic_url(url: str) -> str:
def _model_list_base(url: str) -> str:
"""Normalize model/chat URLs to the configured endpoint base."""
base = (url or "").strip().rstrip("/")
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages", "/responses"):
if base.endswith(suffix):
base = base[: -len(suffix)].rstrip("/")
for suffix in ("/chat", "/tags", "/generate"):
@@ -903,7 +979,12 @@ def _parse_model_cache(raw) -> List[str]:
return out
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
def _configured_cached_model_ids(
endpoint_url: str,
*,
owner: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> List[str]:
"""Return cached models for a configured endpoint matching endpoint_url."""
target = _model_list_base(endpoint_url)
if not target:
@@ -914,7 +995,13 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
return []
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if endpoint_id:
q = q.filter(ModelEndpoint.id == endpoint_id)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
rows = q.all()
for ep in rows:
if _model_list_base(getattr(ep, "base_url", "")) != target:
continue
@@ -933,9 +1020,16 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
return []
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
def list_model_ids(
base_chat_url: str,
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
headers: Optional[Dict] = None,
*,
owner: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> List[str]:
"""List available model IDs from an endpoint."""
cached = _configured_cached_model_ids(base_chat_url)
cached = _configured_cached_model_ids(base_chat_url, owner=owner, endpoint_id=endpoint_id)
if cached:
return cached
provider = _detect_provider(base_chat_url)
@@ -971,9 +1065,16 @@ def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT,
pass
return []
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
def normalize_model_id(
endpoint_url: str,
requested: str,
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
*,
owner: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> Optional[str]:
"""Normalize a model ID to match available models."""
avail = list_model_ids(endpoint_url, timeout)
avail = list_model_ids(endpoint_url, timeout, owner=owner, endpoint_id=endpoint_id)
if not avail:
return None
if requested in avail:
@@ -1169,6 +1270,49 @@ async def llm_call_async(
logger.debug(f"Returning cached response for key: {cache_key}")
return cached_response
if provider == "chatgpt-subscription":
# ChatGPT/Codex requires streamed Responses requests even for callers
# that want a plain string (auto-title, memory extraction, etc.).
# Reuse stream_llm's validated Codex SSE path and collect deltas.
parts: List[str] = []
async for chunk in stream_llm(
url,
model,
messages_copy,
temperature=temperature,
max_tokens=max_tokens,
headers=headers,
timeout=timeout,
):
event_is_error = False
for line in str(chunk).splitlines():
if line.startswith("event:"):
event_is_error = line[6:].strip() == "error"
continue
if not line.startswith("data:"):
continue
raw = line[5:].strip()
if not raw:
continue
if raw == "[DONE]":
response = "".join(parts)
_set_cached_response(cache_key, response)
return response
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
if event_is_error or data.get("error") or (data.get("status") and data.get("text")):
status = int(data.get("status") or 502)
text = data.get("text") or data.get("error") or "ChatGPT Subscription request failed"
raise HTTPException(status, text)
delta = data.get("delta")
if isinstance(delta, str):
parts.append(delta)
response = "".join(parts)
_set_cached_response(cache_key, response)
return response
if provider == "anthropic":
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
@@ -1294,6 +1438,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
model, messages_copy, temperature, max_tokens,
stream=True, tools=tools, num_ctx=get_context_length(url, model),
)
elif provider == "chatgpt-subscription":
target_url = _normalize_chatgpt_subscription_url(url)
h = _provider_headers(provider, headers)
payload = _build_chatgpt_responses_payload(model, messages_copy, temperature, max_tokens, stream=True)
else:
target_url = url
payload = {
@@ -1325,6 +1473,68 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
return
note_model_activity(target_url, model)
# ── ChatGPT Subscription / Codex Responses streaming ──
if provider == "chatgpt-subscription":
event_name = ""
input_tokens = 0
output_tokens = 0
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
_clear_host_dead(target_url)
if r.status_code != 200:
raw = (await r.aread()).decode(errors="replace")
friendly = _format_chatgpt_subscription_error(r.status_code, raw)
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
return
async for line in r.aiter_lines():
if not line:
continue
if line.startswith("event:"):
event_name = line[6:].strip()
continue
if not line.startswith("data:"):
continue
raw = line[5:].strip()
if not raw:
continue
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
evt = data.get("type") or event_name
if evt == "response.output_text.delta":
delta = data.get("delta") or ""
if delta:
yield f'data: {json.dumps({"delta": delta})}\n\n'
elif evt == "response.completed":
usage = (data.get("response") or {}).get("usage") or data.get("usage") or {}
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or input_tokens
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") or output_tokens
if input_tokens or output_tokens:
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": input_tokens, "output_tokens": output_tokens}})}\n\n'
yield "data: [DONE]\n\n"
return
elif evt in ("response.failed", "error"):
err = data.get("error") or (data.get("response") or {}).get("error") or {}
text = err.get("message") if isinstance(err, dict) else str(err or "ChatGPT Subscription request failed")
yield f'event: error\ndata: {json.dumps({"status": 502, "text": text})}\n\n'
return
yield "data: [DONE]\n\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
_cooled = _mark_host_dead(target_url)
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
logger.warning(f"ChatGPT Subscription stream connect to {target_url} failed: {e}{_tail}")
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
except httpx.ReadTimeout:
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
except httpx.NetworkError:
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
except Exception as e:
logger.error(f"ChatGPT Subscription stream error: {e}")
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
return
# ── Native Ollama streaming ──
if provider == "ollama":
_ollama_tool_calls: List[Dict] = []