From 1e0d9b92af31d9dc61108b661ccc438716f94fd7 Mon Sep 17 00:00:00 2001 From: stocky789 <113968354+stocky789@users.noreply.github.com> Date: Mon, 8 Jun 2026 18:19:18 +1000 Subject: [PATCH] feat: add ChatGPT Subscription provider (#2876) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add ChatGPT Subscription support and related features - Introduced a new provider option for ChatGPT Subscription in the endpoint selection UI. - Implemented OAuth flow for ChatGPT Subscription sign-in, including polling for authorization status. - Updated admin interface to handle ChatGPT Subscription, including disabling API key input and providing user guidance. - Enhanced cost tracking logic to differentiate between subscription and non-subscription endpoints. - Added new slash commands for managing skills, including listing, searching, and invoking skills. - Implemented caching for skill catalog to optimize performance. - Updated tests to cover new ChatGPT Subscription functionality and ensure proper endpoint probing. - Refactored existing code to accommodate new features and improve maintainability. * refactor: share provider device-flow setup - reuse one device-flow backend for Copilot and ChatGPT Subscription - add one frontend device-flow helper for Settings and /setup - put GitHub Copilot back into Add Models, now as a dropdown option - make provider selection just select; clicking Add starts sign-in - stop ChatGPT Subscription setup from opening auth tabs automatically - make /setup copilot and /setup chatgpt-subscription work from chat - show ChatGPT Subscription in the /setup suggestions - show the real error message when setup fails - add focused tests for the shared flow and setup UI * feat(chatgpt-subscription): harden credential lifecycle and streamline auth UX Backend: - Resolve runtime bearer for provider-auth endpoints at probe time via a shared _resolve_probe_key() that delegates to resolve_endpoint_runtime, applied across all probe/refresh call sites. - Skip live completion probes and health pings for discovery-only providers (centralized behind _is_discovery_only_provider) — the Codex/Responses API has no such endpoints, so status is derived from cached models. - Never persist the short lived ChatGPT bearer to the plaintext sessions table; proactively clear any stale bearer left by an earlier code path. - Revoke orphaned ProviderAuthSession credentials when the last endpoint backing them is deleted (_delete_orphaned_provider_auth), surfaced via cleared_provider_auth in the delete response. Frontend (admin.js): - Auto-start the device-auth flow on provider selection so the authorization panel (code + Authorize) shows immediately instead of behind a "Sign in" click. - Remove the redundant top button for device auth providers, move retry into the panel via an inline "Try again". - Drop the self-evident hint text and add an execCommand clipboard fallback so Copy works in non-secure (HTTP/LAN) contexts. * fix: harden chatgpt subscription provider * chore: remove PR media from branch * Fix chatgpt subscription recovery and token handling --------- Co-authored-by: 5p00kyy --- app.py | 4 + core/database.py | 39 ++ routes/chat_helpers.py | 114 ++++-- routes/chat_routes.py | 69 +++- routes/chatgpt_subscription_routes.py | 170 ++++++++ routes/copilot_routes.py | 184 ++++----- routes/device_flow.py | 193 +++++++++ routes/model_routes.py | 120 +++++- routes/research_routes.py | 65 +-- routes/session_routes.py | 39 +- routes/skills_routes.py | 70 ++++ routes/webhook_routes.py | 31 +- src/ai_interaction.py | 64 +-- src/chatgpt_subscription.py | 311 +++++++++++++++ src/endpoint_resolver.py | 46 ++- src/llm_core.py | 224 ++++++++++- static/index.html | 3 + static/js/admin.js | 269 +++++++++---- static/js/chatRenderer.js | 54 ++- static/js/providerDeviceFlow.js | 128 ++++++ static/js/providers.js | 5 + static/js/slashAutocomplete.js | 52 ++- static/js/slashCommands.js | 422 ++++++++++++++++---- tests/test_admin_device_flow_static.py | 65 +++ tests/test_chatgpt_subscription_routes.py | 280 +++++++++++++ tests/test_device_flow_routes.py | 138 +++++++ tests/test_endpoint_probing.py | 115 ++++-- tests/test_llm_core_temperature.py | 22 + tests/test_model_routes.py | 134 +++++-- tests/test_provider_detection.py | 10 + tests/test_provider_device_flow_js.py | 157 ++++++++ tests/test_research_endpoint_owner_scope.py | 28 +- tests/test_resolve_session_auth_chatgpt.py | 215 ++++++++++ tests/test_review_regressions.py | 2 +- tests/test_session_owner_attribution.py | 9 + tests/test_setup_device_auth_static.py | 42 ++ tests/test_slash_autocomplete_static.py | 17 + 37 files changed, 3425 insertions(+), 485 deletions(-) create mode 100644 routes/chatgpt_subscription_routes.py create mode 100644 routes/device_flow.py create mode 100644 src/chatgpt_subscription.py create mode 100644 static/js/providerDeviceFlow.js create mode 100644 tests/test_admin_device_flow_static.py create mode 100644 tests/test_chatgpt_subscription_routes.py create mode 100644 tests/test_device_flow_routes.py create mode 100644 tests/test_provider_device_flow_js.py create mode 100644 tests/test_resolve_session_auth_chatgpt.py create mode 100644 tests/test_setup_device_auth_static.py create mode 100644 tests/test_slash_autocomplete_static.py diff --git a/app.py b/app.py index 80e9d9f5b..97906bd46 100644 --- a/app.py +++ b/app.py @@ -598,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery)) from routes.copilot_routes import setup_copilot_routes app.include_router(setup_copilot_routes()) +# ChatGPT Subscription device-flow login +from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes +app.include_router(setup_chatgpt_subscription_routes()) + # TTS from routes.tts_routes import setup_tts_routes app.include_router(setup_tts_routes(tts_service)) diff --git a/core/database.py b/core/database.py index 85692e8c5..ee365c30c 100644 --- a/core/database.py +++ b/core/database.py @@ -361,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base): # is the historical default. When non-null, the model picker only shows # the endpoint to that user (admins always see everything). owner = Column(String, nullable=True, index=True) + # Optional OAuth/session-backed credential row. Used by subscription-backed + # providers that need refresh tokens instead of a static API key. + provider_auth_id = Column(String, nullable=True, index=True) + + +class ProviderAuthSession(TimestampMixin, Base): + """Encrypted OAuth/session credentials for refresh-aware model providers.""" + __tablename__ = "provider_auth_sessions" + + id = Column(String, primary_key=True, index=True) + provider = Column(String, nullable=False, index=True) + owner = Column(String, nullable=True, index=True) + label = Column(String, nullable=True) + base_url = Column(String, nullable=False) + access_token = Column(EncryptedText, nullable=True) + refresh_token = Column(EncryptedText, nullable=True) + last_refresh = Column(DateTime, nullable=True) + auth_mode = Column(String, nullable=True) class McpServer(TimestampMixin, Base): """Admin-configured MCP (Model Context Protocol) tool servers.""" @@ -801,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column(): logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}") +def _migrate_add_provider_auth_id_column(): + """Add provider_auth_id column to model_endpoints if it doesn't exist.""" + import sqlite3 + db_path = DATABASE_URL.replace("sqlite:///", "") + if not os.path.exists(db_path): + return + try: + conn = sqlite3.connect(db_path) + cursor = conn.execute("PRAGMA table_info(model_endpoints)") + columns = [row[1] for row in cursor.fetchall()] + if columns and "provider_auth_id" not in columns: + conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR") + conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)") + conn.commit() + logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints") + conn.close() + except Exception as e: + logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}") + + def _migrate_add_model_type_column(): """Add model_type column to model_endpoints if it doesn't exist.""" import sqlite3 @@ -1599,6 +1637,7 @@ def init_db(): _migrate_add_model_type_column() _migrate_add_model_endpoint_refresh_columns() _migrate_add_model_endpoint_owner_column() + _migrate_add_provider_auth_id_column() _migrate_add_supports_tools_column() _migrate_add_task_run_model_column() _migrate_add_owner_column() diff --git a/routes/chat_helpers.py b/routes/chat_helpers.py index 2e5db4478..5c04ab70e 100644 --- a/routes/chat_helpers.py +++ b/routes/chat_helpers.py @@ -196,14 +196,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None: Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None. """ import requests as _req - from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base + from src.endpoint_resolver import ( + build_chat_url, + build_headers, + build_models_url, + normalize_base, + resolve_endpoint_runtime, + ) + from src.chatgpt_subscription import is_chatgpt_subscription_base current_url = sess.endpoint_url or "" + owner = getattr(sess, "owner", None) db = SessionLocal() try: - endpoints = db.query(ModelEndpoint).filter( + q = db.query(ModelEndpoint).filter( ModelEndpoint.is_enabled == True - ).all() + ) + if owner: + from src.auth_helpers import owner_filter + q = owner_filter(q, ModelEndpoint, owner) + endpoints = q.all() finally: db.close() @@ -212,26 +224,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None: # Skip current endpoint if current_url and base in current_url: continue - # Quick ping - ping_url = build_models_url(base) - headers = build_headers(ep.api_key, base) try: - r = _req.get(ping_url, headers=headers, timeout=5) - r.raise_for_status() - data = r.json() - models = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if not models: - models = [ - m.get("name") or m.get("model") - for m in (data.get("models") or []) - if m.get("name") or m.get("model") - ] + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception: + continue + ping_url = build_models_url(base) + headers = build_headers(api_key, base) + try: + if ping_url: + r = _req.get(ping_url, headers=headers, timeout=5) + r.raise_for_status() + data = r.json() + models = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not models: + models = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] + else: + models = json.loads(ep.cached_models or "[]") if not models: continue # Found a working endpoint — update session new_model = models[0] chat_url = build_chat_url(base) - new_headers = build_headers(ep.api_key, base) + new_headers = build_headers(api_key, base) + persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers sess.model = new_model sess.endpoint_url = chat_url @@ -243,7 +262,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None: _db.query(DBSession).filter(DBSession.id == session_id).update({ "model": new_model, "endpoint_url": chat_url, - "headers": json.dumps(new_headers), + "headers": persisted_headers, }) _db.commit() finally: @@ -336,16 +355,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool: return False +def _has_auth_keys(headers) -> bool: + """True if a headers dict carries an Authorization/x-api-key entry.""" + return isinstance(headers, dict) and any( + k.lower() in ('authorization', 'x-api-key') for k in headers + ) + + def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None): """Ensure session has auth headers — resolve from endpoint DB if missing.""" - has_auth = sess.headers and isinstance(sess.headers, dict) and any( - k.lower() in ('authorization', 'x-api-key') for k in sess.headers - ) - if has_auth: + try: + from src.chatgpt_subscription import is_chatgpt_subscription_base + is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "") + except Exception: + is_chatgpt_subscription = False + has_auth = _has_auth_keys(sess.headers) + if has_auth and not is_chatgpt_subscription: return try: - from src.endpoint_resolver import build_headers, normalize_base + from src.endpoint_resolver import build_headers, resolve_endpoint_runtime db = SessionLocal() try: target_url = getattr(sess, "endpoint_url", "") or "" @@ -361,10 +390,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None): for ep in q.all(): if not _session_url_matches_endpoint(target_url, ep.base_url or ""): continue - if not ep.api_key: + try: + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception as e: + logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e) + return + if not api_key: + # No usable key (e.g. ChatGPT Subscription needs re-auth). + return + sess.headers = build_headers(api_key, base) + if is_chatgpt_subscription: + # The bearer is short-lived and re-resolved per request, so it + # stays request-local and is never written to the plaintext + # sessions.headers column. Proactively strip any bearer an + # older code path may have persisted so it does not linger. + stale_q = db.query(DBSession).filter(DBSession.id == session_id) + if owner: + stale_q = stale_q.filter(DBSession.owner == owner) + stored = stale_q.first() + if stored is not None and _has_auth_keys(stored.headers): + stale_q.update({"headers": {}}) + db.commit() + logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}") + logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}") return - base = normalize_base(ep.base_url or "") - sess.headers = build_headers(ep.api_key, base) update_q = db.query(DBSession).filter(DBSession.id == session_id) if owner: update_q = update_q.filter(DBSession.owner == owner) @@ -408,7 +457,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]: db = SessionLocal() try: - endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) + owner = getattr(sess, "owner", None) + if owner: + from src.auth_helpers import owner_filter + q = owner_filter(q, ModelEndpoint, owner) + endpoints = q.all() for ep in endpoints: try: if normalize_base(getattr(ep, "base_url", "") or "") != session_base: @@ -542,7 +596,11 @@ async def build_chat_context( # Normalize model ID. Prefer cached endpoint models so group chat does not # re-hit slow local /models endpoints on every participant turn. - norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model) + norm = _normalize_model_id_from_cache(sess) or normalize_model_id( + sess.endpoint_url, + sess.model, + owner=getattr(sess, "owner", None), + ) if norm: sess.model = norm diff --git a/routes/chat_routes.py b/routes/chat_routes.py index 063c1cba6..a718d3fbe 100644 --- a/routes/chat_routes.py +++ b/routes/chat_routes.py @@ -169,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None Covers the window between endpoint setup and the first chat send: the picker showed a model in the dropdown but the session record never got written (Issue #587 — UI uses the cached endpoint list, not s.model). - Without this, we'd POST the upstream with model="" and get a generic - 401/503 instead of using the model the user already picked. - - Returns True iff sess.model was repaired. + For ChatGPT Subscription, also repairs stale OpenAI API model names such as + ``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route. """ - if getattr(sess, "model", None): - return False + current_model = (getattr(sess, "model", "") or "").strip() + endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip() + is_chatgpt_subscription = False + if current_model: + try: + from src.chatgpt_subscription import is_chatgpt_subscription_base + is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url) + if not is_chatgpt_subscription: + return False + except Exception: + return False db = SessionLocal() try: # Prefer the endpoint whose base URL matches the session — we know the @@ -194,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None break if not ep: return False + if not is_chatgpt_subscription: + try: + from src.chatgpt_subscription import is_chatgpt_subscription_base + is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url) + except Exception: + is_chatgpt_subscription = False try: cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or []) except Exception: cached = [] if not cached: + visible = [] + else: + try: + visible = _visible_models(cached, getattr(ep, "hidden_models", None)) + except Exception: + visible = cached + if current_model and current_model in {str(item).strip() for item in visible}: return False - try: - visible = _visible_models(cached, getattr(ep, "hidden_models", None)) - except Exception: - visible = cached + if is_chatgpt_subscription: + live_models = [] + if getattr(ep, "provider_auth_id", None): + try: + from src.chatgpt_subscription import fetch_available_models + from src.endpoint_resolver import resolve_endpoint_runtime + _base, api_key = resolve_endpoint_runtime(ep, owner=owner) + if api_key: + live_models = fetch_available_models(api_key) + if live_models: + ep.cached_models = json.dumps(live_models) + db.commit() + except Exception: + live_models = [] + # ChatGPT Subscription recovery must use the live Codex catalog. + # Cached rows are only trusted above to avoid revalidating a model + # that is already present in the visible picker list. + cached = live_models + if not cached: + return False + try: + visible = _visible_models(cached, getattr(ep, "hidden_models", None)) + except Exception: + visible = cached + if current_model and current_model in {str(item).strip() for item in visible}: + return False if not visible: return False model = visible[0] @@ -213,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None # Persist so the next request, websocket reconnect, or page reload # picks up the same model (we'd otherwise re-pick on every send # and silently switch on the user if the cached order shifts). - db_session = db.query(DBSession).filter(DBSession.id == session_id).first() + db_session_q = db.query(DBSession).filter(DBSession.id == session_id) + if owner: + db_session_q = db_session_q.filter(DBSession.owner == owner) + db_session = db_session_q.first() if db_session: db_session.model = model db_session.updated_at = datetime.utcnow() db.commit() sess.model = model logger.info( - "Recovered empty session model for %s — picked %r from endpoint %s", + "Recovered session model for %s — picked %r from endpoint %s", session_id, model, ep.id, ) return True diff --git a/routes/chatgpt_subscription_routes.py b/routes/chatgpt_subscription_routes.py new file mode 100644 index 000000000..9c695b371 --- /dev/null +++ b/routes/chatgpt_subscription_routes.py @@ -0,0 +1,170 @@ +"""ChatGPT Subscription device-flow setup routes.""" + +import json +import logging +import uuid +from typing import Dict, Optional + +from fastapi import HTTPException, Request + +from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive +from routes.device_flow import ( + DeviceFlowPoll, + DeviceFlowStart, + PendingDeviceFlowStore, + create_device_flow_router, +) +from src.auth_helpers import get_current_user +from src import chatgpt_subscription + +logger = logging.getLogger(__name__) + +_DEVICE_FLOW_STORE = PendingDeviceFlowStore() + + +def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict: + access_token = tokens.get("access_token") + refresh_token = tokens.get("refresh_token") + if not access_token or not refresh_token: + raise ValueError("ChatGPT token response was missing access_token or refresh_token") + + base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL + models = chatgpt_subscription.fetch_available_models(access_token) + if not models: + raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.") + db = SessionLocal() + try: + auth = ( + db.query(ProviderAuthSession) + .filter( + ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER, + ProviderAuthSession.owner == owner, + ) + .first() + ) + if auth is None: + auth = ProviderAuthSession( + id=str(uuid.uuid4())[:8], + provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER, + owner=owner, + label="ChatGPT Subscription", + base_url=base, + auth_mode="chatgpt", + ) + db.add(auth) + auth.base_url = base + auth.access_token = access_token + auth.refresh_token = refresh_token + auth.last_refresh = utcnow_naive() + auth.auth_mode = "chatgpt" + + ep = ( + db.query(ModelEndpoint) + .filter( + ModelEndpoint.base_url == base, + ModelEndpoint.provider_auth_id == auth.id, + ModelEndpoint.owner == owner, + ) + .first() + ) + if ep is None: + ep = ModelEndpoint( + id=str(uuid.uuid4())[:8], + name="ChatGPT Subscription", + base_url=base, + model_type="llm", + endpoint_kind="api", + owner=owner, + ) + db.add(ep) + ep.name = "ChatGPT Subscription" + ep.base_url = base + ep.api_key = None + ep.provider_auth_id = auth.id + ep.is_enabled = True + ep.supports_tools = False + ep.model_type = "llm" + ep.endpoint_kind = "api" + ep.model_refresh_mode = "manual" + ep.cached_models = json.dumps(models) + db.commit() + result = { + "id": ep.id, + "name": ep.name, + "base_url": ep.base_url, + "models": models, + } + finally: + db.close() + + try: + from routes.model_routes import _invalidate_models_cache + + _invalidate_models_cache() + except Exception: + pass + return result + + +def _start_device_flow(request: Request, _form) -> DeviceFlowStart: + try: + data = chatgpt_subscription.request_device_code() + except Exception as exc: + raise chatgpt_subscription.to_http_exception(exc) + + device_auth_id = data.get("device_auth_id") + user_code = data.get("user_code") + if not device_auth_id or not user_code: + raise HTTPException(502, "ChatGPT did not return a complete device code") + verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device" + return DeviceFlowStart( + pending={ + "device_auth_id": device_auth_id, + "user_code": user_code, + "owner": get_current_user(request) or None, + }, + response={ + "user_code": user_code, + "verification_uri": verification_uri, + }, + interval=int(data.get("interval") or 5), + expires_in=int(data.get("expires_in") or 900), + ) + + +def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll: + try: + data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"]) + except Exception as exc: + logger.debug("ChatGPT device poll failed: %s", exc) + return DeviceFlowPoll.pending(str(exc)) + + authorization_code = data.get("authorization_code") + code_verifier = data.get("code_verifier") + if authorization_code and code_verifier: + try: + tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier) + result = _provision_endpoint(tokens, pending["owner"]) + except Exception as exc: + logger.exception("ChatGPT Subscription endpoint provisioning failed") + raise chatgpt_subscription.to_http_exception(exc) + return DeviceFlowPoll.authorized(result) + + err = data.get("error") or data.get("status") + if err in ("authorization_pending", "pending", None): + return DeviceFlowPoll.pending() + if err == "slow_down": + return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None) + if err in ("expired_token", "access_denied", "denied"): + return DeviceFlowPoll.failed(err) + return DeviceFlowPoll.pending(err or "unknown") + + +def setup_chatgpt_subscription_routes(): + return create_device_flow_router( + prefix="/api/chatgpt-subscription", + tags=["chatgpt-subscription"], + store=_DEVICE_FLOW_STORE, + start_flow=_start_device_flow, + poll_flow=_poll_device_flow, + ) diff --git a/routes/copilot_routes.py b/routes/copilot_routes.py index bb2b1d21f..1d8be52ce 100644 --- a/routes/copilot_routes.py +++ b/routes/copilot_routes.py @@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action). """ import json -import time import uuid import logging -import threading from typing import Dict, Optional import httpx -from fastapi import APIRouter, Request, Form, HTTPException +from fastapi import HTTPException, Request from core.database import SessionLocal, ModelEndpoint -from core.middleware import require_admin +from routes.device_flow import ( + DeviceFlowPoll, + DeviceFlowStart, + PendingDeviceFlowStore, + create_device_flow_router, +) from src.auth_helpers import get_current_user from src import copilot logger = logging.getLogger(__name__) -# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a -# bearer-like secret, so it lives here (server memory) rather than in the -# browser. Entries expire with the GitHub device code. -# -# NOTE: this is per-process state. The device flow assumes a single worker -# (Odysseus' default): with multiple uvicorn workers, the poll request can land -# on a worker that never saw the start, returning "Unknown or expired login -# session". Move this to a shared store (DB/Redis) if running multi-worker. -_PENDING: Dict[str, Dict] = {} -_PENDING_LOCK = threading.Lock() - - -def _prune_expired() -> None: - now = time.time() - with _PENDING_LOCK: - for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]: - _PENDING.pop(k, None) +_DEVICE_FLOW_STORE = PendingDeviceFlowStore() def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict: @@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict: return result -def setup_copilot_routes() -> APIRouter: - router = APIRouter(prefix="/api/copilot", tags=["copilot"]) +def _start_device_flow(request: Request, form) -> DeviceFlowStart: + host = copilot.GITHUB_HOST + ent = str(form.get("enterprise_url") or "").strip() + if ent: + host = copilot.normalize_domain(ent) + try: + data = copilot.request_device_code(host) + except httpx.HTTPStatusError as e: + status = e.response.status_code if e.response is not None else "unknown" + raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})") + except Exception as e: + raise HTTPException(502, f"GitHub device-code request failed: {e}") - @router.post("/device/start") - def device_start(request: Request, enterprise_url: str = Form("")): - require_admin(request) - _prune_expired() - host = copilot.GITHUB_HOST - ent = (enterprise_url or "").strip() - if ent: - host = copilot.normalize_domain(ent) - try: - data = copilot.request_device_code(host) - except httpx.HTTPStatusError as e: - status = e.response.status_code if e.response is not None else "unknown" - raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})") - except Exception as e: - raise HTTPException(502, f"GitHub device-code request failed: {e}") + device_code = data.get("device_code") + if not device_code: + raise HTTPException(502, "GitHub did not return a device code") - device_code = data.get("device_code") - if not device_code: - raise HTTPException(502, "GitHub did not return a device code") - interval = int(data.get("interval") or 5) - expires_in = int(data.get("expires_in") or 900) - poll_id = uuid.uuid4().hex - with _PENDING_LOCK: - _PENDING[poll_id] = { - "device_code": device_code, - "host": host, - "enterprise_url": ent, - "interval": interval, - "owner": get_current_user(request) or None, - "expires_at": time.time() + expires_in, - "next_poll_at": 0.0, - } - # verification_uri_complete embeds the user code, so the browser tab we - # open lands the user straight on GitHub's "Authorize" screen with the - # code pre-filled — one click, no manual code entry. - return { - "poll_id": poll_id, + # verification_uri_complete embeds the user code, so the browser tab we + # open lands the user straight on GitHub's "Authorize" screen with the + # code pre-filled — one click, no manual code entry. + return DeviceFlowStart( + pending={ + "device_code": device_code, + "host": host, + "enterprise_url": ent, + "owner": get_current_user(request) or None, + }, + response={ "user_code": data.get("user_code"), "verification_uri": data.get("verification_uri"), "verification_uri_complete": data.get("verification_uri_complete"), - "interval": interval, - "expires_in": expires_in, - } + }, + interval=int(data.get("interval") or 5), + expires_in=int(data.get("expires_in") or 900), + ) - @router.post("/device/poll") - def device_poll(request: Request, poll_id: str = Form(...)): - require_admin(request) - _prune_expired() - with _PENDING_LOCK: - pending = _PENDING.get(poll_id) - if not pending: - raise HTTPException(404, "Unknown or expired login session") - # Enforce GitHub's polling interval server-side so a chatty client - # can't trip slow_down. - now = time.time() - if now < pending.get("next_poll_at", 0): - return {"status": "pending"} +def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll: + try: + data = copilot.poll_access_token(pending["host"], pending["device_code"]) + except Exception as e: + return DeviceFlowPoll.pending(f"poll error: {e}") + token = data.get("access_token") + if token: + base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE try: - data = copilot.poll_access_token(pending["host"], pending["device_code"]) + result = _provision_endpoint(token, base, pending["owner"]) except Exception as e: - return {"status": "pending", "detail": f"poll error: {e}"} + logger.exception("Copilot endpoint provisioning failed") + raise HTTPException(500, f"Login succeeded but provisioning failed: {e}") + return DeviceFlowPoll.authorized(result) - token = data.get("access_token") - if token: - base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE - try: - result = _provision_endpoint(token, base, pending["owner"]) - except Exception as e: - logger.exception("Copilot endpoint provisioning failed") - with _PENDING_LOCK: - _PENDING.pop(poll_id, None) - raise HTTPException(500, f"Login succeeded but provisioning failed: {e}") - with _PENDING_LOCK: - _PENDING.pop(poll_id, None) - return {"status": "authorized", "endpoint": result} + err = data.get("error") + if err == "authorization_pending": + return DeviceFlowPoll.pending() + if err == "slow_down": + return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None) + if err in ("expired_token", "access_denied"): + return DeviceFlowPoll.failed(err) + # Unknown error — surface but keep the session for another try. + return DeviceFlowPoll.pending(err or "unknown") - err = data.get("error") - if err == "authorization_pending": - with _PENDING_LOCK: - if poll_id in _PENDING: - _PENDING[poll_id]["next_poll_at"] = now + pending["interval"] - return {"status": "pending"} - if err == "slow_down": - new_interval = int(data.get("interval") or (pending["interval"] + 5)) - with _PENDING_LOCK: - if poll_id in _PENDING: - _PENDING[poll_id]["interval"] = new_interval - _PENDING[poll_id]["next_poll_at"] = now + new_interval - return {"status": "pending"} - if err in ("expired_token", "access_denied"): - with _PENDING_LOCK: - _PENDING.pop(poll_id, None) - return {"status": "failed", "error": err} - # Unknown error — surface but keep the session for another try. - return {"status": "pending", "detail": err or "unknown"} - @router.post("/device/cancel") - def device_cancel(request: Request, poll_id: str = Form(...)): - require_admin(request) - with _PENDING_LOCK: - _PENDING.pop(poll_id, None) - return {"status": "cancelled"} - - return router +def setup_copilot_routes(): + return create_device_flow_router( + prefix="/api/copilot", + tags=["copilot"], + store=_DEVICE_FLOW_STORE, + start_flow=_start_device_flow, + poll_flow=_poll_device_flow, + ) diff --git a/routes/device_flow.py b/routes/device_flow.py new file mode 100644 index 000000000..8b8ab4ac8 --- /dev/null +++ b/routes/device_flow.py @@ -0,0 +1,193 @@ +"""Shared OAuth/device-flow route scaffolding for provider setup.""" + +from __future__ import annotations + +import inspect +import threading +import time +import uuid +from dataclasses import dataclass +from typing import Any, Callable, Iterable, Mapping, Optional + +from fastapi import APIRouter, Form, HTTPException, Request + +from core.middleware import require_admin + + +@dataclass(frozen=True) +class DeviceFlowStart: + """Provider-specific start result consumed by the shared route wrapper.""" + + pending: Mapping[str, Any] + response: Mapping[str, Any] + interval: int = 5 + expires_in: int = 900 + + +@dataclass(frozen=True) +class DeviceFlowPoll: + """Normalized provider poll outcome.""" + + status: str + endpoint: Optional[Mapping[str, Any]] = None + error: Optional[str] = None + detail: Optional[str] = None + interval: Optional[int] = None + + @classmethod + def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll": + return cls(status="pending", detail=detail) + + @classmethod + def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll": + return cls(status="slow_down", interval=interval, detail=detail) + + @classmethod + def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll": + return cls(status="authorized", endpoint=endpoint) + + @classmethod + def failed(cls, error: str) -> "DeviceFlowPoll": + return cls(status="failed", error=error) + + +class PendingDeviceFlowStore: + """Thread-safe in-memory pending device-flow store. + + Device codes and provider-side secrets stay inside this process. Each entry + stores provider payload separately from poll metadata so provider callbacks + only receive the fields they created. + """ + + def __init__(self, *, time_func: Callable[[], float] = time.time): + self._pending: dict[str, dict[str, Any]] = {} + self._lock = threading.Lock() + self._time = time_func + + def _now(self) -> float: + return float(self._time()) + + def prune_expired(self) -> None: + now = self._now() + with self._lock: + for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]: + self._pending.pop(key, None) + + def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str: + self.prune_expired() + poll_id = uuid.uuid4().hex + with self._lock: + self._pending[poll_id] = { + "payload": dict(payload), + "interval": max(int(interval or 5), 1), + "expires_at": self._now() + max(int(expires_in or 900), 1), + "next_poll_at": 0.0, + } + return poll_id + + def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]: + self.prune_expired() + with self._lock: + entry = self._pending.get(poll_id) + if entry is None: + return None + return dict(entry.get("payload") or {}) + + def is_throttled(self, poll_id: str) -> bool: + with self._lock: + entry = self._pending.get(poll_id) + return bool(entry and self._now() < float(entry.get("next_poll_at") or 0)) + + def schedule_next(self, poll_id: str) -> None: + now = self._now() + with self._lock: + entry = self._pending.get(poll_id) + if entry is not None: + entry["next_poll_at"] = now + int(entry.get("interval") or 5) + + def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None: + now = self._now() + with self._lock: + entry = self._pending.get(poll_id) + if entry is not None: + new_interval = int(interval or (int(entry.get("interval") or 5) + 5)) + entry["interval"] = max(new_interval, 1) + entry["next_poll_at"] = now + entry["interval"] + + def pop(self, poll_id: str) -> None: + with self._lock: + self._pending.pop(poll_id, None) + + +async def _maybe_await(value: Any) -> Any: + if inspect.isawaitable(value): + return await value + return value + + +def _pending_response(detail: Optional[str] = None) -> dict[str, Any]: + response: dict[str, Any] = {"status": "pending"} + if detail: + response["detail"] = detail + return response + + +def create_device_flow_router( + *, + prefix: str, + tags: Iterable[str], + store: PendingDeviceFlowStore, + start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart], + poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll], +) -> APIRouter: + """Create standard `/device/start|poll|cancel` routes for a provider.""" + + router = APIRouter(prefix=prefix, tags=list(tags)) + + @router.post("/device/start") + async def device_start(request: Request): + require_admin(request) + form = await request.form() + start = await _maybe_await(start_flow(request, form)) + interval = int(start.interval or 5) + expires_in = int(start.expires_in or 900) + poll_id = store.add(start.pending, interval=interval, expires_in=expires_in) + response = dict(start.response) + response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in}) + return response + + @router.post("/device/poll") + async def device_poll(request: Request, poll_id: str = Form(...)): + require_admin(request) + payload = store.get_payload(poll_id) + if payload is None: + raise HTTPException(404, "Unknown or expired login session") + if store.is_throttled(poll_id): + return {"status": "pending"} + + try: + outcome = await _maybe_await(poll_flow(request, payload)) + except Exception: + store.pop(poll_id) + raise + + if outcome.status == "authorized": + store.pop(poll_id) + return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})} + if outcome.status == "failed": + store.pop(poll_id) + return {"status": "failed", "error": outcome.error or "denied"} + if outcome.status == "slow_down": + store.slow_down(poll_id, outcome.interval) + return _pending_response(outcome.detail) + + store.schedule_next(poll_id) + return _pending_response(outcome.detail) + + @router.post("/device/cancel") + def device_cancel(request: Request, poll_id: str = Form(...)): + require_admin(request) + store.pop(poll_id) + return {"status": "cancelled"} + + return router diff --git a/routes/model_routes.py b/routes/model_routes.py index 2d5be4154..15f1543c3 100644 --- a/routes/model_routes.py +++ b/routes/model_routes.py @@ -283,6 +283,7 @@ _HOST_TO_CURATED = ( ("fireworks.ai", "fireworks"), ("googleapis.com", "google"), ("x.ai", "xai"), + ("openrouter.ai", "openrouter"), ("ollama.com", "ollama"), ("opencode.ai/zen/go", "opencode-go"), @@ -493,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = ( def _is_chat_model(model_id: str) -> bool: """Return True if the model ID looks like a chat/completions-capable model.""" mid = model_id.lower() + if mid in {"gpt-5.1-codex"}: + return True for prefix in _NON_CHAT_PREFIXES: if mid.startswith(prefix): return False @@ -505,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool: return True -def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict: +def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool: + """Delete a ProviderAuthSession once no endpoint still references it. + + Subscription providers (e.g. ChatGPT Subscription) keep their refresh token + in ProviderAuthSession rather than ModelEndpoint.api_key. When the last + endpoint backed by that auth row is removed, the stored credentials should + be cleared instead of lingering. Returns True if a row was deleted. + ``exclude_ep_id`` drops the endpoint currently being deleted from the + reference count so it does not keep its own auth alive. + """ + if not auth_id: + return False + from core.database import ProviderAuthSession + still_referenced = db.query(ModelEndpoint.id).filter( + ModelEndpoint.provider_auth_id == auth_id, + ModelEndpoint.id != exclude_ep_id, + ).first() + if still_referenced is not None: + return False + auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first() + if auth_row is None: + return False + db.delete(auth_row) + return True + + +def _is_discovery_only_provider(provider: str) -> bool: + """Provider that only supports model discovery, not live probing. + + ChatGPT Subscription speaks the Responses/Codex API and has no + chat-completions or general health endpoint, so completion probes and + reachability pings are skipped — status is derived from cached models. + """ + return provider == "chatgpt-subscription" + + +def _resolve_probe_key(ep) -> Optional[str]: + """API key/bearer to probe an endpoint with. + + Delegates to ``resolve_endpoint_runtime``, which already returns the static + ``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes) + the runtime bearer for session-backed providers (e.g. ChatGPT Subscription). + Returns None if resolution fails (e.g. re-auth required) so probing skips + rather than raising. Reads only already-loaded scalar attributes of ``ep``. + """ + try: + from src.endpoint_resolver import resolve_endpoint_runtime + _base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None)) + return key + except Exception as e: + logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e) + return None + + +def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict: """Send a realistic completion request to a single model. Returns {status, latency_ms, error?}.""" provider = _detect_provider(base) + if _is_discovery_only_provider(provider): + # Responses/Codex API, not chat-completions: a completion probe would + # 400 and the re-probe flow would then hide every model. Discovery-only. + return {"status": "ok", "latency_ms": 0, "skipped": True} messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Say OK"}, @@ -621,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis For Anthropic, queries their /v1/models API, falling back to hardcoded list.""" from src.endpoint_resolver import resolve_url base = resolve_url(_normalize_base(base_url)) + if _detect_provider(base) == "chatgpt-subscription": + from src.chatgpt_subscription import fetch_available_models + if api_key: + return fetch_available_models(api_key, timeout=timeout) + return [] if _detect_provider(base) == "anthropic": # Try Anthropic's /v1/models endpoint first url = build_models_url(base) @@ -647,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}") return list(ANTHROPIC_MODELS) url = build_models_url(base) + if not url: + curated_key = _match_provider_curated(base, None) + fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None + return list(fallback or []) headers = build_headers(api_key, base) try: r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify()) @@ -998,6 +1068,17 @@ def setup_model_routes(model_discovery): ok, info = _should_refresh_endpoint(ep, now, force=force) if not ok: continue + if getattr(ep, "provider_auth_id", None): + try: + from src.endpoint_resolver import resolve_endpoint_runtime + info["base"], info["api_key"] = resolve_endpoint_runtime( + ep, + owner=getattr(ep, "owner", None), + ) + info["key"] = _refresh_key(info["base"], info["api_key"]) + except Exception as e: + logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e) + continue groups.setdefault(info["key"], { "base": info["base"], "api_key": info["api_key"], @@ -1266,12 +1347,20 @@ def setup_model_routes(model_discovery): "endpoint_kind": kind, } try: - t0 = _time.time() - ping = _ping_endpoint(base, ep.api_key, timeout=1.5) - entry["latency_ms"] = round((_time.time() - t0) * 1000) - entry["status"] = "online" if ping.get("reachable") or cached_count else "offline" - entry["error"] = ping.get("error") - entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0) + if _is_discovery_only_provider(provider): + # No general health endpoint — an unauthenticated GET just + # 401s. Report status from cached models instead of pinging. + entry["latency_ms"] = None + entry["status"] = "online" if cached_count else "offline" + entry["error"] = None + entry["model_count"] = cached_count + else: + t0 = _time.time() + ping = _ping_endpoint(base, ep.api_key, timeout=1.5) + entry["latency_ms"] = round((_time.time() - t0) * 1000) + entry["status"] = "online" if ping.get("reachable") or cached_count else "offline" + entry["error"] = ping.get("error") + entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0) except Exception as e: entry["latency_ms"] = None entry["status"] = "online" if cached_count else "offline" @@ -1304,7 +1393,7 @@ def setup_model_routes(model_discovery): if ep_id and ep_id not in endpoints_cache: ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() if ep: - endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key} + endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)} ep_data = endpoints_cache.get(ep_id) if not ep_data: # Try to find by base_url from the model's endpoint field @@ -1343,7 +1432,7 @@ def setup_model_routes(model_discovery): "id": ep.id, "name": ep.name, "base_url": ep.base_url, - "api_key": ep.api_key, + "api_key": _resolve_probe_key(ep), }) finally: db.close() @@ -1432,12 +1521,14 @@ def setup_model_routes(model_discovery): # Endpoint counts as reachable if it has any model — including # admin-pinned IDs that a probe would never surface. status = "online" if (all_models or pinned) else "offline" + base = _normalize_base(r.base_url) ping = None - if not all_models and not pinned and r.is_enabled: + # Discovery-only providers have no health endpoint — an + # unauthenticated ping just 401s, so don't bother. + if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)): ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0) if ping.get("reachable"): status = "empty" - base = _normalize_base(r.base_url) kind = _effective_endpoint_kind(r, base) results.append({ "id": r.id, @@ -1713,7 +1804,7 @@ def setup_model_routes(model_discovery): ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first() if not ep: raise HTTPException(404, "Endpoint not found") - ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key} + ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)} finally: db.close() @@ -1777,7 +1868,7 @@ def setup_model_routes(model_discovery): category = _classify_endpoint(base, kind) timeout = _manual_refresh_timeout(ep, category, refresh_timeout) try: - probed = _probe_endpoint(base, ep.api_key, timeout=timeout) + probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout) except Exception as exc: logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc) probed = [] @@ -2116,7 +2207,9 @@ def setup_model_routes(model_discovery): cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id) cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url) cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url) + auth_id = getattr(ep, "provider_auth_id", None) db.delete(ep) + cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id) db.commit() _invalidate_models_cache() _local_probe_cache["data"] = None @@ -2126,6 +2219,7 @@ def setup_model_routes(model_discovery): "cleared_user_preferences": cleared_user_preferences, "cleared_sessions": cleared_sessions, "cleared_loaded_sessions": cleared_loaded_sessions, + "cleared_provider_auth": cleared_provider_auth, } finally: db.close() diff --git a/routes/research_routes.py b/routes/research_routes.py index c48ba3b5d..ea9d207a3 100644 --- a/routes/research_routes.py +++ b/routes/research_routes.py @@ -75,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None): return owner_filter(q, ModelEndpoint, owner).first() +def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None): + """Resolve a ModelEndpoint row into (chat_url, model, headers). + + Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for + panel-selected research endpoints. ChatGPT Subscription endpoints keep + OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty. + """ + from src.endpoint_resolver import ( + build_chat_url, + build_headers, + resolve_endpoint_runtime as resolve_model_endpoint_runtime, + ) + + try: + base, api_key = resolve_model_endpoint_runtime(ep, owner=owner) + except Exception as e: + logger.warning("Could not resolve endpoint credentials for research: %s", e) + return None + + ep_model = (model or "").strip() + if not ep_model: + try: + models = json.loads(ep.cached_models) if ep.cached_models else [] + if models: + ep_model = _first_chat_model(models) + except Exception: + pass + if not ep_model: + return None + return build_chat_url(base), ep_model, build_headers(api_key, base) + + def setup_research_routes(research_handler, session_manager=None) -> APIRouter: router = APIRouter(tags=["research"]) @@ -371,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: if body.endpoint_id: from src.database import SessionLocal - from src.endpoint_resolver import normalize_base, build_chat_url, build_headers db = SessionLocal() try: # Owner-scoped: never resolve another user's private endpoint @@ -380,18 +411,10 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: ep = _owned_enabled_endpoint(db, user, body.endpoint_id) if not ep: raise HTTPException(404, "Endpoint not found or disabled") - base = normalize_base(ep.base_url) - ep_url = build_chat_url(base) - ep_headers = build_headers(ep.api_key, base) - ep_model = body.model or "" - if not ep_model: - try: - import json as _json - models = _json.loads(ep.cached_models) if ep.cached_models else [] - if models: - ep_model = _first_chat_model(models) - except Exception: - pass + resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model) + if not resolved: + raise HTTPException(400, "Endpoint is not configured with a usable model.") + ep_url, ep_model, ep_headers = resolved finally: db.close() else: @@ -408,7 +431,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user) if not ep_url: from src.database import SessionLocal - from src.endpoint_resolver import normalize_base, build_chat_url, build_headers db = SessionLocal() try: # Owner-scoped first-enabled fallback: the caller's own rows @@ -417,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter: # /api/v1/chat fallback (webhook_routes._first_enabled_endpoint). ep = _owned_enabled_endpoint(db, user) if ep: - base = normalize_base(ep.base_url) - ep_url = build_chat_url(base) - ep_headers = build_headers(ep.api_key, base) - ep_model = "" - if ep.cached_models: - try: - import json as _json - models = _json.loads(ep.cached_models) - if models: - ep_model = _first_chat_model(models) - except Exception: - pass + resolved = _resolve_endpoint_runtime(ep, owner=user) + if resolved: + ep_url, ep_model, ep_headers = resolved finally: db.close() if not ep_url: diff --git a/routes/session_routes.py b/routes/session_routes.py index fd2721e3f..5bd693383 100644 --- a/routes/session_routes.py +++ b/routes/session_routes.py @@ -92,18 +92,13 @@ def _reject_compact_during_active_run(session_id: str) -> None: def _verify_session_owner(request: Request, session_id: str, session_manager=None): - """Verify the current user owns the session. Raises 404 if not. + """Verify the current user owns the session, honoring single-user modes. - Ownership is checked against the DB row when one exists (unchanged). If - there is no DB row but the caller owns an in-memory "ghost" session — one - that lives only in ``session_manager`` because it was never persisted, or - its DB row was removed out-of-band — fall back to the in-memory owner so the - user can still manage and delete it. Without this fallback such sessions are - listed by ``/api/sessions`` (they come from the in-memory manager) yet every - per-session operation 404s, making them impossible to delete (issue #1044). - - ``session_manager`` is optional and defaults to ``None`` so existing callers - that only care about persisted sessions keep their exact prior behavior. + Authenticated requests must match the stored DB or in-memory owner. When + auth is disabled and no user is present, treat the app as single-user mode: + verify that the session exists, but do not compare its stored owner. This + keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped + rows created while auth was previously enabled. """ user = effective_user(request) if not user and not _auth_disabled(): @@ -114,13 +109,13 @@ def _verify_session_owner(request: Request, session_id: str, session_manager=Non finally: db.close() if row is not None: - if row.owner != user: + if user and row.owner != user: raise HTTPException(404, f"Session {session_id} not found") return # No DB row — allow the caller to act on an in-memory ghost they own. if session_manager is not None: ghost = getattr(session_manager, "sessions", {}).get(session_id) - if ghost is not None and getattr(ghost, "owner", None) == user: + if ghost is not None and (not user or getattr(ghost, "owner", None) == user): return raise HTTPException(404, f"Session {session_id} not found") @@ -372,8 +367,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ pass elif not model_to_use: from src.llm_core import list_model_ids - ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT, - headers=validation_headers) + ids = list_model_ids( + endpoint_url, + timeout=REQUEST_TIMEOUT, + headers=validation_headers, + owner=user, + endpoint_id=endpoint_id.strip() if endpoint_id else None, + ) if not ids: raise HTTPException(400, "Cannot reach /v1/models") # Default to the first CHAT model — endpoints often list embedding/ @@ -387,8 +387,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_ from src.llm_core import list_model_ids import os as _os req_base = _os.path.basename(model_to_use.rstrip("/")) - avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT, - headers=validation_headers) + avail = list_model_ids( + endpoint_url, + timeout=REQUEST_TIMEOUT, + headers=validation_headers, + owner=user, + endpoint_id=endpoint_id.strip() if endpoint_id else None, + ) if not avail: raise HTTPException(400, "Cannot reach /v1/models") if model_to_use not in avail: diff --git a/routes/skills_routes.py b/routes/skills_routes.py index 705502e48..8a7c5c269 100644 --- a/routes/skills_routes.py +++ b/routes/skills_routes.py @@ -1109,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter: idx = skills_manager.index_for(owner=user) return {"index": idx, "count": len(idx)} + @router.get("/slash-catalog") + async def get_slash_catalog(request: Request): + """Return skills that are available as slash commands. + + Mirrors the agent prompt's published-skill index so the UI never offers + a slash command the model would not normally be allowed to discover. + """ + user = _owner(request) + all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)} + entries = [] + for s in skills_manager.index_for(owner=user): + name = (s.get("name") or "").strip() + if not name: + continue + full = all_skills.get(name) or {} + category = (s.get("category") or full.get("category") or "general").strip() or "general" + entries.append({ + "type": "skill", + "token": f"/{name}", + "name": name, + "category": f"Skills / {category}", + "help": s.get("description") or full.get("description") or "", + "usage": f"/{name} ", + "uses": int(full.get("uses") or 0), + "last_used": full.get("last_used"), + }) + entries.sort(key=lambda row: row["name"]) + return {"skills": entries, "count": len(entries)} + @router.get("/builtin") async def list_builtin_skills(request: Request): """Read-only list of the agent's built-in tool capabilities (research, @@ -1272,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter: _fire_skill_added(user) return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry} + @router.post("/{skill_id}/invoke") + async def invoke_skill(request: Request, skill_id: str): + """Build a skill-pinned prompt for slash-command invocation. + + This is intentionally server-side so availability, ownership, and usage + accounting use the same rules as the SkillsManager. + """ + user = _owner(request) + try: + body = await request.json() + except Exception: + body = {} + request_text = (body.get("request") or "").strip() if isinstance(body, dict) else "" + + invokable = { + s.get("name"): s for s in skills_manager.index_for(owner=user) + if (s.get("name") or "").strip() + } + match = invokable.get(skill_id) + if not match: + raise HTTPException(404, "Skill is not available for slash invocation") + + name = match.get("name") + md = skills_manager.read_skill_md(name, owner=user) + if md is None: + raise HTTPException(404, "Skill source unavailable") + + skills_manager.record_use(name, owner=user) + message = ( + "Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n" + f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n" + + (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)") + ) + return { + "ok": True, + "type": "skill", + "name": name, + "command": f"/{name}", + "message": message, + } + @router.get("/{skill_id}") async def get_skill(request: Request, skill_id: str): user = _owner(request) diff --git a/routes/webhook_routes.py b/routes/webhook_routes.py index 5cf739fda..da6288e7a 100644 --- a/routes/webhook_routes.py +++ b/routes/webhook_routes.py @@ -325,22 +325,33 @@ def setup_webhook_routes( endpoint_url = build_chat_url(base_url) model = body.model or "auto" api_key = ep.api_key + if getattr(ep, "provider_auth_id", None): + try: + from src.endpoint_resolver import resolve_endpoint_runtime + base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner) + endpoint_url = build_chat_url(base_url) + except Exception: + raise HTTPException(500, "Could not resolve endpoint credentials") if model == "auto": try: async with httpx.AsyncClient(timeout=5) as client: models_url = build_models_url(base_url) hdrs = build_headers(api_key, base_url) - resp = await client.get(models_url, headers=hdrs) - resp.raise_for_status() - data = resp.json() - ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if not ids: - ids = [ - m.get("name") or m.get("model") - for m in (data.get("models") or []) - if m.get("name") or m.get("model") - ] + if models_url: + resp = await client.get(models_url, headers=hdrs) + resp.raise_for_status() + data = resp.json() + ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not ids: + ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] + else: + import json as _json + ids = _json.loads(ep.cached_models or "[]") model = ids[0] if ids else "auto" except Exception: raise HTTPException(500, "Could not discover models from endpoint") diff --git a/src/ai_interaction.py b/src/ai_interaction.py index a03a5b0ac..4dbab9a66 100644 --- a/src/ai_interaction.py +++ b/src/ai_interaction.py @@ -57,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None): # Model resolution # --------------------------------------------------------------------------- -from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url +from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]: @@ -98,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di (f" matching '{target_endpoint_name}'" if target_endpoint_name else "")) for ep in endpoints: - base = _normalize_base(ep.base_url) + try: + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception: + continue provider = _detect_provider(base) - headers = build_headers(ep.api_key, base) + headers = build_headers(api_key, base) if provider == "anthropic": # Anthropic: match against hardcoded model list @@ -114,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di else: # OpenAI-compatible and native Ollama: probe the provider's model list. try: - r = httpx.get(build_models_url(base), headers=headers, timeout=5) - r.raise_for_status() - data = r.json() - model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if not model_ids: - model_ids = [ - m.get("name") or m.get("model") - for m in (data.get("models") or []) - if m.get("name") or m.get("model") - ] + models_url = build_models_url(base) + if models_url: + r = httpx.get(models_url, headers=headers, timeout=5) + r.raise_for_status() + data = r.json() + model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not model_ids: + model_ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] + else: + model_ids = json.loads(ep.cached_models or "[]") except Exception: model_ids = [] @@ -1121,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner: total_models = 0 for ep in endpoints: - base = _normalize_base(ep.base_url) + try: + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception: + continue provider = _detect_provider(base) - headers = build_headers(ep.api_key, base) + headers = build_headers(api_key, base) model_ids = [] if provider == "anthropic": model_ids = list(ANTHROPIC_MODELS) else: try: - r = httpx.get(build_models_url(base), headers=headers, timeout=5) - r.raise_for_status() - data = r.json() - model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] - if not model_ids: - model_ids = [ - m.get("name") or m.get("model") - for m in (data.get("models") or []) - if m.get("name") or m.get("model") - ] + models_url = build_models_url(base) + if models_url: + r = httpx.get(models_url, headers=headers, timeout=5) + r.raise_for_status() + data = r.json() + model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")] + if not model_ids: + model_ids = [ + m.get("name") or m.get("model") + for m in (data.get("models") or []) + if m.get("name") or m.get("model") + ] + else: + model_ids = json.loads(ep.cached_models or "[]") except Exception: model_ids = ["(endpoint offline)"] diff --git a/src/chatgpt_subscription.py b/src/chatgpt_subscription.py new file mode 100644 index 000000000..263c4f529 --- /dev/null +++ b/src/chatgpt_subscription.py @@ -0,0 +1,311 @@ +"""ChatGPT subscription / Codex backend OAuth helpers. + +This provider is intentionally separate from OpenAI API-key endpoints. It uses +OpenAI account OAuth device authorization, stores refresh tokens server-side, +and resolves a fresh bearer token at request time. +""" + +from __future__ import annotations + +import base64 +import json +import os +import threading +import time +from typing import Any, Dict, Optional + +import httpx +from fastapi import HTTPException + +from core.database import ProviderAuthSession, SessionLocal, utcnow_naive + +DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = ( + os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/") + or "https://chatgpt.com/backend-api/codex" +) +CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription" +CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token" +CHATGPT_OAUTH_ISSUER = "https://auth.openai.com" +CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback" +CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 +_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {} +_AUTH_REFRESH_LOCKS_GUARD = threading.Lock() + + +def _refresh_lock_for(auth_id: str) -> threading.Lock: + with _AUTH_REFRESH_LOCKS_GUARD: + lock = _AUTH_REFRESH_LOCKS.get(auth_id) + if lock is None: + lock = threading.Lock() + _AUTH_REFRESH_LOCKS[auth_id] = lock + return lock + + +class ChatGPTSubscriptionError(RuntimeError): + """Base error for ChatGPT subscription provider failures.""" + + +class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError): + """Stored OAuth credentials are invalid or expired beyond refresh.""" + + +class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError): + """Upstream quota/rate limit; reconnecting will not fix it.""" + + +class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError): + """No matching owner-scoped auth session exists.""" + + +def is_chatgpt_subscription_base(url: str) -> bool: + try: + from urllib.parse import urlparse + + parsed = urlparse(url or "") + host = (parsed.hostname or "").lower().rstrip(".") + path = (parsed.path or "").rstrip("/") + except Exception: + return False + return host == "chatgpt.com" and ( + path == "/backend-api/codex" or path.startswith("/backend-api/codex/") + ) + + +def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]: + headers = { + "Accept": "application/json, text/event-stream", + "Origin": "https://chatgpt.com", + "Referer": "https://chatgpt.com/codex", + "User-Agent": "Odysseus ChatGPT Subscription", + } + if access_token: + headers["Authorization"] = f"Bearer {access_token}" + return headers + + +def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]: + if not access_token: + return [] + try: + response = httpx.get( + "https://chatgpt.com/backend-api/codex/models?client_version=1.0.0", + headers=chatgpt_headers(access_token), + timeout=timeout, + ) + if response.status_code != 200: + return [] + data = response.json() + except Exception: + return [] + entries = data.get("models", []) if isinstance(data, dict) else [] + sortable: list[tuple[int, str]] = [] + for item in entries: + if not isinstance(item, dict): + continue + slug = item.get("slug") + if not isinstance(slug, str) or not slug.strip(): + continue + visibility = item.get("visibility", "") + if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}: + continue + priority = item.get("priority") + rank = int(priority) if isinstance(priority, (int, float)) else 10_000 + sortable.append((rank, slug.strip())) + sortable.sort(key=lambda item: (item[0], item[1])) + ordered: list[str] = [] + seen: set[str] = set() + for _, slug in sortable: + if slug not in seen: + ordered.append(slug) + seen.add(slug) + return ordered + + +def _raise_for_oauth_response(response: httpx.Response, action: str) -> None: + if response.status_code < 400: + return + code = "" + message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}." + try: + payload = response.json() + err = payload.get("error") if isinstance(payload, dict) else None + if isinstance(err, dict): + code = str(err.get("code") or err.get("type") or "").strip() + msg = err.get("message") + if msg: + message = f"ChatGPT Subscription {action} failed: {msg}" + elif isinstance(err, str): + code = err.strip() + desc = payload.get("error_description") or payload.get("message") + if desc: + message = f"ChatGPT Subscription {action} failed: {desc}" + except Exception: + pass + if response.status_code == 429: + raise ChatGPTSubscriptionRateLimited( + "ChatGPT Subscription quota or rate limit was reached. Credentials are still valid." + ) + if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}: + raise ChatGPTSubscriptionReauthRequired(message) + raise ChatGPTSubscriptionError(message) + + +def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]: + _raise_for_oauth_response(response, action) + try: + data = response.json() + except Exception as exc: + raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc + if not isinstance(data, dict): + raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.") + return data + + +def request_device_code(timeout: float = 15.0) -> Dict[str, Any]: + response = httpx.post( + f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode", + json={"client_id": CHATGPT_OAUTH_CLIENT_ID}, + headers={"Content-Type": "application/json"}, + timeout=timeout, + ) + data = _json_or_error(response, "device-code request") + if not data.get("device_auth_id") or not data.get("user_code"): + raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.") + data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device") + data.setdefault("interval", 5) + data.setdefault("expires_in", 900) + return data + + +def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]: + response = httpx.post( + f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token", + json={"device_auth_id": device_auth_id, "user_code": user_code}, + headers={"Content-Type": "application/json"}, + timeout=timeout, + ) + if response.status_code in (403, 404): + return {"status": "pending", "error": "authorization_pending"} + return _json_or_error(response, "device-code poll") + + +def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]: + response = httpx.post( + CHATGPT_OAUTH_TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "authorization_code", + "code": authorization_code, + "redirect_uri": CHATGPT_OAUTH_REDIRECT_URI, + "client_id": CHATGPT_OAUTH_CLIENT_ID, + "code_verifier": code_verifier, + }, + timeout=timeout, + ) + data = _json_or_error(response, "token exchange") + if not data.get("access_token"): + raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.") + return data + + +def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]: + del access_token + if not refresh_token: + raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.") + response = httpx.post( + CHATGPT_OAUTH_TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CHATGPT_OAUTH_CLIENT_ID, + }, + timeout=timeout, + ) + data = _json_or_error(response, "token refresh") + if not data.get("access_token"): + raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.") + return data + + +def _decode_jwt_payload(token: str) -> Dict[str, Any]: + parts = (token or "").split(".") + if len(parts) < 2: + raise ValueError("not a JWT") + segment = parts[1] + segment += "=" * (-len(segment) % 4) + raw = base64.urlsafe_b64decode(segment.encode("ascii")) + payload = json.loads(raw.decode("utf-8")) + return payload if isinstance(payload, dict) else {} + + +def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool: + try: + exp = int(_decode_jwt_payload(access_token).get("exp") or 0) + except Exception: + return True + return exp <= int(time.time()) + int(skew_seconds) + + +def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]: + db = SessionLocal() + try: + q = db.query(ProviderAuthSession).filter( + ProviderAuthSession.id == auth_id, + ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER, + ) + if owner: + q = q.filter(ProviderAuthSession.owner == owner) + row = q.first() + if row is None: + raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.") + + access_token = row.access_token or "" + if force_refresh or access_token_is_expiring(access_token): + with _refresh_lock_for(auth_id): + db.refresh(row) + access_token = row.access_token or "" + refresh_token = row.refresh_token or "" + if force_refresh or access_token_is_expiring(access_token): + refreshed = refresh_oauth_tokens(access_token, refresh_token) + row.access_token = refreshed["access_token"] + if refreshed.get("refresh_token"): + row.refresh_token = refreshed["refresh_token"] + row.last_refresh = utcnow_naive() + db.commit() + db.refresh(row) + access_token = row.access_token or "" + + return { + "provider": CHATGPT_SUBSCRIPTION_PROVIDER, + "base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"), + "api_key": access_token, + "auth_mode": row.auth_mode or "chatgpt", + } + finally: + db.close() + + +def to_http_exception(exc: Exception) -> HTTPException: + if isinstance(exc, ChatGPTSubscriptionRateLimited): + return HTTPException(429, str(exc)) + if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)): + return HTTPException(401, f"{exc} Reconnect the provider.") + return HTTPException(502, str(exc)) + + +def build_responses_input(messages: list[dict]) -> list[dict]: + input_items: list[dict] = [] + for msg in messages or []: + role = msg.get("role") or "user" + if role == "tool": + role = "user" + content = msg.get("content") + if isinstance(content, list): + text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict)) + else: + text = "" if content is None else str(content) + input_type = "output_text" if role == "assistant" else "input_text" + input_items.append({"role": role, "content": [{"type": input_type, "text": text}]}) + return input_items diff --git a/src/endpoint_resolver.py b/src/endpoint_resolver.py index c6c8d5902..1ae7ace84 100644 --- a/src/endpoint_resolver.py +++ b/src/endpoint_resolver.py @@ -70,6 +70,25 @@ def _endpoint_enabled_models(ep) -> list: return [m for m in _endpoint_cached_models(ep) if m not in hidden] +def resolve_endpoint_runtime(ep, owner: Optional[str] = None) -> Tuple[str, Optional[str]]: + """Resolve a ModelEndpoint row to its runtime base URL and bearer/API key. + + Static-key providers use ``ModelEndpoint.api_key``. Session-backed providers + store refreshable credentials in ProviderAuthSession and must resolve a + current access token at call time. + """ + base = normalize_base(getattr(ep, "base_url", "") or "") + api_key = getattr(ep, "api_key", None) + auth_id = getattr(ep, "provider_auth_id", None) + if auth_id: + from src.chatgpt_subscription import resolve_runtime_credentials + + creds = resolve_runtime_credentials(auth_id, owner=owner) + base = normalize_base(creds.get("base_url") or base) + api_key = creds.get("api_key") + return base, api_key + + # Cache for Tailscale hostname → IP resolution _tailscale_cache: Dict[str, Optional[str]] = {} @@ -133,7 +152,7 @@ def resolve_url(url: str) -> str: def normalize_base(url: str) -> str: """Strip known API path suffixes from a base URL.""" url = (url or "").strip().rstrip("/") - for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]: + for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages", "/responses"]: if url.endswith(suffix): url = url[: -len(suffix)].rstrip("/") for suffix in ["/chat", "/tags", "/generate"]: @@ -158,10 +177,12 @@ def build_chat_url(base: str) -> str: return _anthropic_api_root(base) + "/v1/messages" if provider == "ollama": return _ollama_api_root(base) + "/chat" + if provider == "chatgpt-subscription": + return base.rstrip("/") + "/responses" return base + "/chat/completions" -def build_models_url(base: str) -> str: +def build_models_url(base: str) -> Optional[str]: """Return the provider-specific model-list endpoint URL for a base.""" base = resolve_url(base) provider = _detect_provider(base) @@ -169,6 +190,8 @@ def build_models_url(base: str) -> str: return _anthropic_api_root(base) + "/v1/models" if provider == "ollama": return _ollama_api_root(base) + "/tags" + if provider == "chatgpt-subscription": + return None return base + "/models" @@ -184,6 +207,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]: if provider == "copilot": from src.copilot import copilot_headers return copilot_headers(api_key) + if provider == "chatgpt-subscription": + from src.chatgpt_subscription import chatgpt_headers + return chatgpt_headers(api_key) if api_key: headers["Authorization"] = f"Bearer {api_key}" if provider == "openrouter": @@ -262,9 +288,13 @@ def resolve_endpoint( if not ep: return fallback_url, fallback_model, fallback_headers - base = normalize_base(ep.base_url) + try: + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception as e: + logger.warning("Could not resolve endpoint runtime credentials: %s", e) + return fallback_url, fallback_model, fallback_headers chat_url = build_chat_url(base) - headers = build_headers(ep.api_key, base) + headers = build_headers(api_key, base) # Discard a configured model the user has since disabled on the # endpoint (e.g. a stale `default_model` left pointing at a now-hidden @@ -308,9 +338,13 @@ def resolve_endpoint_by_id( ep = q.first() if not ep: return None - base = normalize_base(ep.base_url) + try: + base, api_key = resolve_endpoint_runtime(ep, owner=owner) + except Exception as e: + logger.warning("Could not resolve endpoint runtime credentials: %s", e) + return None chat_url = build_chat_url(base) - headers = build_headers(ep.api_key, base) + headers = build_headers(api_key, base) m = (model or "").strip() # Drop a model the user disabled on the endpoint, then pick the first # enabled chat model rather than a hidden one. diff --git a/src/llm_core.py b/src/llm_core.py index 4076356dd..2fbfc8178 100644 --- a/src/llm_core.py +++ b/src/llm_core.py @@ -426,6 +426,9 @@ def _detect_provider(url: str) -> str: return "openrouter" if _host_match(url, "groq.com"): return "groq" + from src.chatgpt_subscription import is_chatgpt_subscription_base + if is_chatgpt_subscription_base(url): + return "chatgpt-subscription" from src.copilot import is_copilot_base if is_copilot_base(url): return "copilot" @@ -462,6 +465,8 @@ def _provider_label(url: str) -> str: if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go" if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen" if _host_match(url, "groq.com"): return "Groq" + from src.chatgpt_subscription import is_chatgpt_subscription_base + if is_chatgpt_subscription_base(url): return "ChatGPT Subscription" from src.copilot import is_copilot_base if is_copilot_base(url): return "GitHub Copilot" if _host_match(url, "mistral.ai"): return "Mistral" @@ -479,6 +484,77 @@ def _provider_label(url: str) -> str: return host or "provider" +def _normalize_chatgpt_subscription_url(url: str) -> str: + base = (url or "").strip().rstrip("/") + if base.endswith("/responses"): + return base + return base + "/responses" + + +def _message_content_as_text(content) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for part in content: + if not isinstance(part, dict): + if part: + parts.append(str(part)) + continue + if isinstance(part.get("text"), str): + parts.append(part["text"]) + continue + if isinstance(part.get("content"), str): + parts.append(part["content"]) + return "\n".join(parts) + return "" if content is None else str(content) + + +def _chatgpt_subscription_instructions(messages: List[Dict]) -> str: + instructions = [ + _message_content_as_text(msg.get("content")).strip() + for msg in messages or [] + if (msg.get("role") or "") == "system" + ] + instructions = [part for part in instructions if part] + if instructions: + return "\n\n".join(instructions) + return "You are a helpful AI assistant." + + +def _build_chatgpt_responses_payload( + model: str, + messages: List[Dict], + temperature: float, + max_tokens: int, + *, + stream: bool = False, +) -> Dict: + from src.chatgpt_subscription import build_responses_input + + conversation = [msg for msg in (messages or []) if (msg.get("role") or "") != "system"] + payload: Dict = { + "model": model, + "instructions": _chatgpt_subscription_instructions(messages), + "input": build_responses_input(conversation), + "stream": stream, + "store": False, + } + if not _restricts_temperature(model): + payload["temperature"] = temperature + if max_tokens and max_tokens > 0: + payload["max_output_tokens"] = max_tokens + return payload + + +def _format_chatgpt_subscription_error(status_code: int, text: str) -> str: + if status_code in (401, 403): + return "ChatGPT Subscription credentials expired or were rejected. Reconnect the provider." + if status_code == 429: + return "ChatGPT Subscription quota or rate limit was reached. Retry after the upstream limit resets." + return _format_upstream_error(status_code, text, "https://chatgpt.com/backend-api/codex") + + def _format_upstream_error(status: int, body: bytes | str, url: str) -> str: """Turn an upstream HTTP error into a user-readable sentence. @@ -874,7 +950,7 @@ def _normalize_anthropic_url(url: str) -> str: def _model_list_base(url: str) -> str: """Normalize model/chat URLs to the configured endpoint base.""" base = (url or "").strip().rstrip("/") - for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"): + for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages", "/responses"): if base.endswith(suffix): base = base[: -len(suffix)].rstrip("/") for suffix in ("/chat", "/tags", "/generate"): @@ -903,7 +979,12 @@ def _parse_model_cache(raw) -> List[str]: return out -def _configured_cached_model_ids(endpoint_url: str) -> List[str]: +def _configured_cached_model_ids( + endpoint_url: str, + *, + owner: Optional[str] = None, + endpoint_id: Optional[str] = None, +) -> List[str]: """Return cached models for a configured endpoint matching endpoint_url.""" target = _model_list_base(endpoint_url) if not target: @@ -914,7 +995,13 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]: return [] db = SessionLocal() try: - rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() + q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True) + if endpoint_id: + q = q.filter(ModelEndpoint.id == endpoint_id) + if owner: + from src.auth_helpers import owner_filter + q = owner_filter(q, ModelEndpoint, owner) + rows = q.all() for ep in rows: if _model_list_base(getattr(ep, "base_url", "")) != target: continue @@ -933,9 +1020,16 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]: return [] -def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]: +def list_model_ids( + base_chat_url: str, + timeout: int = LLMConfig.DEFAULT_TIMEOUT, + headers: Optional[Dict] = None, + *, + owner: Optional[str] = None, + endpoint_id: Optional[str] = None, +) -> List[str]: """List available model IDs from an endpoint.""" - cached = _configured_cached_model_ids(base_chat_url) + cached = _configured_cached_model_ids(base_chat_url, owner=owner, endpoint_id=endpoint_id) if cached: return cached provider = _detect_provider(base_chat_url) @@ -971,9 +1065,16 @@ def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, pass return [] -def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]: +def normalize_model_id( + endpoint_url: str, + requested: str, + timeout: int = LLMConfig.DEFAULT_TIMEOUT, + *, + owner: Optional[str] = None, + endpoint_id: Optional[str] = None, +) -> Optional[str]: """Normalize a model ID to match available models.""" - avail = list_model_ids(endpoint_url, timeout) + avail = list_model_ids(endpoint_url, timeout, owner=owner, endpoint_id=endpoint_id) if not avail: return None if requested in avail: @@ -1169,6 +1270,49 @@ async def llm_call_async( logger.debug(f"Returning cached response for key: {cache_key}") return cached_response + if provider == "chatgpt-subscription": + # ChatGPT/Codex requires streamed Responses requests even for callers + # that want a plain string (auto-title, memory extraction, etc.). + # Reuse stream_llm's validated Codex SSE path and collect deltas. + parts: List[str] = [] + async for chunk in stream_llm( + url, + model, + messages_copy, + temperature=temperature, + max_tokens=max_tokens, + headers=headers, + timeout=timeout, + ): + event_is_error = False + for line in str(chunk).splitlines(): + if line.startswith("event:"): + event_is_error = line[6:].strip() == "error" + continue + if not line.startswith("data:"): + continue + raw = line[5:].strip() + if not raw: + continue + if raw == "[DONE]": + response = "".join(parts) + _set_cached_response(cache_key, response) + return response + try: + data = json.loads(raw) + except json.JSONDecodeError: + continue + if event_is_error or data.get("error") or (data.get("status") and data.get("text")): + status = int(data.get("status") or 502) + text = data.get("text") or data.get("error") or "ChatGPT Subscription request failed" + raise HTTPException(status, text) + delta = data.get("delta") + if isinstance(delta, str): + parts.append(delta) + response = "".join(parts) + _set_cached_response(cache_key, response) + return response + if provider == "anthropic": target_url = _normalize_anthropic_url(url) h = _build_anthropic_headers(headers) @@ -1294,6 +1438,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl model, messages_copy, temperature, max_tokens, stream=True, tools=tools, num_ctx=get_context_length(url, model), ) + elif provider == "chatgpt-subscription": + target_url = _normalize_chatgpt_subscription_url(url) + h = _provider_headers(provider, headers) + payload = _build_chatgpt_responses_payload(model, messages_copy, temperature, max_tokens, stream=True) else: target_url = url payload = { @@ -1325,6 +1473,68 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl return note_model_activity(target_url, model) + # ── ChatGPT Subscription / Codex Responses streaming ── + if provider == "chatgpt-subscription": + event_name = "" + input_tokens = 0 + output_tokens = 0 + try: + client = _get_http_client() + async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r: + _clear_host_dead(target_url) + if r.status_code != 200: + raw = (await r.aread()).decode(errors="replace") + friendly = _format_chatgpt_subscription_error(r.status_code, raw) + yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n' + return + async for line in r.aiter_lines(): + if not line: + continue + if line.startswith("event:"): + event_name = line[6:].strip() + continue + if not line.startswith("data:"): + continue + raw = line[5:].strip() + if not raw: + continue + try: + data = json.loads(raw) + except json.JSONDecodeError: + continue + evt = data.get("type") or event_name + if evt == "response.output_text.delta": + delta = data.get("delta") or "" + if delta: + yield f'data: {json.dumps({"delta": delta})}\n\n' + elif evt == "response.completed": + usage = (data.get("response") or {}).get("usage") or data.get("usage") or {} + input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or input_tokens + output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") or output_tokens + if input_tokens or output_tokens: + yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": input_tokens, "output_tokens": output_tokens}})}\n\n' + yield "data: [DONE]\n\n" + return + elif evt in ("response.failed", "error"): + err = data.get("error") or (data.get("response") or {}).get("error") or {} + text = err.get("message") if isinstance(err, dict) else str(err or "ChatGPT Subscription request failed") + yield f'event: error\ndata: {json.dumps({"status": 502, "text": text})}\n\n' + return + yield "data: [DONE]\n\n" + except (httpx.ConnectError, httpx.ConnectTimeout) as e: + _cooled = _mark_host_dead(target_url) + _tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry" + logger.warning(f"ChatGPT Subscription stream connect to {target_url} failed: {e}{_tail}") + yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n' + except httpx.ReadTimeout: + yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n' + except httpx.NetworkError: + yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n' + except Exception as e: + logger.error(f"ChatGPT Subscription stream error: {e}") + yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n' + return + # ── Native Ollama streaming ── if provider == "ollama": _ollama_tool_calls: List[Dict] = [] diff --git a/static/index.html b/static/index.html index 7f4394d39..522129fe9 100644 --- a/static/index.html +++ b/static/index.html @@ -2108,6 +2108,8 @@ + + @@ -2136,6 +2138,7 @@
+
diff --git a/static/js/admin.js b/static/js/admin.js index b9512149b..a9a281a34 100644 --- a/static/js/admin.js +++ b/static/js/admin.js @@ -5,6 +5,7 @@ import uiModule from './ui.js'; import settingsModule from './settings.js'; import { providerLogo } from './providers.js'; import { sortModelObjects } from './modelSort.js'; +import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js'; let initialized = false; let modalEl = null; @@ -707,6 +708,80 @@ function initEndpointForm() { const pickerBtn = el('adm-provider-btn'); const pickerMenu = el('adm-provider-menu'); const pickerCurrent = picker ? picker.querySelector('.adm-provider-current') : null; + const DEVICE_AUTH_PROVIDER_VALUES = new Set(Object.keys(PROVIDER_DEVICE_FLOWS)); + let deviceAuthPolling = false; + function _selectedProviderOption() { + return provider && provider.selectedOptions ? provider.selectedOptions[0] : null; + } + function _selectedDeviceAuthProvider() { + const opt = _selectedProviderOption(); + const flow = opt && opt.dataset ? opt.dataset.authFlow : ''; + if (flow && DEVICE_AUTH_PROVIDER_VALUES.has(flow)) return flow; + return DEVICE_AUTH_PROVIDER_VALUES.has(provider.value) ? provider.value : ''; + } + function _isDeviceAuthSelected() { + return !!_selectedDeviceAuthProvider(); + } + function _setApiFormForProvider() { + const deviceAuthProvider = _selectedDeviceAuthProvider(); + const deviceAuthConfig = PROVIDER_DEVICE_FLOWS[deviceAuthProvider] || null; + const apiKey = el('adm-epApiKey'); + const testBtn = el('adm-epApiTestBtn'); + const addBtn = el('adm-epAddBtn'); + const status = el('adm-deviceAuthStatus'); + const msg = _endpointMsg('api'); + if (deviceAuthConfig) { + urlInput.value = ''; + urlInput.placeholder = deviceAuthProvider === 'copilot' + ? 'GitHub Copilot uses GitHub account sign-in' + : 'ChatGPT Subscription uses OpenAI account sign-in'; + urlInput.readOnly = true; + if (apiKey) { + apiKey.value = ''; + apiKey.placeholder = 'No API key needed'; + apiKey.disabled = true; + } + if (testBtn) { + testBtn.disabled = true; + testBtn.style.opacity = '0.45'; + testBtn.style.cursor = 'not-allowed'; + } + if (addBtn) { + addBtn.disabled = false; + addBtn.textContent = 'Add'; + addBtn.style.width = '55px'; + addBtn.style.display = ''; + } + if (kindSel) kindSel.value = 'api'; + if (msg) { + msg.textContent = ''; + msg.className = ''; + } + } else { + urlInput.placeholder = 'Base URL or pick provider'; + urlInput.readOnly = false; + if (apiKey) { + apiKey.placeholder = 'API key'; + apiKey.disabled = false; + } + if (testBtn) { + testBtn.disabled = false; + testBtn.style.opacity = ''; + testBtn.style.cursor = ''; + } + if (addBtn) { + addBtn.disabled = false; + addBtn.textContent = 'Add'; + addBtn.style.width = '55px'; + addBtn.style.display = ''; + } + if (msg) { + msg.textContent = ''; + msg.className = ''; + } + if (!deviceAuthPolling && status) status.textContent = ''; + } + } function _renderPickerMenu() { if (!pickerMenu) return; pickerMenu.innerHTML = Array.from(provider.options).map(o => { @@ -748,9 +823,16 @@ function initEndpointForm() { } provider.addEventListener('change', () => { + if (_isDeviceAuthSelected()) { + _setApiFormForProvider(); + _renderPickerMenu(); + _syncPickerCurrent(); + return; + } if (provider.value) urlInput.value = provider.value; else urlInput.value = ''; if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy'; + _setApiFormForProvider(); }); urlInput.addEventListener('input', () => { if (provider.value && urlInput.value.trim() !== provider.value) { @@ -838,6 +920,12 @@ function initEndpointForm() { const apiCancelTestBtn = el('adm-epApiCancelTestBtn'); if (apiTestBtn) { apiTestBtn.addEventListener('click', async () => { + if (_isDeviceAuthSelected()) { + const msg = _endpointMsg('api'); + msg.textContent = ''; + msg.className = ''; + return; + } const msg = _endpointMsg('api'); msg.textContent = ''; msg.className = ''; const rawUrl = (urlInput.value || provider.value).trim(); @@ -885,6 +973,11 @@ function initEndpointForm() { } el('adm-epAddBtn').addEventListener('click', async () => { + const deviceAuthProvider = _selectedDeviceAuthProvider(); + if (deviceAuthProvider) { + await _startProviderDeviceAuth(deviceAuthProvider, el('adm-epAddBtn')); + return; + } const msg = _endpointMsg('api'); msg.textContent = ''; msg.className = ''; const rawUrl = (urlInput.value || provider.value).trim(); @@ -936,76 +1029,116 @@ function initEndpointForm() { btn.disabled = false; btn.textContent = 'Add'; }); - // GitHub Copilot — device-flow login. Starts the flow, shows the user a - // code + verification link, and polls until they authorise (or it expires). - const copilotBtn = el('adm-copilotConnectBtn'); - if (copilotBtn) { - let copilotPolling = false; - copilotBtn.addEventListener('click', async () => { - if (copilotPolling) return; - const status = el('adm-copilotStatus'); - const reset = () => { copilotBtn.disabled = false; copilotBtn.textContent = 'Connect GitHub Copilot'; copilotPolling = false; }; - status.textContent = ''; status.className = 'adm-ep-inline-msg'; - copilotBtn.disabled = true; copilotBtn.textContent = 'Starting...'; - copilotPolling = true; - let start; - try { - const res = await fetch('/api/copilot/device/start', { method: 'POST', body: new FormData(), credentials: 'same-origin' }); - start = await res.json(); - if (!res.ok) { status.textContent = start.detail || 'Failed to start login'; status.className = 'admin-error'; reset(); return; } - } catch (e) { status.textContent = 'Request failed'; status.className = 'admin-error'; reset(); return; } + async function _startProviderDeviceAuth(providerKey, triggerEl = null) { + if (deviceAuthPolling) return; + const config = PROVIDER_DEVICE_FLOWS[providerKey]; + if (!config) return; + const status = el('adm-deviceAuthStatus') || _endpointMsg('api'); + if (!status) return; + const triggerText = triggerEl ? triggerEl.textContent : ''; + // Render an error with an inline "Try again" (the top button is hidden for + // device-auth providers, so retry lives here). Built with DOM methods, not + // innerHTML. Call reset() first so the deviceAuthPolling guard is cleared. + const showAuthError = (text) => { + status.className = 'admin-error'; + status.textContent = text + ' '; + const retry = document.createElement('button'); + retry.type = 'button'; + retry.className = 'admin-btn-sm'; + retry.textContent = 'Try again'; + retry.addEventListener('click', () => { _startProviderDeviceAuth(providerKey, triggerEl); }); + status.appendChild(retry); + }; + const reset = () => { + if (triggerEl) { + triggerEl.disabled = false; + triggerEl.textContent = triggerText || 'Add'; + } + deviceAuthPolling = false; + _setApiFormForProvider(); + }; + status.textContent = ''; + status.className = 'adm-ep-inline-msg'; + if (triggerEl) { + triggerEl.disabled = true; + triggerEl.textContent = 'Starting...'; + } + deviceAuthPolling = true; + _setApiFormForProvider(); + status.textContent = `Starting ${config.label} sign-in...`; - const { poll_id, user_code, verification_uri, verification_uri_complete, interval, expires_in } = start; - // Prefer the "complete" URL — it embeds the code so the user only has to - // click "Authorize" (no manual code entry). - const authUrl = verification_uri_complete || verification_uri || ''; - const esc = (s) => String(s || '').replace(/[<>&"]/g, (c) => ({ '<': '<', '>': '>', '&': '&', '"': '"' }[c])); - copilotBtn.textContent = 'Waiting…'; - - // Cohesive waiting panel: spinner + status line, the device code as a - // copyable chip, and a primary "Authorize on GitHub" action. - status.className = ''; - status.innerHTML = - '
' + - '
' + - 'Waiting for GitHub authorization…
' + - '
' + - 'Code' + - '' + esc(user_code) + '' + - '' + - '
' + - 'Authorize on GitHub ↗' + - '
A new tab opened on GitHub — approve there to finish. Didn\'t open? Use the button above.
' + - '
'; - const copyBtn = status.querySelector('.adm-copilot-copy'); - if (copyBtn) copyBtn.addEventListener('click', async () => { - try { await navigator.clipboard.writeText(user_code || ''); copyBtn.textContent = 'Copied'; setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500); } catch (e) {} + try { + const result = await runProviderDeviceFlow(providerKey, { + openWindow: () => {}, + onStart: ({ start, authUrl }) => { + if (triggerEl) triggerEl.textContent = 'Waiting...'; + status.className = ''; + const authLabel = providerKey === 'copilot' ? 'Authorize on GitHub' : 'Authorize with OpenAI'; + const waitLabel = providerKey === 'copilot' ? 'Waiting for GitHub authorization...' : 'Waiting for ChatGPT authorization...'; + status.innerHTML = + '
' + + '
' + + '' + esc(waitLabel) + '
' + + '
' + + 'Code' + + '' + esc(start.user_code) + '' + + '' + + '
' + + '' + esc(authLabel) + ' ↗' + + '
'; + const copyBtn = status.querySelector('.adm-device-auth-copy'); + if (copyBtn) copyBtn.addEventListener('click', async () => { + const code = start.user_code || ''; + let ok = false; + try { + if (navigator.clipboard && window.isSecureContext) { + await navigator.clipboard.writeText(code); + ok = true; + } + } catch (e) {} + if (!ok) { + // navigator.clipboard is unavailable in non-secure contexts (HTTP + // self-host over a LAN IP), so fall back to execCommand('copy'). + const ta = document.createElement('textarea'); + ta.value = code; + ta.style.cssText = 'position:fixed;top:0;left:0;width:1px;height:1px;padding:0;border:0;opacity:0;font-size:16px;'; + document.body.appendChild(ta); + ta.focus(); + ta.select(); + try { ta.setSelectionRange(0, code.length); } catch (e) {} + try { ok = document.execCommand('copy'); } catch (e) {} + ta.remove(); + } + copyBtn.textContent = ok ? 'Copied' : 'Failed'; + setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500); + }); + }, }); - try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {} - - const deadline = Date.now() + (expires_in || 900) * 1000; - const stepMs = Math.max((interval || 5), 2) * 1000; - const done = (cls, text) => { status.className = cls; status.textContent = text; reset(); }; - const poll = async () => { - if (Date.now() > deadline) { done('admin-error', 'Authorization expired — try again.'); return; } - try { - const fd = new FormData(); fd.append('poll_id', poll_id); - const r = await fetch('/api/copilot/device/poll', { method: 'POST', body: fd, credentials: 'same-origin' }); - const d = await r.json(); - if (d.status === 'authorized') { - const n = ((d.endpoint && d.endpoint.models) || []).length; - done('admin-success', '✓ Connected — ' + n + ' Copilot model' + (n !== 1 ? 's' : '') + ' available.'); - if (d.endpoint && d.endpoint.id) _recentlyAddedEpId = String(d.endpoint.id); - await loadEndpoints(); - await _selectAddedModelInChat(d.endpoint || {}); - return; - } - if (d.status === 'failed') { done('admin-error', 'Authorization failed (' + (d.error || 'denied') + ').'); return; } - } catch (e) { /* transient — keep polling */ } - setTimeout(poll, stepMs); - }; - setTimeout(poll, stepMs); - }); + if (result.status === 'authorized') { + const endpoint = result.endpoint || {}; + const n = ((endpoint && endpoint.models) || []).length; + status.className = 'admin-success'; + status.textContent = 'Connected - ' + n + ' ' + config.label + ' model' + (n !== 1 ? 's' : '') + ' available.'; + if (endpoint && endpoint.id) _recentlyAddedEpId = String(endpoint.id); + await loadEndpoints(); + await _selectAddedModelInChat(endpoint || {}); + reset(); + return; + } + if (result.status === 'failed') { + reset(); + showAuthError('Authorization failed (' + (result.error || 'denied') + ').'); + return; + } + if (result.status === 'expired') { + reset(); + showAuthError('Authorization expired.'); + return; + } + } catch (e) { + reset(); + showAuthError(formatDeviceFlowError(e)); + } } // Local "Add" button — sibling form for self-hosted base URLs. diff --git a/static/js/chatRenderer.js b/static/js/chatRenderer.js index 088142302..fc7ed1aeb 100644 --- a/static/js/chatRenderer.js +++ b/static/js/chatRenderer.js @@ -680,9 +680,11 @@ export function applyModelColor(roleEl, modelName) { html += '
Max tokens ' + _mt.toLocaleString() + ' (configured)
'; } } - if (info && info.input != null) html += '
Input $' + info.input.toFixed(2) + ' / 1M
'; - if (info && info.output != null) html += '
Output $' + info.output.toFixed(2) + ' / 1M
'; - if (!info) html += '
No pricing data available
'; + if (isCostTrackedEndpoint(_epUrl)) { + if (info && info.input != null) html += '
Input $' + info.input.toFixed(2) + ' / 1M
'; + if (info && info.output != null) html += '
Output $' + info.output.toFixed(2) + ' / 1M
'; + if (!info) html += '
No pricing data available
'; + } popup.innerHTML = html; const rect = roleEl.getBoundingClientRect(); popup.style.top = (rect.bottom + 4) + 'px'; @@ -735,11 +737,31 @@ export function isLocalEndpoint(url) { return false; } -/** Cost for the current turn, returning null (free) for local endpoints. */ -function _billableCost(model, inputTokens, outputTokens) { - const url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl) +export function isSubscriptionEndpoint(url) { + if (!url) return false; + try { + const parsed = new URL(url); + const path = parsed.pathname.replace(/\/+$/, ''); + return parsed.hostname === 'chatgpt.com' + && (path === '/backend-api/codex' || path.startsWith('/backend-api/codex/')); + } catch (_e) { + return false; + } +} + +function _currentEndpointUrl() { + return (window.sessionModule && window.sessionModule.getCurrentEndpointUrl) ? window.sessionModule.getCurrentEndpointUrl() : null; - if (isLocalEndpoint(url)) return null; +} + +export function isCostTrackedEndpoint(url) { + return !isLocalEndpoint(url) && !isSubscriptionEndpoint(url); +} + +/** Cost for the current turn, returning null for non-billable endpoints. */ +function _billableCost(model, inputTokens, outputTokens) { + const url = _currentEndpointUrl(); + if (!isCostTrackedEndpoint(url)) return null; return getModelCost(model, inputTokens, outputTokens); } @@ -784,11 +806,10 @@ export function resetSessionCost(sessionId) { export function updateSessionCostUI() { const el = document.getElementById('session-cost-display'); if (!el) return; - // Local model? It's free — hide the badge and clear any stale cost that a - // previous (buggy) cloud-rate billing left in localStorage for this session. - const _url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl) - ? window.sessionModule.getCurrentEndpointUrl() : null; - if (isLocalEndpoint(_url)) { + // Non-billable endpoint? Hide the badge and clear stale cost that a previous + // cloud-rate calculation may have left in localStorage for this session. + const _url = _currentEndpointUrl(); + if (!isCostTrackedEndpoint(_url)) { const sid = window.sessionModule && window.sessionModule.getCurrentSessionId(); if (sid && getSessionCost(sid) > 0) { try { @@ -1708,7 +1729,8 @@ export function displayMetrics(messageElement, metrics) { e.stopPropagation(); document.querySelectorAll('.ctx-popup').forEach(p => { if (typeof p._dismiss === 'function') p._dismiss(); else p.remove(); }); - const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : 'n/a'; + const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : ''; + const costRows = costStr ? `
Cost ${costStr}
` : ''; const speedStr = tps != null && tps !== 'undefined' ? `${tps} tok/s` : 'n/a'; const totalTok = inputTokens + outputTokens; const ctxColor = ctxPct >= 85 ? 'var(--red, #e06c75)' : ctxPct >= 70 ? '#ff9900' : 'var(--color-muted-alt, #6b7280)'; @@ -1722,7 +1744,7 @@ export function displayMetrics(messageElement, metrics) { // Session total cost let sessionCostStr = ''; const sc = getSessionCost(); - if (sc > 0) { + if (costStr && sc > 0) { sessionCostStr = `
Session $${sc < 0.01 ? sc.toFixed(4) : sc.toFixed(3)}
`; } @@ -1738,7 +1760,7 @@ export function displayMetrics(messageElement, metrics) {
Time ${responseTime}s
${prepTime != null ? `
Prep ${prepTime}s
` : ''} ${modelWaitTime != null ? `
Model wait ${modelWaitTime}s
` : ''} -
Cost ${costStr}
+ ${costRows} ${sessionCostStr} ${prepDetails ? `
Agent prep
@@ -2392,6 +2414,8 @@ const chatRenderer = { modelColor, applyModelColor, getModelCost, + isCostTrackedEndpoint, + isSubscriptionEndpoint, getImageCost, getSessionCost, resetSessionCost, diff --git a/static/js/providerDeviceFlow.js b/static/js/providerDeviceFlow.js new file mode 100644 index 000000000..5b2975d87 --- /dev/null +++ b/static/js/providerDeviceFlow.js @@ -0,0 +1,128 @@ +// Shared DOM-free provider device-flow runner. + +export const PROVIDER_DEVICE_FLOWS = { + copilot: { + label: 'GitHub Copilot', + startUrl: '/api/copilot/device/start', + pollUrl: '/api/copilot/device/poll', + authUrl(start) { + return start?.verification_uri_complete || start?.verification_uri || ''; + }, + }, + 'chatgpt-subscription': { + label: 'ChatGPT Subscription', + startUrl: '/api/chatgpt-subscription/device/start', + pollUrl: '/api/chatgpt-subscription/device/poll', + authUrl(start) { + return start?.verification_uri || ''; + }, + }, +}; + +function _formData() { + if (typeof FormData !== 'undefined') return new FormData(); + return new URLSearchParams(); +} + +async function _jsonOrEmpty(response) { + try { + return await response.json(); + } catch (_) { + return {}; + } +} + +function _messageFromPayload(payload, fallback) { + if (payload && typeof payload.detail === 'string' && payload.detail.trim()) { + return payload.detail.trim(); + } + if (payload && typeof payload.error === 'string' && payload.error.trim()) { + return payload.error.trim(); + } + if (payload && typeof payload.message === 'string' && payload.message.trim()) { + return payload.message.trim(); + } + return fallback; +} + +export function formatDeviceFlowError(error, fallback = 'Request failed') { + if (!error) return fallback; + if (typeof error === 'string') return error; + if (error.detail) return String(error.detail); + if (error.message) return String(error.message); + return fallback; +} + +async function _fetchJson(fetchImpl, url, options, fallback) { + const response = await fetchImpl(url, options); + const payload = await _jsonOrEmpty(response); + if (!response.ok) { + throw new Error(_messageFromPayload(payload, fallback || `Request failed (HTTP ${response.status})`)); + } + return payload; +} + +function _defaultSleep(ms) { + return new Promise(resolve => setTimeout(resolve, ms)); +} + +async function _callCallback(fn, payload) { + if (typeof fn === 'function') await fn(payload); +} + +export async function runProviderDeviceFlow(provider, options = {}) { + const cfg = PROVIDER_DEVICE_FLOWS[provider]; + if (!cfg) throw new Error(`Unknown device-flow provider: ${provider}`); + + const fetchImpl = options.fetchImpl || globalThis.fetch?.bind(globalThis); + if (!fetchImpl) throw new Error('Fetch API is unavailable'); + + const openWindow = options.openWindow || ((url) => { + if (globalThis.window && typeof globalThis.window.open === 'function') { + globalThis.window.open(url, '_blank', 'noopener'); + } + }); + const sleep = options.sleep || _defaultSleep; + const now = options.now || (() => Date.now()); + const formData = options.formData || _formData(); + + const start = await _fetchJson(fetchImpl, cfg.startUrl, { + method: 'POST', + body: formData, + credentials: 'same-origin', + }, `Failed to start ${cfg.label} sign-in`); + + if (!start.poll_id) throw new Error(`${cfg.label} sign-in did not return a poll id`); + const authUrl = cfg.authUrl(start); + await _callCallback(options.onStart, { provider, config: cfg, start, authUrl }); + if (authUrl) openWindow(authUrl); + + const deadline = now() + Number(start.expires_in || 900) * 1000; + let stepMs = Math.max(Number(start.interval || 5), 2) * 1000; + + while (true) { + if (now() > deadline) return { status: 'expired' }; + await _callCallback(options.onWaiting, { provider, config: cfg, start, authUrl }); + await sleep(stepMs); + if (now() > deadline) return { status: 'expired' }; + + const fd = _formData(); + fd.append('poll_id', start.poll_id); + const poll = await _fetchJson(fetchImpl, cfg.pollUrl, { + method: 'POST', + body: fd, + credentials: 'same-origin', + }, `${cfg.label} sign-in poll failed`); + await _callCallback(options.onPoll, { provider, config: cfg, start, poll }); + + if (poll.status === 'authorized') { + return { status: 'authorized', endpoint: poll.endpoint || {} }; + } + if (poll.status === 'failed') { + return { status: 'failed', error: poll.error || 'denied' }; + } + if (poll.interval) { + stepMs = Math.max(Number(poll.interval || 5), 2) * 1000; + } + } +} diff --git a/static/js/providers.js b/static/js/providers.js index 327e0bbff..1c9c5080a 100644 --- a/static/js/providers.js +++ b/static/js/providers.js @@ -15,6 +15,10 @@ const _PROVIDERS = [ [/opencode/i, ''], + // GitHub / Copilot + [/github|copilot/i, + ''], + // OpenRouter [/openrouter|open router/i, ''], @@ -102,6 +106,7 @@ export function providerLogo(modelId) { // doesn't match `x.ai`. const _ENDPOINT_LABELS = [ [/(^|\.)githubcopilot\.com$/i, "GitHub Copilot"], + [/(^|\.)chatgpt\.com$/i, "ChatGPT Subscription"], [/(^|\.)openrouter\.ai$/i, "OpenRouter"], [/(^|\.)anthropic\.com$/i, "Anthropic"], [/(^|\.)openai\.com$/i, "OpenAI"], diff --git a/static/js/slashAutocomplete.js b/static/js/slashAutocomplete.js index 8745c98a6..14645acfe 100644 --- a/static/js/slashAutocomplete.js +++ b/static/js/slashAutocomplete.js @@ -5,7 +5,7 @@ import { COMMANDS, LEGACY_ALIASES } from './slashCommands.js'; const POPUP_ID = 'slash-autocomplete'; -const MAX_VISIBLE = 12; +const MAX_VISIBLE = 14; // Flatten the registry into a searchable list of leaf entries. Each entry is // either a top-level command or a "cmd sub" pair (so subcommands get their @@ -81,6 +81,23 @@ function _flatten() { return out; } +async function _loadSkillEntries() { + try { + const res = await fetch('/api/skills/slash-catalog', { credentials: 'same-origin' }); + if (!res.ok) return []; + const data = await res.json(); + return (Array.isArray(data.skills) ? data.skills : []).map(s => ({ + token: s.token || `/${s.name}`, + aliases: [], + category: s.category || 'Skills', + help: s.help || 'Run skill', + usage: s.usage || `${s.token || `/${s.name}`} `, + })).filter(e => e.token && e.token.startsWith('/')); + } catch { + return []; + } +} + function _scoreMatch(entry, query) { // query already starts with "/". Match against token + aliases. Prefix wins // over substring; alias match scores slightly lower than token match. @@ -98,6 +115,17 @@ function _scoreMatch(entry, query) { return 0; } +function _exactCommandGroupItems(all, query) { + const q = query.toLowerCase(); + if (!/^\/[a-z0-9_-]+$/i.test(q)) return []; + const parent = all.find(entry => entry.token.toLowerCase() === q); + if (!parent) return []; + const prefix = q + ' '; + const children = all.filter(entry => entry.token.toLowerCase().startsWith(prefix)); + if (!children.length) return []; + return children.concat(parent); +} + function _ensurePopup(textarea) { let el = document.getElementById(POPUP_ID); if (el) return el; @@ -164,7 +192,7 @@ export function initSlashAutocomplete(textarea) { if (!textarea || textarea._slashAcWired) return; textarea._slashAcWired = true; - const all = _flatten(); + let all = _flatten(); let popup = null; let visible = false; let items = []; @@ -191,12 +219,17 @@ export function initSlashAutocomplete(textarea) { // the menu hides — we don't autocomplete mid-sentence. if (!v.startsWith('/') || v.includes('\n')) { hide(); return; } const query = v.trim(); - items = all + const groupItems = _exactCommandGroupItems(all, query); + if (groupItems.length) { + items = groupItems.slice(0, MAX_VISIBLE); + } else { + items = all .map(e => ({ e, s: _scoreMatch(e, query) })) .filter(x => x.s > 0) .sort((a, b) => b.s - a.s) .slice(0, MAX_VISIBLE) .map(x => x.e); + } if (!items.length && query.length > 1) { hide(); return; } if (!items.length) { // Just "/" with no matches — fall back to showing everything up to MAX_VISIBLE @@ -207,6 +240,19 @@ export function initSlashAutocomplete(textarea) { _render(popup, items, selectedIdx, query); }; + _loadSkillEntries().then(skillEntries => { + if (!skillEntries.length) return; + const seen = new Set(all.map(e => e.token)); + const merged = all.slice(); + for (const entry of skillEntries) { + if (seen.has(entry.token)) continue; + seen.add(entry.token); + merged.push(entry); + } + all = merged; + if (visible) refresh(); + }); + const insert = (token) => { textarea.value = token + ' '; textarea.dispatchEvent(new Event('input', { bubbles: true })); diff --git a/static/js/slashCommands.js b/static/js/slashCommands.js index d1ed3e4ff..be4cb6798 100644 --- a/static/js/slashCommands.js +++ b/static/js/slashCommands.js @@ -21,6 +21,7 @@ import workspaceModule from './workspace.js'; import settingsModule from './settings.js'; import cookbookModule from './cookbook.js'; import { EVAL_PROMPTS } from './compare/index.js'; +import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js'; // ── Module state ────────────────────────────────────────────────────── @@ -58,11 +59,28 @@ const SETUP_PROVIDER_URLS = { 'opencode-go': { name: 'OpenCode Go', url: 'https://opencode.ai/zen/go/v1' }, }; const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'ollama', 'xai', 'anthropic', 'groq', 'gemini', 'opencode-zen', 'opencode-go']; -const SETUP_PROVIDER_HINT = SETUP_PROVIDER_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_NAMES[SETUP_PROVIDER_NAMES.length - 1]; +const SETUP_DEVICE_AUTH_PROVIDERS = [ + { key: 'copilot', name: 'GitHub Copilot', aliases: ['github'], command: '/setup copilot' }, + { key: 'chatgpt-subscription', name: 'ChatGPT Subscription', aliases: ['chatgptsubscription', 'chatgpt-sub', 'codex'], command: '/setup chatgpt-subscription' }, +]; +const SETUP_PROVIDER_HINT_NAMES = SETUP_PROVIDER_NAMES.concat(SETUP_DEVICE_AUTH_PROVIDERS.map(provider => provider.key)); +const SETUP_PROVIDER_HINT = SETUP_PROVIDER_HINT_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_HINT_NAMES[SETUP_PROVIDER_HINT_NAMES.length - 1]; const SETUP_LOCAL_ICON = ''; const SETUP_API_ICON = ''; const SETUP_SETTINGS_ICON = ''; +function _setupApiProviderChips() { + return SETUP_PROVIDER_NAMES.map(name => + '' + name + '' + ).join(' '); +} + +function _setupDeviceAuthProviderChips() { + return SETUP_DEVICE_AUTH_PROVIDERS.map(provider => + '' + provider.name + '' + ).join(' '); +} + function _setupProviderFromInput(input) { const raw = (input || '').trim().toLowerCase().replace(/\s+/g, ''); const aliases = { @@ -84,6 +102,17 @@ function _setupProviderFromInput(input) { return SETUP_PROVIDER_URLS[aliases[raw] || raw] || null; } +function _setupDeviceAuthProviderFromInput(input) { + const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '').replace(/_/g, '-'); + if (!raw) return ''; + for (const provider of SETUP_DEVICE_AUTH_PROVIDERS) { + const candidates = [provider.key, provider.name, ...(provider.aliases || [])] + .map(value => String(value || '').toLowerCase().replace(/\s+/g, '').replace(/_/g, '-')); + if (candidates.includes(raw)) return provider.key; + } + return ''; +} + function _extractSetupProviderCredential(input) { const raw = (input || '').trim(); if (!raw) return null; @@ -158,9 +187,8 @@ function _setupReply(text, remember = true) { } function _showSetupEndpointChoices() { - const providers = SETUP_PROVIDER_NAMES.map(name => - '' + name + '' - ).join(' '); + const providers = _setupApiProviderChips(); + const deviceAuthProviders = _setupDeviceAuthProviderChips(); return slashReply( '
' + '
' + @@ -178,6 +206,7 @@ function _showSetupEndpointChoices() { '
Paste provider name then API key (example):
' + '
deepseek sk-...
' + '
Supported providers:
' + providers + '
' + + '
Account sign-in:
' + deviceAuthProviders + '
' + '
' + '
' ); @@ -208,9 +237,8 @@ function _showSetupEndpointChoicesStreamed(options = {}) { text: 'deepseek sk-...', copyText: 'deepseek sk-...', }, - { kind: 'p', html: 'Supported providers:
' + SETUP_PROVIDER_NAMES.map(name => - '' + name + '' - ).join(' ') }, + { kind: 'p', html: 'Supported providers:
' + _setupApiProviderChips() }, + { kind: 'p', html: 'Account sign-in:
' + _setupDeviceAuthProviderChips() }, ]; return typewriterBlocksReply(blocks, { gap: '4px', bodyClass: 'setup-guide-no-censor', interval: 3 }); } @@ -231,7 +259,7 @@ async function _hasConfiguredModels() { } function _setupProviderPrompt() { - const chips = SETUP_PROVIDER_NAMES.map(name => + const chips = SETUP_PROVIDER_HINT_NAMES.map(name => '' + name + '' ).join(' '); slashReply('Supported providers:
' + chips); @@ -286,6 +314,53 @@ function slashReply(text) { return { el: div, body }; } +let _skillCatalogCache = { at: 0, items: [] }; + +async function _loadSkillSlashCatalog(force = false) { + const now = Date.now(); + if (!force && (now - _skillCatalogCache.at) < 15000) return _skillCatalogCache.items; + try { + const res = await fetch(`${API_BASE}/api/skills/slash-catalog`, { credentials: 'same-origin' }); + if (!res.ok) throw new Error('catalog unavailable'); + const data = await res.json(); + const items = Array.isArray(data.skills) ? data.skills : []; + _skillCatalogCache = { at: now, items }; + return items; + } catch { + return _skillCatalogCache.items || []; + } +} + +function _submitComposedMessage(text) { + const msgInput = document.getElementById('message'); + const form = document.getElementById('chat-form'); + if (!msgInput || !form) return false; + msgInput.value = text; + msgInput.dispatchEvent(new Event('input', { bubbles: true })); + if (typeof form.requestSubmit === 'function') form.requestSubmit(); + else form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true })); + return true; +} + +async function _invokeSkillByName(name, requestText, ctx) { + const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/invoke`, { + method: 'POST', + credentials: 'same-origin', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ request: requestText || '' }) + }); + if (!res.ok) { + const err = await res.json().catch(() => null); + slashReply(ctx?.esc ? ctx.esc(err?.detail || 'Skill is not available') : 'Skill is not available'); + return true; + } + const data = await res.json(); + if (!data.message || !_submitComposedMessage(data.message)) { + slashReply('Could not start skill invocation.'); + } + return true; +} + /** Minimal footer for slash replies: copy + dismiss */ function _slashFooter(msgEl) { const footer = document.createElement('div'); @@ -681,6 +756,13 @@ async function handleSetupWizard(mode, input) { await _setupProviderPrompt(); return; } + const deviceAuthProvider = _setupDeviceAuthProviderFromInput(input); + if (deviceAuthProvider) { + _addMessage('user', input); + setupMode = false; + await _setupProviderDeviceFlow(deviceAuthProvider); + return; + } const paired = _extractSetupProviderCredential(input); const provider = paired?.provider || _setupProviderFromInput(input); if (!provider) { @@ -1429,6 +1511,42 @@ async function _cmdModels(args, ctx) { return true; } +async function _cmdModel(args, ctx) { + const sub = (args[0] || '').toLowerCase(); + if (sub === 'list' || sub === 'ls') return _cmdModels(args.slice(1), ctx); + + const model = sessionModule.getCurrentModel ? sessionModule.getCurrentModel() : ''; + const endpoint = sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : ''; + slashReply(`
${[
+    `Current model: ${ctx.esc(model || 'None selected')}`,
+    endpoint ? `Endpoint: ${ctx.esc(endpoint)}` : 'Endpoint: not available',
+    '',
+    'Usage: /model list to show all available models'
+  ].join('\n')}
`); + return true; +} + +async function _cmdMcp(args, ctx) { + const res = await fetch(`${API_BASE}/api/mcp/servers`, { credentials: 'same-origin' }); + if (!res.ok) { + slashReply('MCP status is unavailable for this user.'); + return true; + } + const servers = await res.json(); + if (!Array.isArray(servers) || !servers.length) { + slashReply('No MCP servers configured.'); + return true; + } + const lines = servers.map(s => { + const status = s.status || (s.is_enabled ? 'enabled' : 'disabled'); + const enabled = Number(s.enabled_tool_count ?? s.tool_count ?? 0); + const total = Number(s.tool_count ?? enabled); + return `${s.name || s.id || 'MCP server'} - ${status} (${enabled}/${total} tools)`; + }); + slashReply(`
${lines.map(line => ctx.esc(line)).join('\n')}
`); + return true; +} + // ── Memory ── async function _cmdMemoryList(args, ctx) { @@ -1507,6 +1625,73 @@ async function _cmdMemorySearch(args, ctx) { return true; } +// ── Skills ── + +async function _cmdSkills(args, ctx) { + const sub = (args[0] || 'list').toLowerCase(); + const rest = args.slice(1); + + if (sub === 'list' || sub === 'ls') { + const skills = await _loadSkillSlashCatalog(true); + if (!skills.length) { + slashReply('No published skills available for slash commands.'); + return true; + } + const lines = skills.map(s => { + const uses = Number(s.uses || 0); + const useText = uses > 0 ? ` uses:${uses}` : ''; + return `${ctx.esc(String(s.token || '').padEnd(24))}${ctx.esc(s.help || '')}${useText}`; + }); + slashReply(`
${lines.join('\n')}
`); + return true; + } + + if (sub === 'search' || sub === 'find') { + const query = rest.join(' ').trim(); + if (!query) { slashReply('Usage: /skills search query'); return true; } + const res = await fetch(`${API_BASE}/api/skills/search`, { + method: 'POST', + credentials: 'same-origin', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ query }) + }); + if (!res.ok) { slashReply('Skill search failed.'); return true; } + const data = await res.json(); + const skills = Array.isArray(data.skills) ? data.skills : []; + if (!skills.length) { slashReply(`No skills found for "${ctx.esc(query)}".`); return true; } + const lines = skills.map(s => + ctx.esc(`/${s.name || s.id || ''}`.padEnd(24)) + ctx.esc(s.description || '') + ); + slashReply(`
${lines.join('\n')}
`); + return true; + } + + if (sub === 'view' || sub === 'cat' || sub === 'show') { + const name = (rest[0] || '').trim(); + if (!name) { slashReply('Usage: /skills view name'); return true; } + const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/markdown`, { credentials: 'same-origin' }); + if (!res.ok) { slashReply(`Skill "${ctx.esc(name)}" was not found.`); return true; } + const data = await res.json(); + slashReply(`
${ctx.esc(data.markdown || '')}
`); + return true; + } + + if (sub === 'use' || sub === 'run') { + const name = (rest[0] || '').trim(); + if (!name) { slashReply('Usage: /skills use name request'); return true; } + return _invokeSkillByName(name, rest.slice(1).join(' ').trim(), ctx); + } + + slashReply('Usage: /skills list | search query | view name | use name request'); + return true; +} + +async function _cmdReloadSkills(args, ctx) { + const skills = await _loadSkillSlashCatalog(true); + slashReply(`Reloaded skills. ${skills.length} skill command${skills.length === 1 ? '' : 's'} available.`); + return true; +} + // ── Note (quick Notes shortcut) ── async function _cmdNote(args, ctx) { @@ -1799,6 +1984,53 @@ Uploads: ${d.uploads || '?'}`); return true; } +async function _cmdUsage(args, ctx) { + const sid = ctx.sid; + if (!sid) { + slashReply('No active session.'); + return true; + } + + let session = null; + try { + const sessions = sessionModule.getSessions ? sessionModule.getSessions() : []; + session = (sessions || []).find(s => s.id === sid) || null; + if (!session) { + const res = await fetch(`${API_BASE}/api/sessions`, { credentials: 'same-origin' }); + if (res.ok) { + const data = await res.json(); + const items = Array.isArray(data) ? data : (data.sessions || data.items || []); + session = items.find(s => s.id === sid) || null; + } + } + } catch (_) {} + + const model = session?.model || 'Unknown'; + const endpointUrl = session?.endpoint_url || ( + sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : '' + ); + const messageCount = Number(session?.message_count || 0); + const totalTokens = Number(session?.total_tokens || 0); + const costTracked = chatRenderer.isCostTrackedEndpoint ? chatRenderer.isCostTrackedEndpoint(endpointUrl) : true; + const cost = costTracked && chatRenderer.getSessionCost ? Number(chatRenderer.getSessionCost(sid) || 0) : 0; + const costLine = costTracked + ? (cost > 0 + ? `Estimated local cost: $${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` + : 'Estimated local cost: unavailable or zero') + : 'Estimated local cost: not tracked for this endpoint'; + + slashReply(`
${[
+    `Session: ${ctx.esc(session?.name || 'Current chat')}`,
+    `Model: ${ctx.esc(model)}`,
+    `Messages: ${messageCount.toLocaleString()}`,
+    `Recorded tokens: ${totalTokens.toLocaleString()}`,
+    costLine,
+    '',
+    'Provider account usage is not available from here; check the provider dashboard for account quota/billing.'
+  ].join('\n')}
`); + return true; +} + // ── Context compaction ── async function _cmdCompact(args, ctx) { @@ -4783,39 +5015,53 @@ function _clearSetupCommandInput() { } } -// GitHub Copilot device-flow sign-in, driven from chat (mirrors the Settings -// "Connect GitHub Copilot" button). Replies via the setup guide messages. -async function _setupCopilot() { +async function _setupProviderDeviceFlow(providerKey) { _clearSetupGuideMessages(); - await _setupReply('Starting GitHub Copilot sign-in…'); - let start; + const config = PROVIDER_DEVICE_FLOWS[providerKey]; + if (!config) { + await _setupReply('Provider not recognised.'); + return; + } + await _setupReply(`Starting ${config.label} sign-in...`); try { - const r = await fetch(`${API_BASE}/api/copilot/device/start`, { method: 'POST', body: new FormData(), credentials: 'same-origin' }); - start = await r.json(); - if (!r.ok) { await _setupReply(start.detail || 'Failed to start Copilot sign-in.'); return; } - } catch (e) { await _setupReply('Request failed.'); return; } - const authUrl = start.verification_uri_complete || start.verification_uri || ''; - await _setupReply(`Opening GitHub — approve the request (code ${start.user_code}). Waiting…`); - try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {} - const deadline = Date.now() + (start.expires_in || 900) * 1000; - const stepMs = Math.max((start.interval || 5), 2) * 1000; - const poll = async () => { - if (Date.now() > deadline) { await _setupReply('Copilot sign-in expired — run /setup copilot again.'); return; } - try { - const fd = new FormData(); fd.append('poll_id', start.poll_id); - const r = await fetch(`${API_BASE}/api/copilot/device/poll`, { method: 'POST', body: fd, credentials: 'same-origin' }); - const d = await r.json(); - if (d.status === 'authorized') { - const n = ((d.endpoint && d.endpoint.models) || []).length; - await _setupReply(`Connected — ${n} Copilot model${n !== 1 ? 's' : ''} available.`); - if (modelsModule) modelsModule.refreshModels(true); - return; - } - if (d.status === 'failed') { await _setupReply('Copilot sign-in failed (' + (d.error || 'denied') + ').'); return; } - } catch (e) { /* transient — keep polling */ } - setTimeout(poll, stepMs); - }; - setTimeout(poll, stepMs); + const result = await runProviderDeviceFlow(providerKey, { + onStart: async ({ start, authUrl }) => { + const place = providerKey === 'copilot' ? 'GitHub' : 'OpenAI'; + const action = providerKey === 'copilot' ? 'approve the request' : 'enter the code'; + if (providerKey === 'chatgpt-subscription') { + slashReply( + '
' + + '
Open this URL in your browser, enter the code, then come back here. Waiting...
' + + '
Code: ' + uiModule.esc(start.user_code || '') + '
' + + '' + + '
' + ); + return; + } + await _setupReply(`Opening ${place} - ${action} (code ${start.user_code}). Waiting...`); + }, + openWindow: (url) => { + if (providerKey === 'chatgpt-subscription') return; + try { if (url) window.open(url, '_blank', 'noopener'); } catch (e) {} + }, + }); + if (result.status === 'authorized') { + const n = ((result.endpoint && result.endpoint.models) || []).length; + await _setupReply(`Connected - ${n} ${config.label} model${n !== 1 ? 's' : ''} available.`); + if (modelsModule) modelsModule.refreshModels(true); + return; + } + if (result.status === 'failed') { + await _setupReply(`${config.label} sign-in failed (${result.error || 'denied'}).`); + return; + } + if (result.status === 'expired') { + await _setupReply(`${config.label} sign-in expired - run /setup ${providerKey} again.`); + return; + } + } catch (e) { + await _setupReply(formatDeviceFlowError(e)); + } } async function _cmdSetup(args, ctx) { @@ -4823,7 +5069,11 @@ async function _cmdSetup(args, ctx) { _clearSetupCommandInput(); const topic = (args[0] || '').trim().toLowerCase(); const topicArgs = args.slice(1); - if (topic === 'copilot' || topic === 'github') { await _setupCopilot(); return true; } + const deviceAuthProvider = _setupDeviceAuthProviderFromInput(topic); + if (deviceAuthProvider) { + await _setupProviderDeviceFlow(deviceAuthProvider); + return true; + } const provider = _setupProviderFromInput(topic); if (provider) { _clearSetupGuideMessages(); @@ -5463,8 +5713,20 @@ async function _cmdHelp(args, ctx) { lines.push(''); } } + const skillCommands = await _loadSkillSlashCatalog(false); + if (skillCommands.length) { + lines.push('Skills:'); + for (const skill of skillCommands.slice(0, 20)) { + const token = String(skill.token || '').padEnd(21); + lines.push(` ${ctx.esc(token)}${ctx.esc(skill.help || '')}`); + } + if (skillCommands.length > 20) { + lines.push(` ... ${skillCommands.length - 20} more. Use /skills list`); + } + lines.push(''); + } lines.push('Tip: / --help for details'); - lines.push('Shortcuts: /new /rename /fork /web /bash /memories /forget'); + lines.push('Shortcuts: /new /rename /fork /web /bash /memories /skills'); slashReply(`
${lines.join('\n')}
`); return true; } @@ -5539,6 +5801,20 @@ const COMMANDS = { 'search': { handler: _cmdMemorySearch, alias: ['grep'], help: 'Search memories', usage: '/memory search q' } } }, + skills: { + alias: ['skill'], + category: 'Memory', + help: 'List, search, inspect, or run skills', + handler: _cmdSkills, + usage: '/skills list | search query | view name | use name request', + }, + 'reload-skills': { + alias: ['reload_skills'], + category: 'Memory', + help: 'Refresh the slash skill catalog', + handler: _cmdReloadSkills, + usage: '/reload-skills', + }, rag: { alias: [], category: 'RAG', @@ -5572,7 +5848,7 @@ const COMMANDS = { category: 'Getting started', help: 'Add local or API model endpoints', handler: _cmdSetup, - usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup endpoint', + usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup chatgpt-subscription', // Provider subs so the autocomplete popup surfaces "/setup deepseek", // "/setup openai", etc. when the user types "/setup de". Each sub's // handler is a thin wrapper that re-prepends the sub name and @@ -5590,6 +5866,7 @@ const COMMANDS = { xai: { help: 'xAI (Grok)', alias: ['grok'], usage: '/setup xai xai-...', handler: (a, c) => _cmdSetup(['xai', ...a], c) }, ollama: { help: 'Ollama Cloud', usage: '/setup ollama KEY', handler: (a, c) => _cmdSetup(['ollama', ...a], c) }, copilot: { help: 'GitHub Copilot', usage: '/setup copilot', handler: (a, c) => _cmdSetup(['copilot', ...a], c) }, + 'chatgpt-subscription': { help: 'ChatGPT Subscription', alias: ['codex'], usage: '/setup chatgpt-subscription', handler: (a, c) => _cmdSetup(['chatgpt-subscription', ...a], c) }, local: { help: 'Local model server (vLLM / LM Studio / llama.cpp / Ollama)', usage: '/setup local http://localhost:8000/v1', handler: (a, c) => _cmdSetup(['local', ...a], c) }, @@ -5767,8 +6044,22 @@ const COMMANDS = { handler: (args, ctx) => _cmdToolPanel('compare', args, ctx), usage: '/compare' }, + mcp: { + alias: [], + category: 'Tools', + help: 'Show MCP server status', + handler: _cmdMcp, + usage: '/mcp' + }, + model: { + alias: [], + category: 'Settings', + help: 'Show current chat model', + handler: _cmdModel, + usage: '/model · /model list' + }, models: { - alias: ['model'], + alias: [], category: 'Settings', help: 'List available models', handler: _cmdModels, @@ -5799,10 +6090,16 @@ const COMMANDS = { handler: _cmdStats, usage: '/stats' }, + usage: { + alias: ['cost', 'tokens'], + category: 'Utility', + help: 'Show local usage for the current chat', + handler: _cmdUsage, + usage: '/usage' + }, compact: { alias: [], category: 'Utility', - hidden: true, help: 'Compact older chat messages', handler: _cmdCompact, usage: '/compact' @@ -6075,33 +6372,13 @@ async function handleSlashCommand(input) { } // --- 4. Skill invocation: / [request] --- - // If `rawCmd` matches a published skill, pin its SKILL.md to the user's - // message and re-submit. Lets you fire a stored procedure on demand - // without the model having to discover the skill itself. + // If `rawCmd` matches a published skill, the backend records usage and + // returns a skill-pinned message to submit as the next agent turn. try { - const skillRes = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(rawCmd)}/markdown`, { credentials: 'same-origin' }); - if (skillRes.ok) { - const skillData = await skillRes.json(); - const md = skillData.markdown || ''; - if (md) { - _showUser(); - const request = args.join(' ').trim(); - const msgInput = document.getElementById('message'); - const composed = - `Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n` + - `--- BEGIN SKILL ---\n${md}\n--- END SKILL ---\n\n` + - (request ? `Request: ${request}` : `Request: (use the skill as appropriate)`); - if (msgInput) { - msgInput.value = composed; - const form = document.getElementById('chat-form'); - if (form && typeof form.requestSubmit === 'function') { - form.requestSubmit(); - } else if (form) { - form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true })); - } - } - return true; - } + const catalog = await _loadSkillSlashCatalog(false); + if (catalog.some(s => s.name === rawCmd)) { + _showUser(); + return await _invokeSkillByName(rawCmd, args.join(' ').trim(), ctx); } } catch (_) { /* fall through to fuzzy match */ } @@ -6158,10 +6435,13 @@ export function initSlashCommands(deps) { const providerEl = e.target.closest('.setup-clickable-provider'); if (providerEl) { e.preventDefault(); + const providerKey = providerEl.dataset.setupProvider || providerEl.textContent.trim(); const providerName = providerEl.textContent.trim(); const messageInput = document.getElementById('message'); if (messageInput) { - const text = providerName + ' sk-'; + const text = providerEl.dataset.setupKind === 'device-auth' + ? '/setup ' + providerKey + : providerName + ' sk-'; messageInput.value = text; messageInput.dispatchEvent(new Event('input', { bubbles: true })); messageInput.focus(); diff --git a/tests/test_admin_device_flow_static.py b/tests/test_admin_device_flow_static.py new file mode 100644 index 000000000..94f837340 --- /dev/null +++ b/tests/test_admin_device_flow_static.py @@ -0,0 +1,65 @@ +"""Static regressions for Add Models provider device-flow UX.""" + +from pathlib import Path + + +_REPO = Path(__file__).resolve().parent.parent +_INDEX = (_REPO / "static" / "index.html").read_text(encoding="utf-8") +_ADMIN = (_REPO / "static" / "js" / "admin.js").read_text(encoding="utf-8") + + +def _between(src: str, start: str, end: str) -> str: + start_idx = src.index(start) + end_idx = src.index(end, start_idx) + return src[start_idx:end_idx] + + +def test_copilot_and_chatgpt_subscription_are_dropdown_device_auth_options(): + assert 'value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot' in _INDEX + assert 'value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription' in _INDEX + assert 'id="adm-deviceAuthStatus"' in _INDEX + + +def test_provider_selection_is_inert_and_add_button_starts_device_flow(): + change_block = _between(_ADMIN, "provider.addEventListener('change'", "urlInput.addEventListener('input'") + add_block = _between(_ADMIN, "el('adm-epAddBtn').addEventListener('click'", "async function _startProviderDeviceAuth") + + assert "_startProviderDeviceAuth" not in change_block + assert "_startProviderDeviceAuth(deviceAuthProvider" in add_block + + +def test_device_auth_selection_disables_and_dims_api_test_button(): + form_block = _between(_ADMIN, "function _setApiFormForProvider()", "function _renderPickerMenu()") + + assert "testBtn.disabled = true" in form_block + assert "testBtn.style.opacity = '0.45'" in form_block + assert "testBtn.style.cursor = 'not-allowed'" in form_block + assert "testBtn.disabled = false" in form_block + assert "testBtn.style.opacity = ''" in form_block + assert "testBtn.style.cursor = ''" in form_block + + +def test_device_auth_keeps_manual_auth_button_without_auto_opening_tab(): + auth_block = _between(_ADMIN, "async function _startProviderDeviceAuth", "// Local \"Add\" button") + + assert "Authorize with OpenAI" in auth_block + assert "Authorize on GitHub" in auth_block + assert "adm-copilot-panel" in auth_block + assert "adm-device-auth-copy" in auth_block + assert "openWindow: () => {}" in auth_block + assert "A new tab opened" not in auth_block + + +def test_loud_oauth_copy_and_removed_button_hooks_do_not_return(): + forbidden = [ + "Click Add to start", + "uses account sign-in", + "Uses ChatGPT/Codex OAuth, not an OpenAI API key.", + "adm-chatgptStatus", + "adm-chatgptConnectBtn", + "adm-copilotConnectBtn", + "adm-copilotStatus", + ] + for needle in forbidden: + assert needle not in _INDEX + assert needle not in _ADMIN diff --git a/tests/test_chatgpt_subscription_routes.py b/tests/test_chatgpt_subscription_routes.py new file mode 100644 index 000000000..8661efe37 --- /dev/null +++ b/tests/test_chatgpt_subscription_routes.py @@ -0,0 +1,280 @@ +"""DB-backed ChatGPT Subscription endpoint provisioning tests.""" + +import json + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from core.database import Base, ModelEndpoint, ProviderAuthSession +import routes.chatgpt_subscription_routes as csr + + +def _mem_db(monkeypatch): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(bind=engine) + # Match production (core.database SessionLocal is autoflush=False): a pending + # db.delete(ep) is NOT flushed before the orphan-auth reference-count SELECT, + # which is exactly why _delete_orphaned_provider_auth needs exclude_ep_id. + TestSessionLocal = sessionmaker(bind=engine, autoflush=False) + monkeypatch.setattr(csr, "SessionLocal", TestSessionLocal) + return TestSessionLocal + + +def test_provision_creates_owner_scoped_auth_session_and_endpoint(monkeypatch): + TestSessionLocal = _mem_db(monkeypatch) + monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5", "o4-mini"]) + + res = csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice") + + assert res["name"] == "ChatGPT Subscription" + assert res["base_url"] == csr.chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL + assert res["models"] == ["gpt-5.5", "o4-mini"] + + db = TestSessionLocal() + try: + auth = db.query(ProviderAuthSession).first() + ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == res["id"]).first() + assert auth is not None + assert auth.owner == "alice" + assert auth.provider == csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER + assert auth.access_token == "AT" + assert auth.refresh_token == "RT" + assert auth.auth_mode == "chatgpt" + assert ep is not None + assert ep.owner == "alice" + assert ep.api_key is None + assert ep.provider_auth_id == auth.id + assert ep.endpoint_kind == "api" + assert ep.model_refresh_mode == "manual" + assert ep.supports_tools is False + assert json.loads(ep.cached_models) == ["gpt-5.5", "o4-mini"] + finally: + db.close() + + +def test_provision_refreshes_existing_auth_session_and_endpoint(monkeypatch): + TestSessionLocal = _mem_db(monkeypatch) + monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5"]) + + first = csr._provision_endpoint({"access_token": "OLD", "refresh_token": "OLD-RT"}, "bob") + second = csr._provision_endpoint({"access_token": "NEW", "refresh_token": "NEW-RT"}, "bob") + + assert first["id"] == second["id"] + db = TestSessionLocal() + try: + auth_rows = db.query(ProviderAuthSession).filter(ProviderAuthSession.owner == "bob").all() + ep_rows = db.query(ModelEndpoint).filter(ModelEndpoint.owner == "bob").all() + assert len(auth_rows) == 1 + assert len(ep_rows) == 1 + assert auth_rows[0].access_token == "NEW" + assert auth_rows[0].refresh_token == "NEW-RT" + assert ep_rows[0].provider_auth_id == auth_rows[0].id + finally: + db.close() + + +def test_provision_rejects_missing_tokens(monkeypatch): + _mem_db(monkeypatch) + with pytest.raises(ValueError, match="missing access_token or refresh_token"): + csr._provision_endpoint({"access_token": "AT"}, "alice") + + +def test_provision_rejects_accounts_without_usable_models(monkeypatch): + _mem_db(monkeypatch) + monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: []) + + with pytest.raises(ValueError, match="no usable Codex models"): + csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice") + + +def _add_auth_and_endpoints(db, *, auth_id="auth1", ep_ids=("ep1",)): + db.add(ProviderAuthSession( + id=auth_id, provider=csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER, + owner="alice", base_url="https://chatgpt.com/backend-api/codex", + refresh_token="RT", auth_mode="chatgpt", + )) + for ep_id in ep_ids: + db.add(ModelEndpoint( + id=ep_id, name="ChatGPT Subscription", + base_url="https://chatgpt.com/backend-api/codex", + provider_auth_id=auth_id, owner="alice", + )) + db.commit() + + +def test_delete_orphaned_provider_auth_revokes_when_last_endpoint_removed(monkeypatch): + from routes.model_routes import _delete_orphaned_provider_auth + + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + _add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",)) + # Mirror the production delete route: db.delete(ep) is issued (but not yet + # flushed/committed) BEFORE the orphan check runs. + ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() + db.delete(ep1) + # ep1 (its only referencing endpoint) is being deleted, so the auth clears. + assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is True + db.commit() + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None + finally: + db.close() + + +def test_delete_orphaned_provider_auth_requires_exclude_ep_id_for_pending_delete(monkeypatch): + from routes.model_routes import _delete_orphaned_provider_auth + + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + _add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",)) + ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() + db.delete(ep1) + # Without exclude_ep_id, the un-flushed pending delete leaves ep1 visible + # to the reference-count SELECT (autoflush=False), so the helper must + # conservatively KEEP the auth row. This is the bug exclude_ep_id fixes. + assert _delete_orphaned_provider_auth(db, "auth1") is False + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None + finally: + db.close() + + +def test_delete_orphaned_provider_auth_keeps_auth_while_another_endpoint_uses_it(monkeypatch): + from routes.model_routes import _delete_orphaned_provider_auth + + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + _add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2")) + # ep2 still references auth1, so deleting ep1 must NOT revoke it. + assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None + finally: + db.close() + + +def test_delete_orphaned_provider_auth_noop_without_auth_id(monkeypatch): + from routes.model_routes import _delete_orphaned_provider_auth + + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + assert _delete_orphaned_provider_auth(db, None, exclude_ep_id="ep1") is False + finally: + db.close() + + +def test_delete_orphaned_provider_auth_noop_when_auth_row_missing(monkeypatch): + from routes.model_routes import _delete_orphaned_provider_auth + + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + # Endpoint points at an auth_id whose ProviderAuthSession is already gone. + db.add(ModelEndpoint( + id="ep1", name="ChatGPT Subscription", + base_url="https://chatgpt.com/backend-api/codex", + provider_auth_id="ghost", owner="alice", + )) + db.commit() + ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() + db.delete(ep1) + # No other endpoint references "ghost" and no auth row exists → no-op, no error. + assert _delete_orphaned_provider_auth(db, "ghost", exclude_ep_id="ep1") is False + finally: + db.close() + + +def _delete_route(monkeypatch, TestSessionLocal): + """Resolve the real DELETE /model-endpoints/{ep_id} route, wired to the test DB. + + Neutralizes the route's unrelated cleanup side effects (settings/prefs files, + in-memory session manager) so the test stays hermetic and focuses on the + provider-auth revocation wiring. + """ + import routes.model_routes as mr + import routes.prefs_routes as prefs_routes + import src.ai_interaction as ai_interaction + + monkeypatch.setattr(mr, "SessionLocal", TestSessionLocal) + monkeypatch.setattr(mr, "require_admin", lambda request: None) + monkeypatch.setattr(mr, "_load_settings", lambda: {}) + monkeypatch.setattr(mr, "_save_settings", lambda settings: None) + monkeypatch.setattr(prefs_routes, "_load", lambda: {}) + monkeypatch.setattr(prefs_routes, "_save", lambda prefs: None) + monkeypatch.setattr(ai_interaction, "get_session_manager", lambda: None) + + router = mr.setup_model_routes(model_discovery=None) + for route in router.routes: + if getattr(route, "path", "") == "/api/model-endpoints/{ep_id}" and "DELETE" in getattr(route, "methods", set()): + return route.endpoint + raise AssertionError("DELETE /api/model-endpoints/{ep_id} not found") + + +def test_delete_endpoint_route_revokes_orphaned_provider_auth(monkeypatch): + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + _add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",)) + finally: + db.close() + + delete_endpoint = _delete_route(monkeypatch, TestSessionLocal) + result = delete_endpoint("ep1", object()) + + assert result["deleted"] is True + # The last (only) endpoint backed by auth1 is gone, so the route revokes it. + assert result["cleared_provider_auth"] is True + db = TestSessionLocal() + try: + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None + assert db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() is None + finally: + db.close() + + +def test_delete_endpoint_route_keeps_auth_when_shared(monkeypatch): + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + _add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2")) + finally: + db.close() + + delete_endpoint = _delete_route(monkeypatch, TestSessionLocal) + result = delete_endpoint("ep1", object()) + + assert result["deleted"] is True + # ep2 still references auth1, so deleting ep1 must NOT revoke the credentials. + assert result["cleared_provider_auth"] is False + db = TestSessionLocal() + try: + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None + finally: + db.close() + + +def test_delete_orphaned_provider_auth_revokes_only_after_last_of_several(monkeypatch): + from routes.model_routes import _delete_orphaned_provider_auth + + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + _add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2")) + + # Delete ep1 first: ep2 still references auth1, so the row survives. + ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() + db.delete(ep1) + assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False + db.commit() + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None + + # Now delete the last endpoint ep2: the auth row is finally cleared. + ep2 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep2").first() + db.delete(ep2) + assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep2") is True + db.commit() + assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None + finally: + db.close() diff --git a/tests/test_device_flow_routes.py b/tests/test_device_flow_routes.py new file mode 100644 index 000000000..d8d01d8ce --- /dev/null +++ b/tests/test_device_flow_routes.py @@ -0,0 +1,138 @@ +"""Shared device-flow route helper regressions.""" + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + +from routes import device_flow + + +def _client(monkeypatch, now_ref, start_flow, poll_flow): + store = device_flow.PendingDeviceFlowStore(time_func=lambda: now_ref[0]) + router = device_flow.create_device_flow_router( + prefix="/api/test-device", + tags=["test-device"], + store=store, + start_flow=start_flow, + poll_flow=poll_flow, + ) + app = FastAPI() + app.include_router(router) + monkeypatch.setattr(device_flow, "require_admin", lambda request: None) + return TestClient(app) + + +def _start(_request, _form): + return device_flow.DeviceFlowStart( + pending={"secret": "server-only", "owner": "alice"}, + response={"user_code": "ABCD-EFGH", "verification_uri": "https://example.test/device"}, + interval=5, + expires_in=20, + ) + + +def test_pending_poll_is_throttled_until_interval(monkeypatch): + now = [100.0] + calls = [] + + def poll(_request, pending): + calls.append(dict(pending)) + return device_flow.DeviceFlowPoll.pending() + + client = _client(monkeypatch, now, _start, poll) + start = client.post("/api/test-device/device/start").json() + + first = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]}) + assert first.json() == {"status": "pending"} + assert calls == [{"secret": "server-only", "owner": "alice"}] + + second = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]}) + assert second.json() == {"status": "pending"} + assert len(calls) == 1 + + now[0] += 5 + third = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]}) + assert third.json() == {"status": "pending"} + assert len(calls) == 2 + + +def test_slow_down_updates_poll_interval(monkeypatch): + now = [100.0] + calls = [] + + def poll(_request, _pending): + calls.append(now[0]) + if len(calls) == 1: + return device_flow.DeviceFlowPoll.slow_down(interval=10) + return device_flow.DeviceFlowPoll.authorized({"id": "ep1", "models": ["gpt-4o"]}) + + client = _client(monkeypatch, now, _start, poll) + poll_id = client.post("/api/test-device/device/start").json()["poll_id"] + + assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"} + now[0] += 9 + assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"} + assert len(calls) == 1 + + now[0] += 1 + assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == { + "status": "authorized", + "endpoint": {"id": "ep1", "models": ["gpt-4o"]}, + } + + +def test_authorized_and_failed_polls_remove_pending_session(monkeypatch): + now = [100.0] + outcomes = [ + device_flow.DeviceFlowPoll.authorized({"id": "ep1"}), + device_flow.DeviceFlowPoll.failed("access_denied"), + ] + + def poll(_request, _pending): + return outcomes.pop(0) + + client = _client(monkeypatch, now, _start, poll) + first = client.post("/api/test-device/device/start").json()["poll_id"] + second = client.post("/api/test-device/device/start").json()["poll_id"] + + assert client.post("/api/test-device/device/poll", data={"poll_id": first}).json()["status"] == "authorized" + assert client.post("/api/test-device/device/poll", data={"poll_id": first}).status_code == 404 + + assert client.post("/api/test-device/device/poll", data={"poll_id": second}).json() == { + "status": "failed", + "error": "access_denied", + } + assert client.post("/api/test-device/device/poll", data={"poll_id": second}).status_code == 404 + + +def test_cancel_and_expiry_remove_pending_session(monkeypatch): + now = [100.0] + + def poll(_request, _pending): + return device_flow.DeviceFlowPoll.pending() + + client = _client(monkeypatch, now, _start, poll) + cancelled = client.post("/api/test-device/device/start").json()["poll_id"] + assert client.post("/api/test-device/device/cancel", data={"poll_id": cancelled}).json() == {"status": "cancelled"} + assert client.post("/api/test-device/device/poll", data={"poll_id": cancelled}).status_code == 404 + + expired = client.post("/api/test-device/device/start").json()["poll_id"] + now[0] += 21 + assert client.post("/api/test-device/device/poll", data={"poll_id": expired}).status_code == 404 + + +def test_routes_are_admin_gated(monkeypatch): + now = [100.0] + + def poll(_request, _pending): + return device_flow.DeviceFlowPoll.pending() + + client = _client(monkeypatch, now, _start, poll) + + def deny(_request): + raise HTTPException(403, "admin required") + + monkeypatch.setattr(device_flow, "require_admin", deny) + assert client.post("/api/test-device/device/start").status_code == 403 + assert client.post("/api/test-device/device/poll", data={"poll_id": "missing"}).status_code == 403 + assert client.post("/api/test-device/device/cancel", data={"poll_id": "missing"}).status_code == 403 diff --git a/tests/test_endpoint_probing.py b/tests/test_endpoint_probing.py index 0206ebfb7..ea4835c16 100644 --- a/tests/test_endpoint_probing.py +++ b/tests/test_endpoint_probing.py @@ -25,32 +25,36 @@ from unittest.mock import MagicMock import httpx import pytest -from tests.helpers.import_state import clear_fake_endpoint_resolver_modules +from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state -# Match test_model_routes.py: if another test stubbed src.endpoint_resolver -# during collection, drop the stub so the real URL helpers load here. -clear_fake_endpoint_resolver_modules() +with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"): + # Match test_model_routes.py: if another test stubbed src.endpoint_resolver + # during collection, drop the stub so the real URL helpers load here. + clear_fake_endpoint_resolver_modules() -if "core.database" not in sys.modules: - _core_db = types.ModuleType("core.database") - for _name in [ - "SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document", - "DocumentVersion", "GalleryImage", "GalleryAlbum", "Note", - "CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer", - ]: - setattr(_core_db, _name, MagicMock()) - sys.modules["core.database"] = _core_db + if "core.database" not in sys.modules: + _core_db = types.ModuleType("core.database") + for _name in [ + "SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document", + "DocumentVersion", "GalleryImage", "GalleryAlbum", "Note", + "CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer", + "ProviderAuthSession", "Base", + ]: + setattr(_core_db, _name, MagicMock()) + _core_db.utcnow_naive = MagicMock() + sys.modules["core.database"] = _core_db -import routes.model_routes as model_routes -import src.endpoint_resolver as endpoint_resolver -from routes.model_routes import ( - _probe_endpoint, - _ping_endpoint, - _probe_single_model, - _classify_endpoint, - _rewrite_loopback_for_docker, - _PROVIDER_CURATED, -) + import routes.model_routes as model_routes + import src.endpoint_resolver as endpoint_resolver + from routes.model_routes import ( + _probe_endpoint, + _ping_endpoint, + _probe_single_model, + _resolve_probe_key, + _classify_endpoint, + _rewrite_loopback_for_docker, + _PROVIDER_CURATED, + ) def _patch_resolve(monkeypatch): @@ -117,6 +121,26 @@ class TestProbeEndpointParsing: ) assert _probe_endpoint("https://api.example.com/v1") == [] + def test_chatgpt_subscription_probe_uses_discovery_only(self, monkeypatch): + _patch_resolve(monkeypatch) + calls = [] + + def fake_fetch(access_token, timeout=5): + calls.append((access_token, timeout)) + return ["gpt-5.5"] + + monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", fake_fetch) + + assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS", timeout=7) == ["gpt-5.5"] + assert calls == [("ACCESS", 7)] + + def test_chatgpt_subscription_probe_without_discovery_returns_empty(self, monkeypatch): + _patch_resolve(monkeypatch) + monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", lambda access_token, timeout=5: []) + + assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS") == [] + assert _probe_endpoint("https://chatgpt.com/backend-api/codex") == [] + # ── _ping_endpoint: reachability classification ── @@ -321,6 +345,51 @@ class TestProbeSingleModel: _probe_single_model("https://api.anthropic.com/v1", "sk-ant", "claude-sonnet-4-5", with_tools=True) assert "input_schema" in captured["payload"]["tools"][0] + def test_chatgpt_subscription_skips_completion_probe(self, monkeypatch): + # This provider speaks the Responses/Codex API. A chat-completions probe + # would 400 and (via the re-probe flow) hide every model, so it must be + # short-circuited as discovery-only without any HTTP call. + _patch_resolve(monkeypatch) + + def boom(*args, **kwargs): + raise AssertionError("must not send a completion probe for chatgpt-subscription") + + monkeypatch.setattr(model_routes.httpx, "post", boom) + result = _probe_single_model("https://chatgpt.com/backend-api/codex", None, "gpt-5.1-codex") + assert result["status"] == "ok" + assert result.get("skipped") is True + # Pin the full documented return shape — downstream JSON/UI reads latency_ms. + assert result["latency_ms"] == 0 + + +# ── _resolve_probe_key: static key vs provider-auth runtime token ── + +class TestResolveProbeKey: + def test_static_endpoint_uses_api_key(self): + ep = types.SimpleNamespace(id="e1", api_key="sk-static", provider_auth_id=None, owner=None) + assert _resolve_probe_key(ep) == "sk-static" + + def test_provider_auth_endpoint_resolves_runtime_token(self, monkeypatch): + ep = types.SimpleNamespace(id="e2", api_key=None, provider_auth_id="auth123", owner="alice") + seen = {} + + def fake_runtime(endpoint, owner=None): + seen["owner"] = owner + return ("https://chatgpt.com/backend-api/codex", "live-bearer") + + monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", fake_runtime) + assert _resolve_probe_key(ep) == "live-bearer" + assert seen["owner"] == "alice" + + def test_provider_auth_resolution_failure_returns_none(self, monkeypatch): + ep = types.SimpleNamespace(id="e3", api_key=None, provider_auth_id="auth123", owner=None) + + def boom(endpoint, owner=None): + raise RuntimeError("reauth required") + + monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", boom) + assert _resolve_probe_key(ep) is None + # ── _classify_endpoint: Tailscale CGNAT range ── diff --git a/tests/test_llm_core_temperature.py b/tests/test_llm_core_temperature.py index 00be525b7..f49d3dba0 100644 --- a/tests/test_llm_core_temperature.py +++ b/tests/test_llm_core_temperature.py @@ -75,6 +75,28 @@ def test_normal_model_payload_keeps_temperature_above_one(monkeypatch): assert payload["temperature"] == 1.2 +def test_chatgpt_subscription_payload_uses_max_output_tokens(): + payload = llm_core._build_chatgpt_responses_payload( + "gpt-5.1-codex", + [{"role": "user", "content": "Say OK"}], + temperature=0.2, + max_tokens=37, + ) + + assert payload["max_output_tokens"] == 37 + + +def test_chatgpt_subscription_payload_omits_empty_max_output_tokens(): + payload = llm_core._build_chatgpt_responses_payload( + "gpt-5.1-codex", + [{"role": "user", "content": "Say OK"}], + temperature=0.2, + max_tokens=0, + ) + + assert "max_output_tokens" not in payload + + def _anthropic_payload(temperature): return llm_core._build_anthropic_payload( "claude-3-5-sonnet", diff --git a/tests/test_model_routes.py b/tests/test_model_routes.py index 54a0b4125..a39b3e7ae 100644 --- a/tests/test_model_routes.py +++ b/tests/test_model_routes.py @@ -11,49 +11,51 @@ from types import SimpleNamespace import httpx import pytest -from tests.helpers.import_state import clear_fake_endpoint_resolver_modules +from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state -# Other tests stub this module during collection. These helper tests need -# the real URL normalization helpers so Anthropic /v1 handling is covered. -clear_fake_endpoint_resolver_modules() +with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"): + # Other tests stub this module during collection. These helper tests need + # the real URL normalization helpers so Anthropic /v1 handling is covered. + clear_fake_endpoint_resolver_modules() -if "core.database" not in sys.modules: - _core_db = types.ModuleType("core.database") - for _name in [ - "SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document", - "DocumentVersion", "GalleryImage", "GalleryAlbum", "Note", - "CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", - "McpServer", - ]: - setattr(_core_db, _name, MagicMock()) - sys.modules["core.database"] = _core_db + if "core.database" not in sys.modules: + _core_db = types.ModuleType("core.database") + for _name in [ + "SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document", + "DocumentVersion", "GalleryImage", "GalleryAlbum", "Note", + "CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", + "McpServer", "ProviderAuthSession", "Base", + ]: + setattr(_core_db, _name, MagicMock()) + _core_db.utcnow_naive = MagicMock() + sys.modules["core.database"] = _core_db -import routes.model_routes as model_routes -import src.database as src_database -import src.endpoint_resolver as endpoint_resolver -import src.llm_core as llm_core -from routes.model_routes import ( - _match_provider_curated, - _curate_models, - _visible_models, - _normalize_model_ids, - _api_key_fingerprint, - _is_chat_model, - _classify_endpoint, - _effective_endpoint_kind, - _probe_endpoint, - _ping_endpoint, - _parse_model_list, - _normalize_refresh_mode, - _truthy, - _speech_settings_using_endpoint, - _clear_speech_settings_for_endpoint, - _endpoint_settings_using_endpoint, - _clear_endpoint_settings_for_endpoint, - _clear_user_pref_endpoint_refs, - _PROVIDER_CURATED, -) -from src.llm_core import ANTHROPIC_MODELS + import routes.model_routes as model_routes + import src.database as src_database + import src.endpoint_resolver as endpoint_resolver + import src.llm_core as llm_core + from routes.model_routes import ( + _match_provider_curated, + _curate_models, + _visible_models, + _normalize_model_ids, + _api_key_fingerprint, + _is_chat_model, + _classify_endpoint, + _effective_endpoint_kind, + _probe_endpoint, + _ping_endpoint, + _parse_model_list, + _normalize_refresh_mode, + _truthy, + _speech_settings_using_endpoint, + _clear_speech_settings_for_endpoint, + _endpoint_settings_using_endpoint, + _clear_endpoint_settings_for_endpoint, + _clear_user_pref_endpoint_refs, + _PROVIDER_CURATED, + ) + from src.llm_core import ANTHROPIC_MODELS # ── speech endpoint settings ── @@ -687,8 +689,7 @@ class _PinnedFakeRequest: def _get_route(path, method): - from routes.model_routes import setup_model_routes - router = setup_model_routes(model_discovery=None) + router = model_routes.setup_model_routes(model_discovery=None) for route in router.routes: if getattr(route, "path", "") == path and method in getattr(route, "methods", set()): return route.endpoint @@ -787,6 +788,55 @@ def test_reprobe_preserves_pinned_models(monkeypatch): assert json.loads(ep.cached_models) == ["m1"] +def test_reprobe_chatgpt_subscription_does_not_hide_models(monkeypatch): + # The whole point of the _probe_single_model short-circuit is that re-probing + # a chatgpt-subscription endpoint must NOT mark every (un-probeable) model as + # failed and write them all into hidden_models. Assert that end-to-end at the + # route level, with the REAL _probe_single_model doing the skip. + ep = _make_endpoint( + base_url="https://chatgpt.com/backend-api/codex", + api_key=None, + hidden_models=json.dumps(["stale-hidden"]), + ) + db = _PinnedFakeDb([ep]) + monkeypatch.setattr(model_routes, "SessionLocal", lambda: db) + monkeypatch.setattr(model_routes, "require_admin", lambda request: None) + monkeypatch.setattr(model_routes, "_normalize_base", lambda url: url.rstrip("/")) + monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["gpt-5.1-codex", "gpt-5.1"]) + monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True) + # Any completion probe would be a bug for this provider. + monkeypatch.setattr( + model_routes.httpx, "post", + lambda *a, **k: (_ for _ in ()).throw(AssertionError("must not probe chatgpt-subscription")), + ) + endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET") + + response = endpoint("ep1", _PinnedFakeRequest()) + chunks = [] + + async def _drain(): + async for chunk in response.body_iterator: + chunks.append(chunk.decode() if isinstance(chunk, bytes) else chunk) + + asyncio.run(_drain()) + + events = [] + for chunk in chunks: + for line in chunk.splitlines(): + if line.startswith("data: "): + events.append(json.loads(line[len("data: "):])) + + done = next(e for e in events if e.get("type") == "probe_done") + results = [e for e in events if e.get("type") == "probe_result"] + + # Every model was skipped as ok; none failed → nothing hidden. + assert done["hidden"] == 0 + assert done["ok"] == len(results) == 2 + assert all(r["status"] == "ok" and r.get("skipped") is True for r in results) + # The stale hidden_models is cleared, not repopulated with every model. + assert ep.hidden_models is None + + def test_visible_models_handles_malformed_strings(): # Non-JSON cached/pinned strings are treated as comma/newline lists and # never raise; a malformed hidden string is normalized too. diff --git a/tests/test_provider_detection.py b/tests/test_provider_detection.py index fb53291bf..372a3950d 100644 --- a/tests/test_provider_detection.py +++ b/tests/test_provider_detection.py @@ -42,6 +42,10 @@ class TestHostMatch: class TestDetectProviderRealHosts: + def test_chatgpt_subscription_codex_backend(self): + assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex") == "chatgpt-subscription" + assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex/responses") == "chatgpt-subscription" + def test_anthropic(self): assert llm_core._detect_provider("https://api.anthropic.com") == "anthropic" @@ -93,6 +97,12 @@ class TestBuildersRejectLookalikeHosts: def test_real_anthropic_chat(self): assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages" + def test_chatgpt_subscription_chat_uses_responses(self): + assert build_chat_url("https://chatgpt.com/backend-api/codex") == "https://chatgpt.com/backend-api/codex/responses" + + def test_chatgpt_subscription_models_uses_no_live_probe(self): + assert build_models_url("https://chatgpt.com/backend-api/codex") is None + def test_lookalike_anthropic_chat_is_openai(self): assert build_chat_url("https://notanthropic.com") == "https://notanthropic.com/chat/completions" diff --git a/tests/test_provider_device_flow_js.py b/tests/test_provider_device_flow_js.py new file mode 100644 index 000000000..37bcd29a5 --- /dev/null +++ b/tests/test_provider_device_flow_js.py @@ -0,0 +1,157 @@ +"""Node-driven tests for the shared provider device-flow runner.""" + +import json +import shutil +import subprocess +from pathlib import Path + +import pytest + +_REPO = Path(__file__).resolve().parent.parent +_HELPER = _REPO / "static" / "js" / "providerDeviceFlow.js" +pytestmark = pytest.mark.skipif(not shutil.which("node"), reason="node not on PATH") + + +def _run_node(script: str): + proc = subprocess.run( + ["node", "--input-type=module"], + input=script, + capture_output=True, + text=True, + cwd=str(_REPO), + timeout=30, + ) + assert proc.returncode == 0, proc.stderr + return json.loads(proc.stdout.strip()) + + +def test_copilot_success_uses_complete_verification_uri(): + js = f""" + import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}'; + const calls = []; + const opened = []; + let polls = 0; + const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }}); + const fetchImpl = async (url) => {{ + calls.push(url); + if (url.endsWith('/device/start')) {{ + return response(true, 200, {{ + poll_id: 'poll-1', + user_code: 'GH-CODE', + verification_uri: 'https://github.com/login/device', + verification_uri_complete: 'https://github.com/login/device?user_code=GH-CODE', + interval: 2, + expires_in: 30, + }}); + }} + polls += 1; + return response(true, 200, polls === 1 + ? {{ status: 'pending' }} + : {{ status: 'authorized', endpoint: {{ id: 'ep1', models: ['gpt-4o'] }} }} + ); + }}; + const result = await runProviderDeviceFlow('copilot', {{ + fetchImpl, + openWindow: (url) => opened.push(url), + sleep: async () => {{}}, + now: () => 0, + }}); + console.log(JSON.stringify({{ result, calls, opened }})); + """ + out = _run_node(js) + assert out["result"]["status"] == "authorized" + assert out["result"]["endpoint"]["id"] == "ep1" + assert out["opened"] == ["https://github.com/login/device?user_code=GH-CODE"] + assert out["calls"] == ["/api/copilot/device/start", "/api/copilot/device/poll", "/api/copilot/device/poll"] + + +def test_chatgpt_success_uses_plain_verification_uri(): + js = f""" + import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}'; + const opened = []; + const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }}); + const fetchImpl = async (url) => {{ + if (url.endsWith('/device/start')) {{ + return response(true, 200, {{ + poll_id: 'poll-1', + user_code: 'OA-CODE', + verification_uri: 'https://auth.openai.com/codex/device', + interval: 2, + expires_in: 30, + }}); + }} + return response(true, 200, {{ status: 'authorized', endpoint: {{ id: 'chatgpt', models: ['gpt-5.5'] }} }}); + }}; + const result = await runProviderDeviceFlow('chatgpt-subscription', {{ + fetchImpl, + openWindow: (url) => opened.push(url), + sleep: async () => {{}}, + now: () => 0, + }}); + console.log(JSON.stringify({{ result, opened }})); + """ + out = _run_node(js) + assert out["result"]["status"] == "authorized" + assert out["opened"] == ["https://auth.openai.com/codex/device"] + + +def test_start_errors_surface_backend_detail(): + js = f""" + import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}'; + const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }}); + try {{ + await runProviderDeviceFlow('copilot', {{ + fetchImpl: async () => response(false, 502, {{ detail: 'GitHub device-code request failed: upstream down' }}), + openWindow: () => {{}}, + sleep: async () => {{}}, + now: () => 0, + }}); + }} catch (err) {{ + console.log(JSON.stringify({{ message: err.message }})); + }} + """ + out = _run_node(js) + assert out["message"] == "GitHub device-code request failed: upstream down" + + +def test_thrown_fetch_errors_are_preserved(): + js = f""" + import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}'; + try {{ + await runProviderDeviceFlow('chatgpt-subscription', {{ + fetchImpl: async () => {{ throw new Error('network offline'); }}, + openWindow: () => {{}}, + sleep: async () => {{}}, + now: () => 0, + }}); + }} catch (err) {{ + console.log(JSON.stringify({{ message: err.message }})); + }} + """ + out = _run_node(js) + assert out["message"] == "network offline" + + +def test_expired_flow_returns_expired_status(): + js = f""" + import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}'; + let currentTime = 0; + const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }}); + const result = await runProviderDeviceFlow('copilot', {{ + fetchImpl: async (url) => url.endsWith('/device/start') + ? response(true, 200, {{ + poll_id: 'poll-1', + user_code: 'GH-CODE', + verification_uri: 'https://github.com/login/device', + interval: 2, + expires_in: 1, + }}) + : response(true, 200, {{ status: 'pending' }}), + openWindow: () => {{}}, + sleep: async () => {{ currentTime += 2000; }}, + now: () => currentTime, + }}); + console.log(JSON.stringify(result)); + """ + out = _run_node(js) + assert out == {"status": "expired"} diff --git a/tests/test_research_endpoint_owner_scope.py b/tests/test_research_endpoint_owner_scope.py index baa71d382..e30e5d994 100644 --- a/tests/test_research_endpoint_owner_scope.py +++ b/tests/test_research_endpoint_owner_scope.py @@ -24,7 +24,7 @@ _sd = types.ModuleType("src.database") _sd.ModelEndpoint = MagicMock() sys.modules.setdefault("src.database", _sd) -from routes.research_routes import _owned_enabled_endpoint # noqa: E402 +from routes.research_routes import _owned_enabled_endpoint, _resolve_endpoint_runtime # noqa: E402 class _Predicate: @@ -129,3 +129,29 @@ def test_null_owner_is_legacy_single_user_noop(): rows = [_ep("ep-x", "bob"), _ep("ep-y", "alice")] ep = _resolve(rows, None, "ep-x") assert ep is not None and ep.id == "ep-x" + + +def test_runtime_resolution_uses_provider_auth_for_chatgpt_subscription(monkeypatch): + ep = SimpleNamespace( + id="ep-chatgpt", + owner="alice", + base_url="https://chatgpt.com/backend-api/codex", + api_key=None, + provider_auth_id="auth-1", + cached_models='["gpt-5.5"]', + hidden_models=None, + ) + + monkeypatch.setattr( + "src.chatgpt_subscription.resolve_runtime_credentials", + lambda auth_id, owner=None: { + "base_url": "https://chatgpt.com/backend-api/codex", + "api_key": "fresh-access-token", + }, + ) + + url, model, headers = _resolve_endpoint_runtime(ep, owner="alice", model="") + + assert url == "https://chatgpt.com/backend-api/codex/responses" + assert model == "gpt-5.5" + assert headers["Authorization"] == "Bearer fresh-access-token" diff --git a/tests/test_resolve_session_auth_chatgpt.py b/tests/test_resolve_session_auth_chatgpt.py new file mode 100644 index 000000000..ebba8298d --- /dev/null +++ b/tests/test_resolve_session_auth_chatgpt.py @@ -0,0 +1,215 @@ +"""resolve_session_auth must not persist the ChatGPT Subscription bearer. + +The ChatGPT Subscription access token is a short-lived OAuth bearer re-resolved +(and refreshed) on every request. resolve_session_auth() may set it on the +in-memory session for the current request, but it must never write it back into +the sessions table — otherwise the live token sits at rest as +"Authorization: Bearer ...". Only the encrypted refresh token in +ProviderAuthSession is allowed to persist. +""" + +import types + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +import routes.chat_helpers as chat_helpers +import src.endpoint_resolver as endpoint_resolver +from core.database import Base, ModelEndpoint, Session as DbSession + +_CODEX_BASE = "https://chatgpt.com/backend-api/codex" + + +def _mem_db(monkeypatch): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(bind=engine) + # Match production SessionLocal (core.database) which is autoflush=False. + TestSessionLocal = sessionmaker(bind=engine, autoflush=False) + monkeypatch.setattr(chat_helpers, "SessionLocal", TestSessionLocal) + return TestSessionLocal + + +def test_chatgpt_subscription_auth_is_not_written_to_sessions_table(monkeypatch): + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + db.add(ModelEndpoint( + id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE, + provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None, + )) + db.add(DbSession( + id="sess1", name="chat", endpoint_url=_CODEX_BASE, + model="gpt-5.1-codex", owner="alice", headers={}, + )) + db.commit() + finally: + db.close() + + # A live access token is resolved at request time. + monkeypatch.setattr( + endpoint_resolver, "resolve_endpoint_runtime", + lambda ep, owner=None: (_CODEX_BASE, "live-access-token"), + ) + + sess = types.SimpleNamespace( + id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex", + owner="alice", headers={}, + ) + chat_helpers.resolve_session_auth(sess, "sess1", owner="alice") + + # In-memory session got request-local auth for this request... + assert any(k.lower() == "authorization" for k in sess.headers) + assert sess.headers["Authorization"] == "Bearer live-access-token" + + # ...but the DB row must NOT have the bearer persisted. + db = TestSessionLocal() + try: + row = db.query(DbSession).filter(DbSession.id == "sess1").first() + stored = row.headers or {} + assert not any(k.lower() == "authorization" for k in stored), ( + f"ChatGPT bearer leaked into sessions table: {stored}" + ) + finally: + db.close() + + +def test_non_subscription_auth_is_still_persisted_to_sessions_table(monkeypatch): + """The early-return must be scoped to ChatGPT Subscription only. + + Ordinary endpoints rely on resolve_session_auth() persisting the resolved + headers into the sessions table so they aren't re-resolved on every request. + If the is_chatgpt_subscription guard ever widened, this would silently break; + this test pins the persistence path as still reached for normal endpoints. + """ + base = "https://api.example.com/v1" + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + db.add(ModelEndpoint( + id="ep1", name="Generic", base_url=base, + owner="alice", is_enabled=True, api_key="sk-static", + )) + db.add(DbSession( + id="sess1", name="chat", endpoint_url=base, + model="gpt-x", owner="alice", headers={}, + )) + db.commit() + finally: + db.close() + + monkeypatch.setattr( + endpoint_resolver, "resolve_endpoint_runtime", + lambda ep, owner=None: (base, "sk-static"), + ) + + sess = types.SimpleNamespace( + id="sess1", endpoint_url=base, model="gpt-x", owner="alice", headers={}, + ) + chat_helpers.resolve_session_auth(sess, "sess1", owner="alice") + + # In-memory session got auth... + assert any(k.lower() in ("authorization", "x-api-key") for k in sess.headers) + + # ...AND it was persisted to the DB row (the normal, non-subscription path). + db = TestSessionLocal() + try: + row = db.query(DbSession).filter(DbSession.id == "sess1").first() + stored = row.headers or {} + assert any(k.lower() in ("authorization", "x-api-key") for k in stored), ( + f"non-subscription auth was not persisted: {stored}" + ) + finally: + db.close() + + +def test_chatgpt_subscription_clears_previously_persisted_bearer(monkeypatch): + """A bearer left at rest by an older code path is stripped on next resolve.""" + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + db.add(ModelEndpoint( + id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE, + provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None, + )) + # Simulate the leak: a stale bearer already sitting in the sessions table. + db.add(DbSession( + id="sess1", name="chat", endpoint_url=_CODEX_BASE, + model="gpt-5.1-codex", owner="alice", + headers={"Authorization": "Bearer stale-leaked-token"}, + )) + db.commit() + finally: + db.close() + + monkeypatch.setattr( + endpoint_resolver, + "resolve_endpoint_runtime", + lambda ep, owner=None: (_CODEX_BASE, "live-access-token"), + ) + + sess = types.SimpleNamespace( + id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex", + owner="alice", headers={}, + ) + chat_helpers.resolve_session_auth(sess, "sess1", owner="alice") + + # The stale bearer must have been stripped from the DB row. + db = TestSessionLocal() + try: + row = db.query(DbSession).filter(DbSession.id == "sess1").first() + stored = row.headers or {} + assert not any(k.lower() == "authorization" for k in stored), ( + f"stale ChatGPT bearer was not cleared: {stored}" + ) + finally: + db.close() + + +def test_chatgpt_subscription_fallback_auth_is_not_written_to_sessions_table(monkeypatch): + """Fallback endpoint selection must keep the resolved bearer request-local.""" + TestSessionLocal = _mem_db(monkeypatch) + db = TestSessionLocal() + try: + db.add(ModelEndpoint( + id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE, + provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None, + cached_models='["gpt-5.1-codex"]', + )) + db.add(DbSession( + id="sess1", name="chat", endpoint_url="https://old.example/v1", + model="old-model", owner="alice", headers={}, + )) + db.commit() + finally: + db.close() + + monkeypatch.setattr( + endpoint_resolver, + "resolve_endpoint_runtime", + lambda ep, owner=None: (_CODEX_BASE, "live-access-token"), + ) + + sess = types.SimpleNamespace( + id="sess1", endpoint_url="https://old.example/v1", model="old-model", + owner="alice", headers={}, + ) + result = chat_helpers.try_fallback_endpoint(sess, "sess1") + + assert result == { + "model": "gpt-5.1-codex", + "endpoint_url": _CODEX_BASE + "/responses", + "endpoint_name": "ChatGPT Subscription", + } + assert sess.headers["Authorization"] == "Bearer live-access-token" + + db = TestSessionLocal() + try: + row = db.query(DbSession).filter(DbSession.id == "sess1").first() + assert row.model == "gpt-5.1-codex" + assert row.endpoint_url == _CODEX_BASE + "/responses" + stored = row.headers or {} + assert not any(k.lower() == "authorization" for k in stored), ( + f"ChatGPT fallback bearer leaked into sessions table: {stored}" + ) + finally: + db.close() diff --git a/tests/test_review_regressions.py b/tests/test_review_regressions.py index cda2c720a..b3988f88e 100644 --- a/tests/test_review_regressions.py +++ b/tests/test_review_regressions.py @@ -386,7 +386,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message) monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {}) monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester") - monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model: None) + monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None) monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact) monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages) diff --git a/tests/test_session_owner_attribution.py b/tests/test_session_owner_attribution.py index 421bdea17..3dbaf53cf 100644 --- a/tests/test_session_owner_attribution.py +++ b/tests/test_session_owner_attribution.py @@ -137,3 +137,12 @@ def test_unauthenticated_caller_rejected(monkeypatch): with pytest.raises(HTTPException) as exc: SR._verify_session_owner(req, "sid") assert exc.value.status_code == 401 + + +def test_auth_disabled_allows_owner_stamped_session(monkeypatch): + monkeypatch.setenv("AUTH_ENABLED", "false") + monkeypatch.setattr(SR, "SessionLocal", _session_local_returning("admin")) + req = _req(api_token=False, current_user=None) + + # Single-user/auth-disabled mode should verify existence but not compare owner. + SR._verify_session_owner(req, "sid-owned-by-admin") diff --git a/tests/test_setup_device_auth_static.py b/tests/test_setup_device_auth_static.py new file mode 100644 index 000000000..4ba7d61c9 --- /dev/null +++ b/tests/test_setup_device_auth_static.py @@ -0,0 +1,42 @@ +"""Static regressions for `/setup` account sign-in providers.""" + +from pathlib import Path + + +_REPO = Path(__file__).resolve().parent.parent +_SLASH = (_REPO / "static" / "js" / "slashCommands.js").read_text(encoding="utf-8") + + +def _between(src: str, start: str, end: str) -> str: + start_idx = src.index(start) + end_idx = src.index(end, start_idx) + return src[start_idx:end_idx] + + +def test_setup_guide_lists_account_sign_in_providers(): + guide_block = _between(_SLASH, "function _showSetupEndpointChoices", "async function _hasConfiguredModels") + + assert 'data-setup-provider="' in _SLASH + assert "provider.key" in _SLASH + assert "'copilot'" in _SLASH + assert "'chatgpt-subscription'" in _SLASH + assert "/setup copilot" in _SLASH + assert "/setup chatgpt-subscription" in _SLASH + + +def test_clicking_account_sign_in_provider_prefills_setup_command_not_api_key(): + click_block = _between(_SLASH, "const providerEl = e.target.closest('.setup-clickable-provider')", "// 3. Check") + + assert "providerEl.dataset.setupProvider" in click_block + assert "providerEl.dataset.setupKind === 'device-auth'" in click_block + assert "'/setup ' + providerKey" in click_block + + +def test_setup_chatgpt_subscription_prints_auth_url_without_auto_opening_tab(): + flow_block = _between(_SLASH, "async function _setupProviderDeviceFlow", "async function _cmdSetup") + + assert "providerKey === 'chatgpt-subscription'" in flow_block + assert "Open this URL" in flow_block + assert "authUrl" in flow_block + assert 'href="\' + uiModule.esc(authUrl || \'\') + \'"' in flow_block + assert "if (providerKey === 'chatgpt-subscription') return;" in flow_block diff --git a/tests/test_slash_autocomplete_static.py b/tests/test_slash_autocomplete_static.py new file mode 100644 index 000000000..a7549e271 --- /dev/null +++ b/tests/test_slash_autocomplete_static.py @@ -0,0 +1,17 @@ +"""Static regressions for slash autocomplete command-group expansion.""" + +from pathlib import Path + + +_REPO = Path(__file__).resolve().parent.parent +_AC = (_REPO / "static" / "js" / "slashAutocomplete.js").read_text(encoding="utf-8") + + +def test_exact_parent_command_expands_subcommands_before_top_level_row_cap(): + assert "function _exactCommandGroupItems" in _AC + assert "entry.token.toLowerCase().startsWith(prefix)" in _AC + assert "items = groupItems.slice(0, MAX_VISIBLE);" in _AC + + +def test_setup_group_has_room_for_chatgpt_subscription_suggestion(): + assert "const MAX_VISIBLE = 14;" in _AC