feat: add ChatGPT Subscription provider (#2876)

* feat: Add ChatGPT Subscription support and related features

- Introduced a new provider option for ChatGPT Subscription in the endpoint selection UI.
- Implemented OAuth flow for ChatGPT Subscription sign-in, including polling for authorization status.
- Updated admin interface to handle ChatGPT Subscription, including disabling API key input and providing user guidance.
- Enhanced cost tracking logic to differentiate between subscription and non-subscription endpoints.
- Added new slash commands for managing skills, including listing, searching, and invoking skills.
- Implemented caching for skill catalog to optimize performance.
- Updated tests to cover new ChatGPT Subscription functionality and ensure proper endpoint probing.
- Refactored existing code to accommodate new features and improve maintainability.

* refactor: share provider device-flow setup

- reuse one device-flow backend for Copilot and ChatGPT Subscription
- add one frontend device-flow helper for Settings and /setup
- put GitHub Copilot back into Add Models, now as a dropdown option
- make provider selection just select; clicking Add starts sign-in
- stop ChatGPT Subscription setup from opening auth tabs automatically
- make /setup copilot and /setup chatgpt-subscription work from chat
- show ChatGPT Subscription in the /setup suggestions
- show the real error message when setup fails
- add focused tests for the shared flow and setup UI

* feat(chatgpt-subscription): harden credential lifecycle and streamline auth UX

Backend:
- Resolve runtime bearer for provider-auth endpoints at probe time via a
  shared _resolve_probe_key() that delegates to resolve_endpoint_runtime,
  applied across all probe/refresh call sites.
- Skip live completion probes and health pings for discovery-only providers
  (centralized behind _is_discovery_only_provider) — the Codex/Responses API
  has no such endpoints, so status is derived from cached models.
- Never persist the short lived ChatGPT bearer to the plaintext sessions
  table; proactively clear any stale bearer left by an earlier code path.
- Revoke orphaned ProviderAuthSession credentials when the last endpoint
  backing them is deleted (_delete_orphaned_provider_auth), surfaced via
  cleared_provider_auth in the delete response.

Frontend (admin.js):
- Auto-start the device-auth flow on provider selection so the authorization
  panel (code + Authorize) shows immediately instead of behind a "Sign in" click.
- Remove the redundant top button for device auth providers, move retry
  into the panel via an inline "Try again".
- Drop the self-evident hint text and add an execCommand clipboard fallback so
  Copy works in non-secure (HTTP/LAN) contexts.

* fix: harden chatgpt subscription provider

* chore: remove PR media from branch

* Fix chatgpt subscription recovery and token handling

---------

Co-authored-by: 5p00kyy <admin@5p00ky.dev>
This commit is contained in:
stocky789
2026-06-08 18:19:18 +10:00
committed by GitHub
parent ac94885c84
commit 1e0d9b92af
37 changed files with 3425 additions and 485 deletions
+4
View File
@@ -598,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery))
from routes.copilot_routes import setup_copilot_routes
app.include_router(setup_copilot_routes())
# ChatGPT Subscription device-flow login
from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes
app.include_router(setup_chatgpt_subscription_routes())
# TTS
from routes.tts_routes import setup_tts_routes
app.include_router(setup_tts_routes(tts_service))
+39
View File
@@ -361,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base):
# is the historical default. When non-null, the model picker only shows
# the endpoint to that user (admins always see everything).
owner = Column(String, nullable=True, index=True)
# Optional OAuth/session-backed credential row. Used by subscription-backed
# providers that need refresh tokens instead of a static API key.
provider_auth_id = Column(String, nullable=True, index=True)
class ProviderAuthSession(TimestampMixin, Base):
"""Encrypted OAuth/session credentials for refresh-aware model providers."""
__tablename__ = "provider_auth_sessions"
id = Column(String, primary_key=True, index=True)
provider = Column(String, nullable=False, index=True)
owner = Column(String, nullable=True, index=True)
label = Column(String, nullable=True)
base_url = Column(String, nullable=False)
access_token = Column(EncryptedText, nullable=True)
refresh_token = Column(EncryptedText, nullable=True)
last_refresh = Column(DateTime, nullable=True)
auth_mode = Column(String, nullable=True)
class McpServer(TimestampMixin, Base):
"""Admin-configured MCP (Model Context Protocol) tool servers."""
@@ -801,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column():
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
def _migrate_add_provider_auth_id_column():
"""Add provider_auth_id column to model_endpoints if it doesn't exist."""
import sqlite3
db_path = DATABASE_URL.replace("sqlite:///", "")
if not os.path.exists(db_path):
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
columns = [row[1] for row in cursor.fetchall()]
if columns and "provider_auth_id" not in columns:
conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR")
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
conn.commit()
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
conn.close()
except Exception as e:
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
def _migrate_add_model_type_column():
"""Add model_type column to model_endpoints if it doesn't exist."""
import sqlite3
@@ -1599,6 +1637,7 @@ def init_db():
_migrate_add_model_type_column()
_migrate_add_model_endpoint_refresh_columns()
_migrate_add_model_endpoint_owner_column()
_migrate_add_provider_auth_id_column()
_migrate_add_supports_tools_column()
_migrate_add_task_run_model_column()
_migrate_add_owner_column()
+86 -28
View File
@@ -196,14 +196,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
"""
import requests as _req
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
from src.endpoint_resolver import (
build_chat_url,
build_headers,
build_models_url,
normalize_base,
resolve_endpoint_runtime,
)
from src.chatgpt_subscription import is_chatgpt_subscription_base
current_url = sess.endpoint_url or ""
owner = getattr(sess, "owner", None)
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(
q = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True
).all()
)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
finally:
db.close()
@@ -212,26 +224,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
# Skip current endpoint
if current_url and base in current_url:
continue
# Quick ping
ping_url = build_models_url(base)
headers = build_headers(ep.api_key, base)
try:
r = _req.get(ping_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
models = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception:
continue
ping_url = build_models_url(base)
headers = build_headers(api_key, base)
try:
if ping_url:
r = _req.get(ping_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
models = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
models = json.loads(ep.cached_models or "[]")
if not models:
continue
# Found a working endpoint — update session
new_model = models[0]
chat_url = build_chat_url(base)
new_headers = build_headers(ep.api_key, base)
new_headers = build_headers(api_key, base)
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
sess.model = new_model
sess.endpoint_url = chat_url
@@ -243,7 +262,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
_db.query(DBSession).filter(DBSession.id == session_id).update({
"model": new_model,
"endpoint_url": chat_url,
"headers": json.dumps(new_headers),
"headers": persisted_headers,
})
_db.commit()
finally:
@@ -336,16 +355,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
return False
def _has_auth_keys(headers) -> bool:
"""True if a headers dict carries an Authorization/x-api-key entry."""
return isinstance(headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in headers
)
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
)
if has_auth:
try:
from src.chatgpt_subscription import is_chatgpt_subscription_base
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
except Exception:
is_chatgpt_subscription = False
has_auth = _has_auth_keys(sess.headers)
if has_auth and not is_chatgpt_subscription:
return
try:
from src.endpoint_resolver import build_headers, normalize_base
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
db = SessionLocal()
try:
target_url = getattr(sess, "endpoint_url", "") or ""
@@ -361,10 +390,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
for ep in q.all():
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
continue
if not ep.api_key:
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
return
if not api_key:
# No usable key (e.g. ChatGPT Subscription needs re-auth).
return
sess.headers = build_headers(api_key, base)
if is_chatgpt_subscription:
# The bearer is short-lived and re-resolved per request, so it
# stays request-local and is never written to the plaintext
# sessions.headers column. Proactively strip any bearer an
# older code path may have persisted so it does not linger.
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
stale_q = stale_q.filter(DBSession.owner == owner)
stored = stale_q.first()
if stored is not None and _has_auth_keys(stored.headers):
stale_q.update({"headers": {}})
db.commit()
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
return
base = normalize_base(ep.base_url or "")
sess.headers = build_headers(ep.api_key, base)
update_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
update_q = update_q.filter(DBSession.owner == owner)
@@ -408,7 +457,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
owner = getattr(sess, "owner", None)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for ep in endpoints:
try:
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
@@ -542,7 +596,11 @@ async def build_chat_context(
# Normalize model ID. Prefer cached endpoint models so group chat does not
# re-hit slow local /models endpoints on every participant turn.
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
sess.endpoint_url,
sess.model,
owner=getattr(sess, "owner", None),
)
if norm:
sess.model = norm
+57 -12
View File
@@ -169,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
Covers the window between endpoint setup and the first chat send: the
picker showed a model in the dropdown but the session record never got
written (Issue #587 — UI uses the cached endpoint list, not s.model).
Without this, we'd POST the upstream with model="" and get a generic
401/503 instead of using the model the user already picked.
Returns True iff sess.model was repaired.
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
"""
if getattr(sess, "model", None):
return False
current_model = (getattr(sess, "model", "") or "").strip()
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
is_chatgpt_subscription = False
if current_model:
try:
from src.chatgpt_subscription import is_chatgpt_subscription_base
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
if not is_chatgpt_subscription:
return False
except Exception:
return False
db = SessionLocal()
try:
# Prefer the endpoint whose base URL matches the session — we know the
@@ -194,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
break
if not ep:
return False
if not is_chatgpt_subscription:
try:
from src.chatgpt_subscription import is_chatgpt_subscription_base
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
except Exception:
is_chatgpt_subscription = False
try:
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
except Exception:
cached = []
if not cached:
visible = []
else:
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if current_model and current_model in {str(item).strip() for item in visible}:
return False
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if is_chatgpt_subscription:
live_models = []
if getattr(ep, "provider_auth_id", None):
try:
from src.chatgpt_subscription import fetch_available_models
from src.endpoint_resolver import resolve_endpoint_runtime
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
if api_key:
live_models = fetch_available_models(api_key)
if live_models:
ep.cached_models = json.dumps(live_models)
db.commit()
except Exception:
live_models = []
# ChatGPT Subscription recovery must use the live Codex catalog.
# Cached rows are only trusted above to avoid revalidating a model
# that is already present in the visible picker list.
cached = live_models
if not cached:
return False
try:
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
except Exception:
visible = cached
if current_model and current_model in {str(item).strip() for item in visible}:
return False
if not visible:
return False
model = visible[0]
@@ -213,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
# Persist so the next request, websocket reconnect, or page reload
# picks up the same model (we'd otherwise re-pick on every send
# and silently switch on the user if the cached order shifts).
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
db_session_q = db_session_q.filter(DBSession.owner == owner)
db_session = db_session_q.first()
if db_session:
db_session.model = model
db_session.updated_at = datetime.utcnow()
db.commit()
sess.model = model
logger.info(
"Recovered empty session model for %s — picked %r from endpoint %s",
"Recovered session model for %s — picked %r from endpoint %s",
session_id, model, ep.id,
)
return True
+170
View File
@@ -0,0 +1,170 @@
"""ChatGPT Subscription device-flow setup routes."""
import json
import logging
import uuid
from typing import Dict, Optional
from fastapi import HTTPException, Request
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
from routes.device_flow import (
DeviceFlowPoll,
DeviceFlowStart,
PendingDeviceFlowStore,
create_device_flow_router,
)
from src.auth_helpers import get_current_user
from src import chatgpt_subscription
logger = logging.getLogger(__name__)
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
if not access_token or not refresh_token:
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
models = chatgpt_subscription.fetch_available_models(access_token)
if not models:
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
db = SessionLocal()
try:
auth = (
db.query(ProviderAuthSession)
.filter(
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
ProviderAuthSession.owner == owner,
)
.first()
)
if auth is None:
auth = ProviderAuthSession(
id=str(uuid.uuid4())[:8],
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
owner=owner,
label="ChatGPT Subscription",
base_url=base,
auth_mode="chatgpt",
)
db.add(auth)
auth.base_url = base
auth.access_token = access_token
auth.refresh_token = refresh_token
auth.last_refresh = utcnow_naive()
auth.auth_mode = "chatgpt"
ep = (
db.query(ModelEndpoint)
.filter(
ModelEndpoint.base_url == base,
ModelEndpoint.provider_auth_id == auth.id,
ModelEndpoint.owner == owner,
)
.first()
)
if ep is None:
ep = ModelEndpoint(
id=str(uuid.uuid4())[:8],
name="ChatGPT Subscription",
base_url=base,
model_type="llm",
endpoint_kind="api",
owner=owner,
)
db.add(ep)
ep.name = "ChatGPT Subscription"
ep.base_url = base
ep.api_key = None
ep.provider_auth_id = auth.id
ep.is_enabled = True
ep.supports_tools = False
ep.model_type = "llm"
ep.endpoint_kind = "api"
ep.model_refresh_mode = "manual"
ep.cached_models = json.dumps(models)
db.commit()
result = {
"id": ep.id,
"name": ep.name,
"base_url": ep.base_url,
"models": models,
}
finally:
db.close()
try:
from routes.model_routes import _invalidate_models_cache
_invalidate_models_cache()
except Exception:
pass
return result
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
try:
data = chatgpt_subscription.request_device_code()
except Exception as exc:
raise chatgpt_subscription.to_http_exception(exc)
device_auth_id = data.get("device_auth_id")
user_code = data.get("user_code")
if not device_auth_id or not user_code:
raise HTTPException(502, "ChatGPT did not return a complete device code")
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
return DeviceFlowStart(
pending={
"device_auth_id": device_auth_id,
"user_code": user_code,
"owner": get_current_user(request) or None,
},
response={
"user_code": user_code,
"verification_uri": verification_uri,
},
interval=int(data.get("interval") or 5),
expires_in=int(data.get("expires_in") or 900),
)
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
try:
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
except Exception as exc:
logger.debug("ChatGPT device poll failed: %s", exc)
return DeviceFlowPoll.pending(str(exc))
authorization_code = data.get("authorization_code")
code_verifier = data.get("code_verifier")
if authorization_code and code_verifier:
try:
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
result = _provision_endpoint(tokens, pending["owner"])
except Exception as exc:
logger.exception("ChatGPT Subscription endpoint provisioning failed")
raise chatgpt_subscription.to_http_exception(exc)
return DeviceFlowPoll.authorized(result)
err = data.get("error") or data.get("status")
if err in ("authorization_pending", "pending", None):
return DeviceFlowPoll.pending()
if err == "slow_down":
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
if err in ("expired_token", "access_denied", "denied"):
return DeviceFlowPoll.failed(err)
return DeviceFlowPoll.pending(err or "unknown")
def setup_chatgpt_subscription_routes():
return create_device_flow_router(
prefix="/api/chatgpt-subscription",
tags=["chatgpt-subscription"],
store=_DEVICE_FLOW_STORE,
start_flow=_start_device_flow,
poll_flow=_poll_device_flow,
)
+67 -117
View File
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
"""
import json
import time
import uuid
import logging
import threading
from typing import Dict, Optional
import httpx
from fastapi import APIRouter, Request, Form, HTTPException
from fastapi import HTTPException, Request
from core.database import SessionLocal, ModelEndpoint
from core.middleware import require_admin
from routes.device_flow import (
DeviceFlowPoll,
DeviceFlowStart,
PendingDeviceFlowStore,
create_device_flow_router,
)
from src.auth_helpers import get_current_user
from src import copilot
logger = logging.getLogger(__name__)
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
# bearer-like secret, so it lives here (server memory) rather than in the
# browser. Entries expire with the GitHub device code.
#
# NOTE: this is per-process state. The device flow assumes a single worker
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
# on a worker that never saw the start, returning "Unknown or expired login
# session". Move this to a shared store (DB/Redis) if running multi-worker.
_PENDING: Dict[str, Dict] = {}
_PENDING_LOCK = threading.Lock()
def _prune_expired() -> None:
now = time.time()
with _PENDING_LOCK:
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
_PENDING.pop(k, None)
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
return result
def setup_copilot_routes() -> APIRouter:
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
host = copilot.GITHUB_HOST
ent = str(form.get("enterprise_url") or "").strip()
if ent:
host = copilot.normalize_domain(ent)
try:
data = copilot.request_device_code(host)
except httpx.HTTPStatusError as e:
status = e.response.status_code if e.response is not None else "unknown"
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
except Exception as e:
raise HTTPException(502, f"GitHub device-code request failed: {e}")
@router.post("/device/start")
def device_start(request: Request, enterprise_url: str = Form("")):
require_admin(request)
_prune_expired()
host = copilot.GITHUB_HOST
ent = (enterprise_url or "").strip()
if ent:
host = copilot.normalize_domain(ent)
try:
data = copilot.request_device_code(host)
except httpx.HTTPStatusError as e:
status = e.response.status_code if e.response is not None else "unknown"
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
except Exception as e:
raise HTTPException(502, f"GitHub device-code request failed: {e}")
device_code = data.get("device_code")
if not device_code:
raise HTTPException(502, "GitHub did not return a device code")
device_code = data.get("device_code")
if not device_code:
raise HTTPException(502, "GitHub did not return a device code")
interval = int(data.get("interval") or 5)
expires_in = int(data.get("expires_in") or 900)
poll_id = uuid.uuid4().hex
with _PENDING_LOCK:
_PENDING[poll_id] = {
"device_code": device_code,
"host": host,
"enterprise_url": ent,
"interval": interval,
"owner": get_current_user(request) or None,
"expires_at": time.time() + expires_in,
"next_poll_at": 0.0,
}
# verification_uri_complete embeds the user code, so the browser tab we
# open lands the user straight on GitHub's "Authorize" screen with the
# code pre-filled — one click, no manual code entry.
return {
"poll_id": poll_id,
# verification_uri_complete embeds the user code, so the browser tab we
# open lands the user straight on GitHub's "Authorize" screen with the
# code pre-filled — one click, no manual code entry.
return DeviceFlowStart(
pending={
"device_code": device_code,
"host": host,
"enterprise_url": ent,
"owner": get_current_user(request) or None,
},
response={
"user_code": data.get("user_code"),
"verification_uri": data.get("verification_uri"),
"verification_uri_complete": data.get("verification_uri_complete"),
"interval": interval,
"expires_in": expires_in,
}
},
interval=int(data.get("interval") or 5),
expires_in=int(data.get("expires_in") or 900),
)
@router.post("/device/poll")
def device_poll(request: Request, poll_id: str = Form(...)):
require_admin(request)
_prune_expired()
with _PENDING_LOCK:
pending = _PENDING.get(poll_id)
if not pending:
raise HTTPException(404, "Unknown or expired login session")
# Enforce GitHub's polling interval server-side so a chatty client
# can't trip slow_down.
now = time.time()
if now < pending.get("next_poll_at", 0):
return {"status": "pending"}
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
try:
data = copilot.poll_access_token(pending["host"], pending["device_code"])
except Exception as e:
return DeviceFlowPoll.pending(f"poll error: {e}")
token = data.get("access_token")
if token:
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
try:
data = copilot.poll_access_token(pending["host"], pending["device_code"])
result = _provision_endpoint(token, base, pending["owner"])
except Exception as e:
return {"status": "pending", "detail": f"poll error: {e}"}
logger.exception("Copilot endpoint provisioning failed")
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
return DeviceFlowPoll.authorized(result)
token = data.get("access_token")
if token:
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
try:
result = _provision_endpoint(token, base, pending["owner"])
except Exception as e:
logger.exception("Copilot endpoint provisioning failed")
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
return {"status": "authorized", "endpoint": result}
err = data.get("error")
if err == "authorization_pending":
return DeviceFlowPoll.pending()
if err == "slow_down":
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
if err in ("expired_token", "access_denied"):
return DeviceFlowPoll.failed(err)
# Unknown error — surface but keep the session for another try.
return DeviceFlowPoll.pending(err or "unknown")
err = data.get("error")
if err == "authorization_pending":
with _PENDING_LOCK:
if poll_id in _PENDING:
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
return {"status": "pending"}
if err == "slow_down":
new_interval = int(data.get("interval") or (pending["interval"] + 5))
with _PENDING_LOCK:
if poll_id in _PENDING:
_PENDING[poll_id]["interval"] = new_interval
_PENDING[poll_id]["next_poll_at"] = now + new_interval
return {"status": "pending"}
if err in ("expired_token", "access_denied"):
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
return {"status": "failed", "error": err}
# Unknown error — surface but keep the session for another try.
return {"status": "pending", "detail": err or "unknown"}
@router.post("/device/cancel")
def device_cancel(request: Request, poll_id: str = Form(...)):
require_admin(request)
with _PENDING_LOCK:
_PENDING.pop(poll_id, None)
return {"status": "cancelled"}
return router
def setup_copilot_routes():
return create_device_flow_router(
prefix="/api/copilot",
tags=["copilot"],
store=_DEVICE_FLOW_STORE,
start_flow=_start_device_flow,
poll_flow=_poll_device_flow,
)
+193
View File
@@ -0,0 +1,193 @@
"""Shared OAuth/device-flow route scaffolding for provider setup."""
from __future__ import annotations
import inspect
import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Mapping, Optional
from fastapi import APIRouter, Form, HTTPException, Request
from core.middleware import require_admin
@dataclass(frozen=True)
class DeviceFlowStart:
"""Provider-specific start result consumed by the shared route wrapper."""
pending: Mapping[str, Any]
response: Mapping[str, Any]
interval: int = 5
expires_in: int = 900
@dataclass(frozen=True)
class DeviceFlowPoll:
"""Normalized provider poll outcome."""
status: str
endpoint: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
detail: Optional[str] = None
interval: Optional[int] = None
@classmethod
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
return cls(status="pending", detail=detail)
@classmethod
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
return cls(status="slow_down", interval=interval, detail=detail)
@classmethod
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
return cls(status="authorized", endpoint=endpoint)
@classmethod
def failed(cls, error: str) -> "DeviceFlowPoll":
return cls(status="failed", error=error)
class PendingDeviceFlowStore:
"""Thread-safe in-memory pending device-flow store.
Device codes and provider-side secrets stay inside this process. Each entry
stores provider payload separately from poll metadata so provider callbacks
only receive the fields they created.
"""
def __init__(self, *, time_func: Callable[[], float] = time.time):
self._pending: dict[str, dict[str, Any]] = {}
self._lock = threading.Lock()
self._time = time_func
def _now(self) -> float:
return float(self._time())
def prune_expired(self) -> None:
now = self._now()
with self._lock:
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
self._pending.pop(key, None)
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
self.prune_expired()
poll_id = uuid.uuid4().hex
with self._lock:
self._pending[poll_id] = {
"payload": dict(payload),
"interval": max(int(interval or 5), 1),
"expires_at": self._now() + max(int(expires_in or 900), 1),
"next_poll_at": 0.0,
}
return poll_id
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
self.prune_expired()
with self._lock:
entry = self._pending.get(poll_id)
if entry is None:
return None
return dict(entry.get("payload") or {})
def is_throttled(self, poll_id: str) -> bool:
with self._lock:
entry = self._pending.get(poll_id)
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
def schedule_next(self, poll_id: str) -> None:
now = self._now()
with self._lock:
entry = self._pending.get(poll_id)
if entry is not None:
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
now = self._now()
with self._lock:
entry = self._pending.get(poll_id)
if entry is not None:
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
entry["interval"] = max(new_interval, 1)
entry["next_poll_at"] = now + entry["interval"]
def pop(self, poll_id: str) -> None:
with self._lock:
self._pending.pop(poll_id, None)
async def _maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
response: dict[str, Any] = {"status": "pending"}
if detail:
response["detail"] = detail
return response
def create_device_flow_router(
*,
prefix: str,
tags: Iterable[str],
store: PendingDeviceFlowStore,
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
) -> APIRouter:
"""Create standard `/device/start|poll|cancel` routes for a provider."""
router = APIRouter(prefix=prefix, tags=list(tags))
@router.post("/device/start")
async def device_start(request: Request):
require_admin(request)
form = await request.form()
start = await _maybe_await(start_flow(request, form))
interval = int(start.interval or 5)
expires_in = int(start.expires_in or 900)
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
response = dict(start.response)
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
return response
@router.post("/device/poll")
async def device_poll(request: Request, poll_id: str = Form(...)):
require_admin(request)
payload = store.get_payload(poll_id)
if payload is None:
raise HTTPException(404, "Unknown or expired login session")
if store.is_throttled(poll_id):
return {"status": "pending"}
try:
outcome = await _maybe_await(poll_flow(request, payload))
except Exception:
store.pop(poll_id)
raise
if outcome.status == "authorized":
store.pop(poll_id)
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
if outcome.status == "failed":
store.pop(poll_id)
return {"status": "failed", "error": outcome.error or "denied"}
if outcome.status == "slow_down":
store.slow_down(poll_id, outcome.interval)
return _pending_response(outcome.detail)
store.schedule_next(poll_id)
return _pending_response(outcome.detail)
@router.post("/device/cancel")
def device_cancel(request: Request, poll_id: str = Form(...)):
require_admin(request)
store.pop(poll_id)
return {"status": "cancelled"}
return router
+107 -13
View File
@@ -283,6 +283,7 @@ _HOST_TO_CURATED = (
("fireworks.ai", "fireworks"),
("googleapis.com", "google"),
("x.ai", "xai"),
("openrouter.ai", "openrouter"),
("ollama.com", "ollama"),
("opencode.ai/zen/go", "opencode-go"),
@@ -493,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
def _is_chat_model(model_id: str) -> bool:
"""Return True if the model ID looks like a chat/completions-capable model."""
mid = model_id.lower()
if mid in {"gpt-5.1-codex"}:
return True
for prefix in _NON_CHAT_PREFIXES:
if mid.startswith(prefix):
return False
@@ -505,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
return True
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
"""Delete a ProviderAuthSession once no endpoint still references it.
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
endpoint backed by that auth row is removed, the stored credentials should
be cleared instead of lingering. Returns True if a row was deleted.
``exclude_ep_id`` drops the endpoint currently being deleted from the
reference count so it does not keep its own auth alive.
"""
if not auth_id:
return False
from core.database import ProviderAuthSession
still_referenced = db.query(ModelEndpoint.id).filter(
ModelEndpoint.provider_auth_id == auth_id,
ModelEndpoint.id != exclude_ep_id,
).first()
if still_referenced is not None:
return False
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
if auth_row is None:
return False
db.delete(auth_row)
return True
def _is_discovery_only_provider(provider: str) -> bool:
"""Provider that only supports model discovery, not live probing.
ChatGPT Subscription speaks the Responses/Codex API and has no
chat-completions or general health endpoint, so completion probes and
reachability pings are skipped — status is derived from cached models.
"""
return provider == "chatgpt-subscription"
def _resolve_probe_key(ep) -> Optional[str]:
"""API key/bearer to probe an endpoint with.
Delegates to ``resolve_endpoint_runtime``, which already returns the static
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
Returns None if resolution fails (e.g. re-auth required) so probing skips
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
"""
try:
from src.endpoint_resolver import resolve_endpoint_runtime
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
return key
except Exception as e:
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
return None
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
provider = _detect_provider(base)
if _is_discovery_only_provider(provider):
# Responses/Codex API, not chat-completions: a completion probe would
# 400 and the re-probe flow would then hide every model. Discovery-only.
return {"status": "ok", "latency_ms": 0, "skipped": True}
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say OK"},
@@ -621,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
from src.endpoint_resolver import resolve_url
base = resolve_url(_normalize_base(base_url))
if _detect_provider(base) == "chatgpt-subscription":
from src.chatgpt_subscription import fetch_available_models
if api_key:
return fetch_available_models(api_key, timeout=timeout)
return []
if _detect_provider(base) == "anthropic":
# Try Anthropic's /v1/models endpoint first
url = build_models_url(base)
@@ -647,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
return list(ANTHROPIC_MODELS)
url = build_models_url(base)
if not url:
curated_key = _match_provider_curated(base, None)
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
return list(fallback or [])
headers = build_headers(api_key, base)
try:
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
@@ -998,6 +1068,17 @@ def setup_model_routes(model_discovery):
ok, info = _should_refresh_endpoint(ep, now, force=force)
if not ok:
continue
if getattr(ep, "provider_auth_id", None):
try:
from src.endpoint_resolver import resolve_endpoint_runtime
info["base"], info["api_key"] = resolve_endpoint_runtime(
ep,
owner=getattr(ep, "owner", None),
)
info["key"] = _refresh_key(info["base"], info["api_key"])
except Exception as e:
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
continue
groups.setdefault(info["key"], {
"base": info["base"],
"api_key": info["api_key"],
@@ -1266,12 +1347,20 @@ def setup_model_routes(model_discovery):
"endpoint_kind": kind,
}
try:
t0 = _time.time()
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
entry["error"] = ping.get("error")
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
if _is_discovery_only_provider(provider):
# No general health endpoint — an unauthenticated GET just
# 401s. Report status from cached models instead of pinging.
entry["latency_ms"] = None
entry["status"] = "online" if cached_count else "offline"
entry["error"] = None
entry["model_count"] = cached_count
else:
t0 = _time.time()
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
entry["latency_ms"] = round((_time.time() - t0) * 1000)
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
entry["error"] = ping.get("error")
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
except Exception as e:
entry["latency_ms"] = None
entry["status"] = "online" if cached_count else "offline"
@@ -1304,7 +1393,7 @@ def setup_model_routes(model_discovery):
if ep_id and ep_id not in endpoints_cache:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if ep:
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
ep_data = endpoints_cache.get(ep_id)
if not ep_data:
# Try to find by base_url from the model's endpoint field
@@ -1343,7 +1432,7 @@ def setup_model_routes(model_discovery):
"id": ep.id,
"name": ep.name,
"base_url": ep.base_url,
"api_key": ep.api_key,
"api_key": _resolve_probe_key(ep),
})
finally:
db.close()
@@ -1432,12 +1521,14 @@ def setup_model_routes(model_discovery):
# Endpoint counts as reachable if it has any model — including
# admin-pinned IDs that a probe would never surface.
status = "online" if (all_models or pinned) else "offline"
base = _normalize_base(r.base_url)
ping = None
if not all_models and not pinned and r.is_enabled:
# Discovery-only providers have no health endpoint — an
# unauthenticated ping just 401s, so don't bother.
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
if ping.get("reachable"):
status = "empty"
base = _normalize_base(r.base_url)
kind = _effective_endpoint_kind(r, base)
results.append({
"id": r.id,
@@ -1713,7 +1804,7 @@ def setup_model_routes(model_discovery):
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
if not ep:
raise HTTPException(404, "Endpoint not found")
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
finally:
db.close()
@@ -1777,7 +1868,7 @@ def setup_model_routes(model_discovery):
category = _classify_endpoint(base, kind)
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
try:
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
except Exception as exc:
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
probed = []
@@ -2116,7 +2207,9 @@ def setup_model_routes(model_discovery):
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
auth_id = getattr(ep, "provider_auth_id", None)
db.delete(ep)
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
db.commit()
_invalidate_models_cache()
_local_probe_cache["data"] = None
@@ -2126,6 +2219,7 @@ def setup_model_routes(model_discovery):
"cleared_user_preferences": cleared_user_preferences,
"cleared_sessions": cleared_sessions,
"cleared_loaded_sessions": cleared_loaded_sessions,
"cleared_provider_auth": cleared_provider_auth,
}
finally:
db.close()
+39 -26
View File
@@ -75,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
return owner_filter(q, ModelEndpoint, owner).first()
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
panel-selected research endpoints. ChatGPT Subscription endpoints keep
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
"""
from src.endpoint_resolver import (
build_chat_url,
build_headers,
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
)
try:
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Could not resolve endpoint credentials for research: %s", e)
return None
ep_model = (model or "").strip()
if not ep_model:
try:
models = json.loads(ep.cached_models) if ep.cached_models else []
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
if not ep_model:
return None
return build_chat_url(base), ep_model, build_headers(api_key, base)
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
router = APIRouter(tags=["research"])
@@ -371,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
if body.endpoint_id:
from src.database import SessionLocal
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
# Owner-scoped: never resolve another user's private endpoint
@@ -380,18 +411,10 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
if not ep:
raise HTTPException(404, "Endpoint not found or disabled")
base = normalize_base(ep.base_url)
ep_url = build_chat_url(base)
ep_headers = build_headers(ep.api_key, base)
ep_model = body.model or ""
if not ep_model:
try:
import json as _json
models = _json.loads(ep.cached_models) if ep.cached_models else []
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
if not resolved:
raise HTTPException(400, "Endpoint is not configured with a usable model.")
ep_url, ep_model, ep_headers = resolved
finally:
db.close()
else:
@@ -408,7 +431,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
if not ep_url:
from src.database import SessionLocal
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
# Owner-scoped first-enabled fallback: the caller's own rows
@@ -417,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
ep = _owned_enabled_endpoint(db, user)
if ep:
base = normalize_base(ep.base_url)
ep_url = build_chat_url(base)
ep_headers = build_headers(ep.api_key, base)
ep_model = ""
if ep.cached_models:
try:
import json as _json
models = _json.loads(ep.cached_models)
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
resolved = _resolve_endpoint_runtime(ep, owner=user)
if resolved:
ep_url, ep_model, ep_headers = resolved
finally:
db.close()
if not ep_url:
+22 -17
View File
@@ -92,18 +92,13 @@ def _reject_compact_during_active_run(session_id: str) -> None:
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
"""Verify the current user owns the session. Raises 404 if not.
"""Verify the current user owns the session, honoring single-user modes.
Ownership is checked against the DB row when one exists (unchanged). If
there is no DB row but the caller owns an in-memory "ghost" session — one
that lives only in ``session_manager`` because it was never persisted, or
its DB row was removed out-of-band — fall back to the in-memory owner so the
user can still manage and delete it. Without this fallback such sessions are
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
per-session operation 404s, making them impossible to delete (issue #1044).
``session_manager`` is optional and defaults to ``None`` so existing callers
that only care about persisted sessions keep their exact prior behavior.
Authenticated requests must match the stored DB or in-memory owner. When
auth is disabled and no user is present, treat the app as single-user mode:
verify that the session exists, but do not compare its stored owner. This
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
rows created while auth was previously enabled.
"""
user = effective_user(request)
if not user and not _auth_disabled():
@@ -114,13 +109,13 @@ def _verify_session_owner(request: Request, session_id: str, session_manager=Non
finally:
db.close()
if row is not None:
if row.owner != user:
if user and row.owner != user:
raise HTTPException(404, f"Session {session_id} not found")
return
# No DB row — allow the caller to act on an in-memory ghost they own.
if session_manager is not None:
ghost = getattr(session_manager, "sessions", {}).get(session_id)
if ghost is not None and getattr(ghost, "owner", None) == user:
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
return
raise HTTPException(404, f"Session {session_id} not found")
@@ -372,8 +367,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
pass
elif not model_to_use:
from src.llm_core import list_model_ids
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers=validation_headers)
ids = list_model_ids(
endpoint_url,
timeout=REQUEST_TIMEOUT,
headers=validation_headers,
owner=user,
endpoint_id=endpoint_id.strip() if endpoint_id else None,
)
if not ids:
raise HTTPException(400, "Cannot reach /v1/models")
# Default to the first CHAT model — endpoints often list embedding/
@@ -387,8 +387,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
from src.llm_core import list_model_ids
import os as _os
req_base = _os.path.basename(model_to_use.rstrip("/"))
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers=validation_headers)
avail = list_model_ids(
endpoint_url,
timeout=REQUEST_TIMEOUT,
headers=validation_headers,
owner=user,
endpoint_id=endpoint_id.strip() if endpoint_id else None,
)
if not avail:
raise HTTPException(400, "Cannot reach /v1/models")
if model_to_use not in avail:
+70
View File
@@ -1109,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
idx = skills_manager.index_for(owner=user)
return {"index": idx, "count": len(idx)}
@router.get("/slash-catalog")
async def get_slash_catalog(request: Request):
"""Return skills that are available as slash commands.
Mirrors the agent prompt's published-skill index so the UI never offers
a slash command the model would not normally be allowed to discover.
"""
user = _owner(request)
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
entries = []
for s in skills_manager.index_for(owner=user):
name = (s.get("name") or "").strip()
if not name:
continue
full = all_skills.get(name) or {}
category = (s.get("category") or full.get("category") or "general").strip() or "general"
entries.append({
"type": "skill",
"token": f"/{name}",
"name": name,
"category": f"Skills / {category}",
"help": s.get("description") or full.get("description") or "",
"usage": f"/{name} <request>",
"uses": int(full.get("uses") or 0),
"last_used": full.get("last_used"),
})
entries.sort(key=lambda row: row["name"])
return {"skills": entries, "count": len(entries)}
@router.get("/builtin")
async def list_builtin_skills(request: Request):
"""Read-only list of the agent's built-in tool capabilities (research,
@@ -1272,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
_fire_skill_added(user)
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
@router.post("/{skill_id}/invoke")
async def invoke_skill(request: Request, skill_id: str):
"""Build a skill-pinned prompt for slash-command invocation.
This is intentionally server-side so availability, ownership, and usage
accounting use the same rules as the SkillsManager.
"""
user = _owner(request)
try:
body = await request.json()
except Exception:
body = {}
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
invokable = {
s.get("name"): s for s in skills_manager.index_for(owner=user)
if (s.get("name") or "").strip()
}
match = invokable.get(skill_id)
if not match:
raise HTTPException(404, "Skill is not available for slash invocation")
name = match.get("name")
md = skills_manager.read_skill_md(name, owner=user)
if md is None:
raise HTTPException(404, "Skill source unavailable")
skills_manager.record_use(name, owner=user)
message = (
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
)
return {
"ok": True,
"type": "skill",
"name": name,
"command": f"/{name}",
"message": message,
}
@router.get("/{skill_id}")
async def get_skill(request: Request, skill_id: str):
user = _owner(request)
+21 -10
View File
@@ -325,22 +325,33 @@ def setup_webhook_routes(
endpoint_url = build_chat_url(base_url)
model = body.model or "auto"
api_key = ep.api_key
if getattr(ep, "provider_auth_id", None):
try:
from src.endpoint_resolver import resolve_endpoint_runtime
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
endpoint_url = build_chat_url(base_url)
except Exception:
raise HTTPException(500, "Could not resolve endpoint credentials")
if model == "auto":
try:
async with httpx.AsyncClient(timeout=5) as client:
models_url = build_models_url(base_url)
hdrs = build_headers(api_key, base_url)
resp = await client.get(models_url, headers=hdrs)
resp.raise_for_status()
data = resp.json()
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not ids:
ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
if models_url:
resp = await client.get(models_url, headers=hdrs)
resp.raise_for_status()
data = resp.json()
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not ids:
ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
import json as _json
ids = _json.loads(ep.cached_models or "[]")
model = ids[0] if ids else "auto"
except Exception:
raise HTTPException(500, "Could not discover models from endpoint")
+39 -25
View File
@@ -57,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
# Model resolution
# ---------------------------------------------------------------------------
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
@@ -98,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
for ep in endpoints:
base = _normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception:
continue
provider = _detect_provider(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
if provider == "anthropic":
# Anthropic: match against hardcoded model list
@@ -114,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
else:
# OpenAI-compatible and native Ollama: probe the provider's model list.
try:
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
models_url = build_models_url(base)
if models_url:
r = httpx.get(models_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
model_ids = json.loads(ep.cached_models or "[]")
except Exception:
model_ids = []
@@ -1121,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
total_models = 0
for ep in endpoints:
base = _normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception:
continue
provider = _detect_provider(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
model_ids = []
if provider == "anthropic":
model_ids = list(ANTHROPIC_MODELS)
else:
try:
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
models_url = build_models_url(base)
if models_url:
r = httpx.get(models_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not model_ids:
model_ids = [
m.get("name") or m.get("model")
for m in (data.get("models") or [])
if m.get("name") or m.get("model")
]
else:
model_ids = json.loads(ep.cached_models or "[]")
except Exception:
model_ids = ["(endpoint offline)"]
+311
View File
@@ -0,0 +1,311 @@
"""ChatGPT subscription / Codex backend OAuth helpers.
This provider is intentionally separate from OpenAI API-key endpoints. It uses
OpenAI account OAuth device authorization, stores refresh tokens server-side,
and resolves a fresh bearer token at request time.
"""
from __future__ import annotations
import base64
import json
import os
import threading
import time
from typing import Any, Dict, Optional
import httpx
from fastapi import HTTPException
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
or "https://chatgpt.com/backend-api/codex"
)
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
def _refresh_lock_for(auth_id: str) -> threading.Lock:
with _AUTH_REFRESH_LOCKS_GUARD:
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
if lock is None:
lock = threading.Lock()
_AUTH_REFRESH_LOCKS[auth_id] = lock
return lock
class ChatGPTSubscriptionError(RuntimeError):
"""Base error for ChatGPT subscription provider failures."""
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
"""Stored OAuth credentials are invalid or expired beyond refresh."""
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
"""Upstream quota/rate limit; reconnecting will not fix it."""
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
"""No matching owner-scoped auth session exists."""
def is_chatgpt_subscription_base(url: str) -> bool:
try:
from urllib.parse import urlparse
parsed = urlparse(url or "")
host = (parsed.hostname or "").lower().rstrip(".")
path = (parsed.path or "").rstrip("/")
except Exception:
return False
return host == "chatgpt.com" and (
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
)
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
headers = {
"Accept": "application/json, text/event-stream",
"Origin": "https://chatgpt.com",
"Referer": "https://chatgpt.com/codex",
"User-Agent": "Odysseus ChatGPT Subscription",
}
if access_token:
headers["Authorization"] = f"Bearer {access_token}"
return headers
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
if not access_token:
return []
try:
response = httpx.get(
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
headers=chatgpt_headers(access_token),
timeout=timeout,
)
if response.status_code != 200:
return []
data = response.json()
except Exception:
return []
entries = data.get("models", []) if isinstance(data, dict) else []
sortable: list[tuple[int, str]] = []
for item in entries:
if not isinstance(item, dict):
continue
slug = item.get("slug")
if not isinstance(slug, str) or not slug.strip():
continue
visibility = item.get("visibility", "")
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
continue
priority = item.get("priority")
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
sortable.append((rank, slug.strip()))
sortable.sort(key=lambda item: (item[0], item[1]))
ordered: list[str] = []
seen: set[str] = set()
for _, slug in sortable:
if slug not in seen:
ordered.append(slug)
seen.add(slug)
return ordered
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
if response.status_code < 400:
return
code = ""
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
try:
payload = response.json()
err = payload.get("error") if isinstance(payload, dict) else None
if isinstance(err, dict):
code = str(err.get("code") or err.get("type") or "").strip()
msg = err.get("message")
if msg:
message = f"ChatGPT Subscription {action} failed: {msg}"
elif isinstance(err, str):
code = err.strip()
desc = payload.get("error_description") or payload.get("message")
if desc:
message = f"ChatGPT Subscription {action} failed: {desc}"
except Exception:
pass
if response.status_code == 429:
raise ChatGPTSubscriptionRateLimited(
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
)
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
raise ChatGPTSubscriptionReauthRequired(message)
raise ChatGPTSubscriptionError(message)
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
_raise_for_oauth_response(response, action)
try:
data = response.json()
except Exception as exc:
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
if not isinstance(data, dict):
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
return data
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
response = httpx.post(
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
headers={"Content-Type": "application/json"},
timeout=timeout,
)
data = _json_or_error(response, "device-code request")
if not data.get("device_auth_id") or not data.get("user_code"):
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
data.setdefault("interval", 5)
data.setdefault("expires_in", 900)
return data
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
response = httpx.post(
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
json={"device_auth_id": device_auth_id, "user_code": user_code},
headers={"Content-Type": "application/json"},
timeout=timeout,
)
if response.status_code in (403, 404):
return {"status": "pending", "error": "authorization_pending"}
return _json_or_error(response, "device-code poll")
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
response = httpx.post(
CHATGPT_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "authorization_code",
"code": authorization_code,
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
"client_id": CHATGPT_OAUTH_CLIENT_ID,
"code_verifier": code_verifier,
},
timeout=timeout,
)
data = _json_or_error(response, "token exchange")
if not data.get("access_token"):
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
return data
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
del access_token
if not refresh_token:
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
response = httpx.post(
CHATGPT_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": CHATGPT_OAUTH_CLIENT_ID,
},
timeout=timeout,
)
data = _json_or_error(response, "token refresh")
if not data.get("access_token"):
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
return data
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
parts = (token or "").split(".")
if len(parts) < 2:
raise ValueError("not a JWT")
segment = parts[1]
segment += "=" * (-len(segment) % 4)
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
payload = json.loads(raw.decode("utf-8"))
return payload if isinstance(payload, dict) else {}
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
try:
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
except Exception:
return True
return exp <= int(time.time()) + int(skew_seconds)
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
db = SessionLocal()
try:
q = db.query(ProviderAuthSession).filter(
ProviderAuthSession.id == auth_id,
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
)
if owner:
q = q.filter(ProviderAuthSession.owner == owner)
row = q.first()
if row is None:
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
access_token = row.access_token or ""
if force_refresh or access_token_is_expiring(access_token):
with _refresh_lock_for(auth_id):
db.refresh(row)
access_token = row.access_token or ""
refresh_token = row.refresh_token or ""
if force_refresh or access_token_is_expiring(access_token):
refreshed = refresh_oauth_tokens(access_token, refresh_token)
row.access_token = refreshed["access_token"]
if refreshed.get("refresh_token"):
row.refresh_token = refreshed["refresh_token"]
row.last_refresh = utcnow_naive()
db.commit()
db.refresh(row)
access_token = row.access_token or ""
return {
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
"api_key": access_token,
"auth_mode": row.auth_mode or "chatgpt",
}
finally:
db.close()
def to_http_exception(exc: Exception) -> HTTPException:
if isinstance(exc, ChatGPTSubscriptionRateLimited):
return HTTPException(429, str(exc))
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
return HTTPException(401, f"{exc} Reconnect the provider.")
return HTTPException(502, str(exc))
def build_responses_input(messages: list[dict]) -> list[dict]:
input_items: list[dict] = []
for msg in messages or []:
role = msg.get("role") or "user"
if role == "tool":
role = "user"
content = msg.get("content")
if isinstance(content, list):
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
else:
text = "" if content is None else str(content)
input_type = "output_text" if role == "assistant" else "input_text"
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
return input_items
+40 -6
View File
@@ -70,6 +70,25 @@ def _endpoint_enabled_models(ep) -> list:
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
def resolve_endpoint_runtime(ep, owner: Optional[str] = None) -> Tuple[str, Optional[str]]:
"""Resolve a ModelEndpoint row to its runtime base URL and bearer/API key.
Static-key providers use ``ModelEndpoint.api_key``. Session-backed providers
store refreshable credentials in ProviderAuthSession and must resolve a
current access token at call time.
"""
base = normalize_base(getattr(ep, "base_url", "") or "")
api_key = getattr(ep, "api_key", None)
auth_id = getattr(ep, "provider_auth_id", None)
if auth_id:
from src.chatgpt_subscription import resolve_runtime_credentials
creds = resolve_runtime_credentials(auth_id, owner=owner)
base = normalize_base(creds.get("base_url") or base)
api_key = creds.get("api_key")
return base, api_key
# Cache for Tailscale hostname → IP resolution
_tailscale_cache: Dict[str, Optional[str]] = {}
@@ -133,7 +152,7 @@ def resolve_url(url: str) -> str:
def normalize_base(url: str) -> str:
"""Strip known API path suffixes from a base URL."""
url = (url or "").strip().rstrip("/")
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages", "/responses"]:
if url.endswith(suffix):
url = url[: -len(suffix)].rstrip("/")
for suffix in ["/chat", "/tags", "/generate"]:
@@ -158,10 +177,12 @@ def build_chat_url(base: str) -> str:
return _anthropic_api_root(base) + "/v1/messages"
if provider == "ollama":
return _ollama_api_root(base) + "/chat"
if provider == "chatgpt-subscription":
return base.rstrip("/") + "/responses"
return base + "/chat/completions"
def build_models_url(base: str) -> str:
def build_models_url(base: str) -> Optional[str]:
"""Return the provider-specific model-list endpoint URL for a base."""
base = resolve_url(base)
provider = _detect_provider(base)
@@ -169,6 +190,8 @@ def build_models_url(base: str) -> str:
return _anthropic_api_root(base) + "/v1/models"
if provider == "ollama":
return _ollama_api_root(base) + "/tags"
if provider == "chatgpt-subscription":
return None
return base + "/models"
@@ -184,6 +207,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
if provider == "copilot":
from src.copilot import copilot_headers
return copilot_headers(api_key)
if provider == "chatgpt-subscription":
from src.chatgpt_subscription import chatgpt_headers
return chatgpt_headers(api_key)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if provider == "openrouter":
@@ -262,9 +288,13 @@ def resolve_endpoint(
if not ep:
return fallback_url, fallback_model, fallback_headers
base = normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
return fallback_url, fallback_model, fallback_headers
chat_url = build_chat_url(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
# Discard a configured model the user has since disabled on the
# endpoint (e.g. a stale `default_model` left pointing at a now-hidden
@@ -308,9 +338,13 @@ def resolve_endpoint_by_id(
ep = q.first()
if not ep:
return None
base = normalize_base(ep.base_url)
try:
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
except Exception as e:
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
return None
chat_url = build_chat_url(base)
headers = build_headers(ep.api_key, base)
headers = build_headers(api_key, base)
m = (model or "").strip()
# Drop a model the user disabled on the endpoint, then pick the first
# enabled chat model rather than a hidden one.
+217 -7
View File
@@ -426,6 +426,9 @@ def _detect_provider(url: str) -> str:
return "openrouter"
if _host_match(url, "groq.com"):
return "groq"
from src.chatgpt_subscription import is_chatgpt_subscription_base
if is_chatgpt_subscription_base(url):
return "chatgpt-subscription"
from src.copilot import is_copilot_base
if is_copilot_base(url):
return "copilot"
@@ -462,6 +465,8 @@ def _provider_label(url: str) -> str:
if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go"
if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen"
if _host_match(url, "groq.com"): return "Groq"
from src.chatgpt_subscription import is_chatgpt_subscription_base
if is_chatgpt_subscription_base(url): return "ChatGPT Subscription"
from src.copilot import is_copilot_base
if is_copilot_base(url): return "GitHub Copilot"
if _host_match(url, "mistral.ai"): return "Mistral"
@@ -479,6 +484,77 @@ def _provider_label(url: str) -> str:
return host or "provider"
def _normalize_chatgpt_subscription_url(url: str) -> str:
base = (url or "").strip().rstrip("/")
if base.endswith("/responses"):
return base
return base + "/responses"
def _message_content_as_text(content) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if not isinstance(part, dict):
if part:
parts.append(str(part))
continue
if isinstance(part.get("text"), str):
parts.append(part["text"])
continue
if isinstance(part.get("content"), str):
parts.append(part["content"])
return "\n".join(parts)
return "" if content is None else str(content)
def _chatgpt_subscription_instructions(messages: List[Dict]) -> str:
instructions = [
_message_content_as_text(msg.get("content")).strip()
for msg in messages or []
if (msg.get("role") or "") == "system"
]
instructions = [part for part in instructions if part]
if instructions:
return "\n\n".join(instructions)
return "You are a helpful AI assistant."
def _build_chatgpt_responses_payload(
model: str,
messages: List[Dict],
temperature: float,
max_tokens: int,
*,
stream: bool = False,
) -> Dict:
from src.chatgpt_subscription import build_responses_input
conversation = [msg for msg in (messages or []) if (msg.get("role") or "") != "system"]
payload: Dict = {
"model": model,
"instructions": _chatgpt_subscription_instructions(messages),
"input": build_responses_input(conversation),
"stream": stream,
"store": False,
}
if not _restricts_temperature(model):
payload["temperature"] = temperature
if max_tokens and max_tokens > 0:
payload["max_output_tokens"] = max_tokens
return payload
def _format_chatgpt_subscription_error(status_code: int, text: str) -> str:
if status_code in (401, 403):
return "ChatGPT Subscription credentials expired or were rejected. Reconnect the provider."
if status_code == 429:
return "ChatGPT Subscription quota or rate limit was reached. Retry after the upstream limit resets."
return _format_upstream_error(status_code, text, "https://chatgpt.com/backend-api/codex")
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
"""Turn an upstream HTTP error into a user-readable sentence.
@@ -874,7 +950,7 @@ def _normalize_anthropic_url(url: str) -> str:
def _model_list_base(url: str) -> str:
"""Normalize model/chat URLs to the configured endpoint base."""
base = (url or "").strip().rstrip("/")
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages", "/responses"):
if base.endswith(suffix):
base = base[: -len(suffix)].rstrip("/")
for suffix in ("/chat", "/tags", "/generate"):
@@ -903,7 +979,12 @@ def _parse_model_cache(raw) -> List[str]:
return out
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
def _configured_cached_model_ids(
endpoint_url: str,
*,
owner: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> List[str]:
"""Return cached models for a configured endpoint matching endpoint_url."""
target = _model_list_base(endpoint_url)
if not target:
@@ -914,7 +995,13 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
return []
db = SessionLocal()
try:
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if endpoint_id:
q = q.filter(ModelEndpoint.id == endpoint_id)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
rows = q.all()
for ep in rows:
if _model_list_base(getattr(ep, "base_url", "")) != target:
continue
@@ -933,9 +1020,16 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
return []
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
def list_model_ids(
base_chat_url: str,
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
headers: Optional[Dict] = None,
*,
owner: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> List[str]:
"""List available model IDs from an endpoint."""
cached = _configured_cached_model_ids(base_chat_url)
cached = _configured_cached_model_ids(base_chat_url, owner=owner, endpoint_id=endpoint_id)
if cached:
return cached
provider = _detect_provider(base_chat_url)
@@ -971,9 +1065,16 @@ def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT,
pass
return []
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
def normalize_model_id(
endpoint_url: str,
requested: str,
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
*,
owner: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> Optional[str]:
"""Normalize a model ID to match available models."""
avail = list_model_ids(endpoint_url, timeout)
avail = list_model_ids(endpoint_url, timeout, owner=owner, endpoint_id=endpoint_id)
if not avail:
return None
if requested in avail:
@@ -1169,6 +1270,49 @@ async def llm_call_async(
logger.debug(f"Returning cached response for key: {cache_key}")
return cached_response
if provider == "chatgpt-subscription":
# ChatGPT/Codex requires streamed Responses requests even for callers
# that want a plain string (auto-title, memory extraction, etc.).
# Reuse stream_llm's validated Codex SSE path and collect deltas.
parts: List[str] = []
async for chunk in stream_llm(
url,
model,
messages_copy,
temperature=temperature,
max_tokens=max_tokens,
headers=headers,
timeout=timeout,
):
event_is_error = False
for line in str(chunk).splitlines():
if line.startswith("event:"):
event_is_error = line[6:].strip() == "error"
continue
if not line.startswith("data:"):
continue
raw = line[5:].strip()
if not raw:
continue
if raw == "[DONE]":
response = "".join(parts)
_set_cached_response(cache_key, response)
return response
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
if event_is_error or data.get("error") or (data.get("status") and data.get("text")):
status = int(data.get("status") or 502)
text = data.get("text") or data.get("error") or "ChatGPT Subscription request failed"
raise HTTPException(status, text)
delta = data.get("delta")
if isinstance(delta, str):
parts.append(delta)
response = "".join(parts)
_set_cached_response(cache_key, response)
return response
if provider == "anthropic":
target_url = _normalize_anthropic_url(url)
h = _build_anthropic_headers(headers)
@@ -1294,6 +1438,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
model, messages_copy, temperature, max_tokens,
stream=True, tools=tools, num_ctx=get_context_length(url, model),
)
elif provider == "chatgpt-subscription":
target_url = _normalize_chatgpt_subscription_url(url)
h = _provider_headers(provider, headers)
payload = _build_chatgpt_responses_payload(model, messages_copy, temperature, max_tokens, stream=True)
else:
target_url = url
payload = {
@@ -1325,6 +1473,68 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
return
note_model_activity(target_url, model)
# ── ChatGPT Subscription / Codex Responses streaming ──
if provider == "chatgpt-subscription":
event_name = ""
input_tokens = 0
output_tokens = 0
try:
client = _get_http_client()
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
_clear_host_dead(target_url)
if r.status_code != 200:
raw = (await r.aread()).decode(errors="replace")
friendly = _format_chatgpt_subscription_error(r.status_code, raw)
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
return
async for line in r.aiter_lines():
if not line:
continue
if line.startswith("event:"):
event_name = line[6:].strip()
continue
if not line.startswith("data:"):
continue
raw = line[5:].strip()
if not raw:
continue
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
evt = data.get("type") or event_name
if evt == "response.output_text.delta":
delta = data.get("delta") or ""
if delta:
yield f'data: {json.dumps({"delta": delta})}\n\n'
elif evt == "response.completed":
usage = (data.get("response") or {}).get("usage") or data.get("usage") or {}
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or input_tokens
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") or output_tokens
if input_tokens or output_tokens:
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": input_tokens, "output_tokens": output_tokens}})}\n\n'
yield "data: [DONE]\n\n"
return
elif evt in ("response.failed", "error"):
err = data.get("error") or (data.get("response") or {}).get("error") or {}
text = err.get("message") if isinstance(err, dict) else str(err or "ChatGPT Subscription request failed")
yield f'event: error\ndata: {json.dumps({"status": 502, "text": text})}\n\n'
return
yield "data: [DONE]\n\n"
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
_cooled = _mark_host_dead(target_url)
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
logger.warning(f"ChatGPT Subscription stream connect to {target_url} failed: {e}{_tail}")
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
except httpx.ReadTimeout:
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
except httpx.NetworkError:
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
except Exception as e:
logger.error(f"ChatGPT Subscription stream error: {e}")
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
return
# ── Native Ollama streaming ──
if provider == "ollama":
_ollama_tool_calls: List[Dict] = []
+3
View File
@@ -2108,6 +2108,8 @@
<option value="https://api.anthropic.com" data-logo="anthropic">Anthropic</option>
<option value="https://api.deepseek.com/v1" data-logo="deepseek" selected>DeepSeek</option>
<option value="https://api.openai.com/v1" data-logo="openai">OpenAI</option>
<option value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot</option>
<option value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription</option>
<option value="https://openrouter.ai/api/v1" data-logo="openrouter">OpenRouter</option>
<option value="https://ollama.com/api" data-logo="ollama">Ollama Cloud</option>
<option value="https://api.groq.com/openai/v1" data-logo="groq">Groq</option>
@@ -2136,6 +2138,7 @@
<button class="admin-btn-add" id="adm-epAddBtn" style="width:55px;text-align:center;">Add</button>
</div>
<div id="adm-epApiMsg" class="adm-ep-inline-msg"></div>
<div id="adm-deviceAuthStatus" class="adm-ep-inline-msg"></div>
</div>
</div>
</div>
+201 -68
View File
@@ -5,6 +5,7 @@ import uiModule from './ui.js';
import settingsModule from './settings.js';
import { providerLogo } from './providers.js';
import { sortModelObjects } from './modelSort.js';
import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js';
let initialized = false;
let modalEl = null;
@@ -707,6 +708,80 @@ function initEndpointForm() {
const pickerBtn = el('adm-provider-btn');
const pickerMenu = el('adm-provider-menu');
const pickerCurrent = picker ? picker.querySelector('.adm-provider-current') : null;
const DEVICE_AUTH_PROVIDER_VALUES = new Set(Object.keys(PROVIDER_DEVICE_FLOWS));
let deviceAuthPolling = false;
function _selectedProviderOption() {
return provider && provider.selectedOptions ? provider.selectedOptions[0] : null;
}
function _selectedDeviceAuthProvider() {
const opt = _selectedProviderOption();
const flow = opt && opt.dataset ? opt.dataset.authFlow : '';
if (flow && DEVICE_AUTH_PROVIDER_VALUES.has(flow)) return flow;
return DEVICE_AUTH_PROVIDER_VALUES.has(provider.value) ? provider.value : '';
}
function _isDeviceAuthSelected() {
return !!_selectedDeviceAuthProvider();
}
function _setApiFormForProvider() {
const deviceAuthProvider = _selectedDeviceAuthProvider();
const deviceAuthConfig = PROVIDER_DEVICE_FLOWS[deviceAuthProvider] || null;
const apiKey = el('adm-epApiKey');
const testBtn = el('adm-epApiTestBtn');
const addBtn = el('adm-epAddBtn');
const status = el('adm-deviceAuthStatus');
const msg = _endpointMsg('api');
if (deviceAuthConfig) {
urlInput.value = '';
urlInput.placeholder = deviceAuthProvider === 'copilot'
? 'GitHub Copilot uses GitHub account sign-in'
: 'ChatGPT Subscription uses OpenAI account sign-in';
urlInput.readOnly = true;
if (apiKey) {
apiKey.value = '';
apiKey.placeholder = 'No API key needed';
apiKey.disabled = true;
}
if (testBtn) {
testBtn.disabled = true;
testBtn.style.opacity = '0.45';
testBtn.style.cursor = 'not-allowed';
}
if (addBtn) {
addBtn.disabled = false;
addBtn.textContent = 'Add';
addBtn.style.width = '55px';
addBtn.style.display = '';
}
if (kindSel) kindSel.value = 'api';
if (msg) {
msg.textContent = '';
msg.className = '';
}
} else {
urlInput.placeholder = 'Base URL or pick provider';
urlInput.readOnly = false;
if (apiKey) {
apiKey.placeholder = 'API key';
apiKey.disabled = false;
}
if (testBtn) {
testBtn.disabled = false;
testBtn.style.opacity = '';
testBtn.style.cursor = '';
}
if (addBtn) {
addBtn.disabled = false;
addBtn.textContent = 'Add';
addBtn.style.width = '55px';
addBtn.style.display = '';
}
if (msg) {
msg.textContent = '';
msg.className = '';
}
if (!deviceAuthPolling && status) status.textContent = '';
}
}
function _renderPickerMenu() {
if (!pickerMenu) return;
pickerMenu.innerHTML = Array.from(provider.options).map(o => {
@@ -748,9 +823,16 @@ function initEndpointForm() {
}
provider.addEventListener('change', () => {
if (_isDeviceAuthSelected()) {
_setApiFormForProvider();
_renderPickerMenu();
_syncPickerCurrent();
return;
}
if (provider.value) urlInput.value = provider.value;
else urlInput.value = '';
if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy';
_setApiFormForProvider();
});
urlInput.addEventListener('input', () => {
if (provider.value && urlInput.value.trim() !== provider.value) {
@@ -838,6 +920,12 @@ function initEndpointForm() {
const apiCancelTestBtn = el('adm-epApiCancelTestBtn');
if (apiTestBtn) {
apiTestBtn.addEventListener('click', async () => {
if (_isDeviceAuthSelected()) {
const msg = _endpointMsg('api');
msg.textContent = '';
msg.className = '';
return;
}
const msg = _endpointMsg('api');
msg.textContent = ''; msg.className = '';
const rawUrl = (urlInput.value || provider.value).trim();
@@ -885,6 +973,11 @@ function initEndpointForm() {
}
el('adm-epAddBtn').addEventListener('click', async () => {
const deviceAuthProvider = _selectedDeviceAuthProvider();
if (deviceAuthProvider) {
await _startProviderDeviceAuth(deviceAuthProvider, el('adm-epAddBtn'));
return;
}
const msg = _endpointMsg('api');
msg.textContent = ''; msg.className = '';
const rawUrl = (urlInput.value || provider.value).trim();
@@ -936,76 +1029,116 @@ function initEndpointForm() {
btn.disabled = false; btn.textContent = 'Add';
});
// GitHub Copilot — device-flow login. Starts the flow, shows the user a
// code + verification link, and polls until they authorise (or it expires).
const copilotBtn = el('adm-copilotConnectBtn');
if (copilotBtn) {
let copilotPolling = false;
copilotBtn.addEventListener('click', async () => {
if (copilotPolling) return;
const status = el('adm-copilotStatus');
const reset = () => { copilotBtn.disabled = false; copilotBtn.textContent = 'Connect GitHub Copilot'; copilotPolling = false; };
status.textContent = ''; status.className = 'adm-ep-inline-msg';
copilotBtn.disabled = true; copilotBtn.textContent = 'Starting...';
copilotPolling = true;
let start;
try {
const res = await fetch('/api/copilot/device/start', { method: 'POST', body: new FormData(), credentials: 'same-origin' });
start = await res.json();
if (!res.ok) { status.textContent = start.detail || 'Failed to start login'; status.className = 'admin-error'; reset(); return; }
} catch (e) { status.textContent = 'Request failed'; status.className = 'admin-error'; reset(); return; }
async function _startProviderDeviceAuth(providerKey, triggerEl = null) {
if (deviceAuthPolling) return;
const config = PROVIDER_DEVICE_FLOWS[providerKey];
if (!config) return;
const status = el('adm-deviceAuthStatus') || _endpointMsg('api');
if (!status) return;
const triggerText = triggerEl ? triggerEl.textContent : '';
// Render an error with an inline "Try again" (the top button is hidden for
// device-auth providers, so retry lives here). Built with DOM methods, not
// innerHTML. Call reset() first so the deviceAuthPolling guard is cleared.
const showAuthError = (text) => {
status.className = 'admin-error';
status.textContent = text + ' ';
const retry = document.createElement('button');
retry.type = 'button';
retry.className = 'admin-btn-sm';
retry.textContent = 'Try again';
retry.addEventListener('click', () => { _startProviderDeviceAuth(providerKey, triggerEl); });
status.appendChild(retry);
};
const reset = () => {
if (triggerEl) {
triggerEl.disabled = false;
triggerEl.textContent = triggerText || 'Add';
}
deviceAuthPolling = false;
_setApiFormForProvider();
};
status.textContent = '';
status.className = 'adm-ep-inline-msg';
if (triggerEl) {
triggerEl.disabled = true;
triggerEl.textContent = 'Starting...';
}
deviceAuthPolling = true;
_setApiFormForProvider();
status.textContent = `Starting ${config.label} sign-in...`;
const { poll_id, user_code, verification_uri, verification_uri_complete, interval, expires_in } = start;
// Prefer the "complete" URL — it embeds the code so the user only has to
// click "Authorize" (no manual code entry).
const authUrl = verification_uri_complete || verification_uri || '';
const esc = (s) => String(s || '').replace(/[<>&"]/g, (c) => ({ '<': '&lt;', '>': '&gt;', '&': '&amp;', '"': '&quot;' }[c]));
copilotBtn.textContent = 'Waiting…';
// Cohesive waiting panel: spinner + status line, the device code as a
// copyable chip, and a primary "Authorize on GitHub" action.
status.className = '';
status.innerHTML =
'<div class="adm-copilot-panel">' +
'<div class="adm-copilot-wait"><span class="admin-spinner"></span>' +
'<span>Waiting for GitHub authorization…</span></div>' +
'<div class="adm-copilot-coderow">' +
'<span class="adm-copilot-code-label">Code</span>' +
'<code class="adm-copilot-code">' + esc(user_code) + '</code>' +
'<button type="button" class="admin-btn-sm adm-copilot-copy">Copy</button>' +
'</div>' +
'<a class="admin-btn-add adm-copilot-auth" href="' + encodeURI(authUrl) + '" target="_blank" rel="noopener">Authorize on GitHub ↗</a>' +
'<div class="adm-copilot-hint">A new tab opened on GitHub — approve there to finish. Didn\'t open? Use the button above.</div>' +
'</div>';
const copyBtn = status.querySelector('.adm-copilot-copy');
if (copyBtn) copyBtn.addEventListener('click', async () => {
try { await navigator.clipboard.writeText(user_code || ''); copyBtn.textContent = 'Copied'; setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500); } catch (e) {}
try {
const result = await runProviderDeviceFlow(providerKey, {
openWindow: () => {},
onStart: ({ start, authUrl }) => {
if (triggerEl) triggerEl.textContent = 'Waiting...';
status.className = '';
const authLabel = providerKey === 'copilot' ? 'Authorize on GitHub' : 'Authorize with OpenAI';
const waitLabel = providerKey === 'copilot' ? 'Waiting for GitHub authorization...' : 'Waiting for ChatGPT authorization...';
status.innerHTML =
'<div class="adm-copilot-panel">' +
'<div class="adm-copilot-wait"><span class="admin-spinner"></span>' +
'<span>' + esc(waitLabel) + '</span></div>' +
'<div class="adm-copilot-coderow">' +
'<span class="adm-copilot-code-label">Code</span>' +
'<code class="adm-copilot-code">' + esc(start.user_code) + '</code>' +
'<button type="button" class="admin-btn-sm adm-device-auth-copy">Copy</button>' +
'</div>' +
'<a class="admin-btn-add adm-copilot-auth" href="' + encodeURI(authUrl || '') + '" target="_blank" rel="noopener">' + esc(authLabel) + ' ↗</a>' +
'</div>';
const copyBtn = status.querySelector('.adm-device-auth-copy');
if (copyBtn) copyBtn.addEventListener('click', async () => {
const code = start.user_code || '';
let ok = false;
try {
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(code);
ok = true;
}
} catch (e) {}
if (!ok) {
// navigator.clipboard is unavailable in non-secure contexts (HTTP
// self-host over a LAN IP), so fall back to execCommand('copy').
const ta = document.createElement('textarea');
ta.value = code;
ta.style.cssText = 'position:fixed;top:0;left:0;width:1px;height:1px;padding:0;border:0;opacity:0;font-size:16px;';
document.body.appendChild(ta);
ta.focus();
ta.select();
try { ta.setSelectionRange(0, code.length); } catch (e) {}
try { ok = document.execCommand('copy'); } catch (e) {}
ta.remove();
}
copyBtn.textContent = ok ? 'Copied' : 'Failed';
setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500);
});
},
});
try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {}
const deadline = Date.now() + (expires_in || 900) * 1000;
const stepMs = Math.max((interval || 5), 2) * 1000;
const done = (cls, text) => { status.className = cls; status.textContent = text; reset(); };
const poll = async () => {
if (Date.now() > deadline) { done('admin-error', 'Authorization expired — try again.'); return; }
try {
const fd = new FormData(); fd.append('poll_id', poll_id);
const r = await fetch('/api/copilot/device/poll', { method: 'POST', body: fd, credentials: 'same-origin' });
const d = await r.json();
if (d.status === 'authorized') {
const n = ((d.endpoint && d.endpoint.models) || []).length;
done('admin-success', '✓ Connected — ' + n + ' Copilot model' + (n !== 1 ? 's' : '') + ' available.');
if (d.endpoint && d.endpoint.id) _recentlyAddedEpId = String(d.endpoint.id);
await loadEndpoints();
await _selectAddedModelInChat(d.endpoint || {});
return;
}
if (d.status === 'failed') { done('admin-error', 'Authorization failed (' + (d.error || 'denied') + ').'); return; }
} catch (e) { /* transient — keep polling */ }
setTimeout(poll, stepMs);
};
setTimeout(poll, stepMs);
});
if (result.status === 'authorized') {
const endpoint = result.endpoint || {};
const n = ((endpoint && endpoint.models) || []).length;
status.className = 'admin-success';
status.textContent = 'Connected - ' + n + ' ' + config.label + ' model' + (n !== 1 ? 's' : '') + ' available.';
if (endpoint && endpoint.id) _recentlyAddedEpId = String(endpoint.id);
await loadEndpoints();
await _selectAddedModelInChat(endpoint || {});
reset();
return;
}
if (result.status === 'failed') {
reset();
showAuthError('Authorization failed (' + (result.error || 'denied') + ').');
return;
}
if (result.status === 'expired') {
reset();
showAuthError('Authorization expired.');
return;
}
} catch (e) {
reset();
showAuthError(formatDeviceFlowError(e));
}
}
// Local "Add" button — sibling form for self-hosted base URLs.
+39 -15
View File
@@ -680,9 +680,11 @@ export function applyModelColor(roleEl, modelName) {
html += '<div><span class="ctx-label">Max tokens</span> ' + _mt.toLocaleString() + ' <span style="opacity:0.4">(configured)</span></div>';
}
}
if (info && info.input != null) html += '<div><span class="ctx-label">Input</span> $' + info.input.toFixed(2) + ' / 1M</div>';
if (info && info.output != null) html += '<div><span class="ctx-label">Output</span> $' + info.output.toFixed(2) + ' / 1M</div>';
if (!info) html += '<div style="opacity:0.4;font-size:0.85em;margin-top:4px;">No pricing data available</div>';
if (isCostTrackedEndpoint(_epUrl)) {
if (info && info.input != null) html += '<div><span class="ctx-label">Input</span> $' + info.input.toFixed(2) + ' / 1M</div>';
if (info && info.output != null) html += '<div><span class="ctx-label">Output</span> $' + info.output.toFixed(2) + ' / 1M</div>';
if (!info) html += '<div style="opacity:0.4;font-size:0.85em;margin-top:4px;">No pricing data available</div>';
}
popup.innerHTML = html;
const rect = roleEl.getBoundingClientRect();
popup.style.top = (rect.bottom + 4) + 'px';
@@ -735,11 +737,31 @@ export function isLocalEndpoint(url) {
return false;
}
/** Cost for the current turn, returning null (free) for local endpoints. */
function _billableCost(model, inputTokens, outputTokens) {
const url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
export function isSubscriptionEndpoint(url) {
if (!url) return false;
try {
const parsed = new URL(url);
const path = parsed.pathname.replace(/\/+$/, '');
return parsed.hostname === 'chatgpt.com'
&& (path === '/backend-api/codex' || path.startsWith('/backend-api/codex/'));
} catch (_e) {
return false;
}
}
function _currentEndpointUrl() {
return (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
? window.sessionModule.getCurrentEndpointUrl() : null;
if (isLocalEndpoint(url)) return null;
}
export function isCostTrackedEndpoint(url) {
return !isLocalEndpoint(url) && !isSubscriptionEndpoint(url);
}
/** Cost for the current turn, returning null for non-billable endpoints. */
function _billableCost(model, inputTokens, outputTokens) {
const url = _currentEndpointUrl();
if (!isCostTrackedEndpoint(url)) return null;
return getModelCost(model, inputTokens, outputTokens);
}
@@ -784,11 +806,10 @@ export function resetSessionCost(sessionId) {
export function updateSessionCostUI() {
const el = document.getElementById('session-cost-display');
if (!el) return;
// Local model? It's free — hide the badge and clear any stale cost that a
// previous (buggy) cloud-rate billing left in localStorage for this session.
const _url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
? window.sessionModule.getCurrentEndpointUrl() : null;
if (isLocalEndpoint(_url)) {
// Non-billable endpoint? Hide the badge and clear stale cost that a previous
// cloud-rate calculation may have left in localStorage for this session.
const _url = _currentEndpointUrl();
if (!isCostTrackedEndpoint(_url)) {
const sid = window.sessionModule && window.sessionModule.getCurrentSessionId();
if (sid && getSessionCost(sid) > 0) {
try {
@@ -1708,7 +1729,8 @@ export function displayMetrics(messageElement, metrics) {
e.stopPropagation();
document.querySelectorAll('.ctx-popup').forEach(p => { if (typeof p._dismiss === 'function') p._dismiss(); else p.remove(); });
const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : 'n/a';
const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : '';
const costRows = costStr ? `<div><span class="ctx-label">Cost</span> ${costStr}</div>` : '';
const speedStr = tps != null && tps !== 'undefined' ? `${tps} tok/s` : 'n/a';
const totalTok = inputTokens + outputTokens;
const ctxColor = ctxPct >= 85 ? 'var(--red, #e06c75)' : ctxPct >= 70 ? '#ff9900' : 'var(--color-muted-alt, #6b7280)';
@@ -1722,7 +1744,7 @@ export function displayMetrics(messageElement, metrics) {
// Session total cost
let sessionCostStr = '';
const sc = getSessionCost();
if (sc > 0) {
if (costStr && sc > 0) {
sessionCostStr = `<div><span class="ctx-label">Session</span> $${sc < 0.01 ? sc.toFixed(4) : sc.toFixed(3)}</div>`;
}
@@ -1738,7 +1760,7 @@ export function displayMetrics(messageElement, metrics) {
<div><span class="ctx-label">Time</span> ${responseTime}s</div>
${prepTime != null ? `<div><span class="ctx-label">Prep</span> ${prepTime}s</div>` : ''}
${modelWaitTime != null ? `<div><span class="ctx-label">Model wait</span> ${modelWaitTime}s</div>` : ''}
<div><span class="ctx-label">Cost</span> ${costStr}</div>
${costRows}
${sessionCostStr}
${prepDetails ? `<div style="margin-top:6px;padding-top:6px;border-top:1px solid var(--border);font-size:0.85em;opacity:0.8;">
<div style="font-weight:600;margin-bottom:4px;color:var(--fg);">Agent prep</div>
@@ -2392,6 +2414,8 @@ const chatRenderer = {
modelColor,
applyModelColor,
getModelCost,
isCostTrackedEndpoint,
isSubscriptionEndpoint,
getImageCost,
getSessionCost,
resetSessionCost,
+128
View File
@@ -0,0 +1,128 @@
// Shared DOM-free provider device-flow runner.
export const PROVIDER_DEVICE_FLOWS = {
copilot: {
label: 'GitHub Copilot',
startUrl: '/api/copilot/device/start',
pollUrl: '/api/copilot/device/poll',
authUrl(start) {
return start?.verification_uri_complete || start?.verification_uri || '';
},
},
'chatgpt-subscription': {
label: 'ChatGPT Subscription',
startUrl: '/api/chatgpt-subscription/device/start',
pollUrl: '/api/chatgpt-subscription/device/poll',
authUrl(start) {
return start?.verification_uri || '';
},
},
};
function _formData() {
if (typeof FormData !== 'undefined') return new FormData();
return new URLSearchParams();
}
async function _jsonOrEmpty(response) {
try {
return await response.json();
} catch (_) {
return {};
}
}
function _messageFromPayload(payload, fallback) {
if (payload && typeof payload.detail === 'string' && payload.detail.trim()) {
return payload.detail.trim();
}
if (payload && typeof payload.error === 'string' && payload.error.trim()) {
return payload.error.trim();
}
if (payload && typeof payload.message === 'string' && payload.message.trim()) {
return payload.message.trim();
}
return fallback;
}
export function formatDeviceFlowError(error, fallback = 'Request failed') {
if (!error) return fallback;
if (typeof error === 'string') return error;
if (error.detail) return String(error.detail);
if (error.message) return String(error.message);
return fallback;
}
async function _fetchJson(fetchImpl, url, options, fallback) {
const response = await fetchImpl(url, options);
const payload = await _jsonOrEmpty(response);
if (!response.ok) {
throw new Error(_messageFromPayload(payload, fallback || `Request failed (HTTP ${response.status})`));
}
return payload;
}
function _defaultSleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms));
}
async function _callCallback(fn, payload) {
if (typeof fn === 'function') await fn(payload);
}
export async function runProviderDeviceFlow(provider, options = {}) {
const cfg = PROVIDER_DEVICE_FLOWS[provider];
if (!cfg) throw new Error(`Unknown device-flow provider: ${provider}`);
const fetchImpl = options.fetchImpl || globalThis.fetch?.bind(globalThis);
if (!fetchImpl) throw new Error('Fetch API is unavailable');
const openWindow = options.openWindow || ((url) => {
if (globalThis.window && typeof globalThis.window.open === 'function') {
globalThis.window.open(url, '_blank', 'noopener');
}
});
const sleep = options.sleep || _defaultSleep;
const now = options.now || (() => Date.now());
const formData = options.formData || _formData();
const start = await _fetchJson(fetchImpl, cfg.startUrl, {
method: 'POST',
body: formData,
credentials: 'same-origin',
}, `Failed to start ${cfg.label} sign-in`);
if (!start.poll_id) throw new Error(`${cfg.label} sign-in did not return a poll id`);
const authUrl = cfg.authUrl(start);
await _callCallback(options.onStart, { provider, config: cfg, start, authUrl });
if (authUrl) openWindow(authUrl);
const deadline = now() + Number(start.expires_in || 900) * 1000;
let stepMs = Math.max(Number(start.interval || 5), 2) * 1000;
while (true) {
if (now() > deadline) return { status: 'expired' };
await _callCallback(options.onWaiting, { provider, config: cfg, start, authUrl });
await sleep(stepMs);
if (now() > deadline) return { status: 'expired' };
const fd = _formData();
fd.append('poll_id', start.poll_id);
const poll = await _fetchJson(fetchImpl, cfg.pollUrl, {
method: 'POST',
body: fd,
credentials: 'same-origin',
}, `${cfg.label} sign-in poll failed`);
await _callCallback(options.onPoll, { provider, config: cfg, start, poll });
if (poll.status === 'authorized') {
return { status: 'authorized', endpoint: poll.endpoint || {} };
}
if (poll.status === 'failed') {
return { status: 'failed', error: poll.error || 'denied' };
}
if (poll.interval) {
stepMs = Math.max(Number(poll.interval || 5), 2) * 1000;
}
}
}
+5
View File
@@ -15,6 +15,10 @@ const _PROVIDERS = [
[/opencode/i,
'<svg viewBox="0 0 24 30" fill="currentColor"><path d="M18 6H6V24H18V6ZM24 30H0V0H24V30Z"/></svg>'],
// GitHub / Copilot
[/github|copilot/i,
'<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .5A12 12 0 0 0 8.2 23.9c.6.1.8-.3.8-.6v-2.1c-3.3.7-4-1.4-4-1.4-.5-1.4-1.3-1.8-1.3-1.8-1.1-.8.1-.8.1-.8 1.2.1 1.9 1.3 1.9 1.3 1.1 1.9 2.9 1.3 3.6 1 .1-.8.4-1.3.8-1.6-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.3-.5-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.4 11.4 0 0 1 6 0C15.3 4.7 16 5 16 5c.6 1.6.2 2.9.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .5Z"/></svg>'],
// OpenRouter
[/openrouter|open router/i,
'<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="5" cy="12" r="2.5"/><circle cx="19" cy="6" r="2.5"/><circle cx="19" cy="18" r="2.5"/><path d="M7.5 12h4.5c2 0 2.5-6 4.5-6"/><path d="M12 12c2 0 2.5 6 4.5 6"/></svg>'],
@@ -102,6 +106,7 @@ export function providerLogo(modelId) {
// doesn't match `x.ai`.
const _ENDPOINT_LABELS = [
[/(^|\.)githubcopilot\.com$/i, "GitHub Copilot"],
[/(^|\.)chatgpt\.com$/i, "ChatGPT Subscription"],
[/(^|\.)openrouter\.ai$/i, "OpenRouter"],
[/(^|\.)anthropic\.com$/i, "Anthropic"],
[/(^|\.)openai\.com$/i, "OpenAI"],
+49 -3
View File
@@ -5,7 +5,7 @@
import { COMMANDS, LEGACY_ALIASES } from './slashCommands.js';
const POPUP_ID = 'slash-autocomplete';
const MAX_VISIBLE = 12;
const MAX_VISIBLE = 14;
// Flatten the registry into a searchable list of leaf entries. Each entry is
// either a top-level command or a "cmd sub" pair (so subcommands get their
@@ -81,6 +81,23 @@ function _flatten() {
return out;
}
async function _loadSkillEntries() {
try {
const res = await fetch('/api/skills/slash-catalog', { credentials: 'same-origin' });
if (!res.ok) return [];
const data = await res.json();
return (Array.isArray(data.skills) ? data.skills : []).map(s => ({
token: s.token || `/${s.name}`,
aliases: [],
category: s.category || 'Skills',
help: s.help || 'Run skill',
usage: s.usage || `${s.token || `/${s.name}`} <request>`,
})).filter(e => e.token && e.token.startsWith('/'));
} catch {
return [];
}
}
function _scoreMatch(entry, query) {
// query already starts with "/". Match against token + aliases. Prefix wins
// over substring; alias match scores slightly lower than token match.
@@ -98,6 +115,17 @@ function _scoreMatch(entry, query) {
return 0;
}
function _exactCommandGroupItems(all, query) {
const q = query.toLowerCase();
if (!/^\/[a-z0-9_-]+$/i.test(q)) return [];
const parent = all.find(entry => entry.token.toLowerCase() === q);
if (!parent) return [];
const prefix = q + ' ';
const children = all.filter(entry => entry.token.toLowerCase().startsWith(prefix));
if (!children.length) return [];
return children.concat(parent);
}
function _ensurePopup(textarea) {
let el = document.getElementById(POPUP_ID);
if (el) return el;
@@ -164,7 +192,7 @@ export function initSlashAutocomplete(textarea) {
if (!textarea || textarea._slashAcWired) return;
textarea._slashAcWired = true;
const all = _flatten();
let all = _flatten();
let popup = null;
let visible = false;
let items = [];
@@ -191,12 +219,17 @@ export function initSlashAutocomplete(textarea) {
// the menu hides — we don't autocomplete mid-sentence.
if (!v.startsWith('/') || v.includes('\n')) { hide(); return; }
const query = v.trim();
items = all
const groupItems = _exactCommandGroupItems(all, query);
if (groupItems.length) {
items = groupItems.slice(0, MAX_VISIBLE);
} else {
items = all
.map(e => ({ e, s: _scoreMatch(e, query) }))
.filter(x => x.s > 0)
.sort((a, b) => b.s - a.s)
.slice(0, MAX_VISIBLE)
.map(x => x.e);
}
if (!items.length && query.length > 1) { hide(); return; }
if (!items.length) {
// Just "/" with no matches — fall back to showing everything up to MAX_VISIBLE
@@ -207,6 +240,19 @@ export function initSlashAutocomplete(textarea) {
_render(popup, items, selectedIdx, query);
};
_loadSkillEntries().then(skillEntries => {
if (!skillEntries.length) return;
const seen = new Set(all.map(e => e.token));
const merged = all.slice();
for (const entry of skillEntries) {
if (seen.has(entry.token)) continue;
seen.add(entry.token);
merged.push(entry);
}
all = merged;
if (visible) refresh();
});
const insert = (token) => {
textarea.value = token + ' ';
textarea.dispatchEvent(new Event('input', { bubbles: true }));
+351 -71
View File
@@ -21,6 +21,7 @@ import workspaceModule from './workspace.js';
import settingsModule from './settings.js';
import cookbookModule from './cookbook.js';
import { EVAL_PROMPTS } from './compare/index.js';
import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js';
// ── Module state ──────────────────────────────────────────────────────
@@ -58,11 +59,28 @@ const SETUP_PROVIDER_URLS = {
'opencode-go': { name: 'OpenCode Go', url: 'https://opencode.ai/zen/go/v1' },
};
const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'ollama', 'xai', 'anthropic', 'groq', 'gemini', 'opencode-zen', 'opencode-go'];
const SETUP_PROVIDER_HINT = SETUP_PROVIDER_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_NAMES[SETUP_PROVIDER_NAMES.length - 1];
const SETUP_DEVICE_AUTH_PROVIDERS = [
{ key: 'copilot', name: 'GitHub Copilot', aliases: ['github'], command: '/setup copilot' },
{ key: 'chatgpt-subscription', name: 'ChatGPT Subscription', aliases: ['chatgptsubscription', 'chatgpt-sub', 'codex'], command: '/setup chatgpt-subscription' },
];
const SETUP_PROVIDER_HINT_NAMES = SETUP_PROVIDER_NAMES.concat(SETUP_DEVICE_AUTH_PROVIDERS.map(provider => provider.key));
const SETUP_PROVIDER_HINT = SETUP_PROVIDER_HINT_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_HINT_NAMES[SETUP_PROVIDER_HINT_NAMES.length - 1];
const SETUP_LOCAL_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><rect x="2" y="3" width="20" height="14" rx="2"/><path d="M8 21h8"/><path d="M12 17v4"/></svg>';
const SETUP_API_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>';
const SETUP_SETTINGS_ICON = '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-2px;margin-right:5px;"><circle cx="12" cy="12" r="3"/><path d="M19.4 15a1.65 1.65 0 0 0 .33 1.82l.06.06a2 2 0 0 1-2.83 2.83l-.06-.06a1.65 1.65 0 0 0-1.82-.33 1.65 1.65 0 0 0-1 1.51V21a2 2 0 0 1-4 0v-.09A1.65 1.65 0 0 0 9 19.4a1.65 1.65 0 0 0-1.82.33l-.06.06a2 2 0 0 1-2.83-2.83l.06-.06a1.65 1.65 0 0 0 .33-1.82 1.65 1.65 0 0 0-1.51-1H3a2 2 0 0 1 0-4h.09A1.65 1.65 0 0 0 4.6 9a1.65 1.65 0 0 0-.33-1.82l-.06-.06a2 2 0 0 1 2.83-2.83l.06.06a1.65 1.65 0 0 0 1.82.33H9a1.65 1.65 0 0 0 1-1.51V3a2 2 0 0 1 4 0v.09a1.65 1.65 0 0 0 1 1.51 1.65 1.65 0 0 0 1.82-.33l.06-.06a2 2 0 0 1 2.83 2.83l-.06.06a1.65 1.65 0 0 0-.33 1.82V9a1.65 1.65 0 0 0 1.51 1H21a2 2 0 0 1 0 4h-.09a1.65 1.65 0 0 0-1.51 1z"/></svg>';
function _setupApiProviderChips() {
return SETUP_PROVIDER_NAMES.map(name =>
'<span class="setup-clickable-provider" data-setup-kind="api-key" data-setup-provider="' + name + '" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
).join(' ');
}
function _setupDeviceAuthProviderChips() {
return SETUP_DEVICE_AUTH_PROVIDERS.map(provider =>
'<span class="setup-clickable-provider" data-setup-kind="device-auth" data-setup-provider="' + provider.key + '" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Run ' + provider.command + '">' + provider.name + '</span>'
).join(' ');
}
function _setupProviderFromInput(input) {
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '');
const aliases = {
@@ -84,6 +102,17 @@ function _setupProviderFromInput(input) {
return SETUP_PROVIDER_URLS[aliases[raw] || raw] || null;
}
function _setupDeviceAuthProviderFromInput(input) {
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '').replace(/_/g, '-');
if (!raw) return '';
for (const provider of SETUP_DEVICE_AUTH_PROVIDERS) {
const candidates = [provider.key, provider.name, ...(provider.aliases || [])]
.map(value => String(value || '').toLowerCase().replace(/\s+/g, '').replace(/_/g, '-'));
if (candidates.includes(raw)) return provider.key;
}
return '';
}
function _extractSetupProviderCredential(input) {
const raw = (input || '').trim();
if (!raw) return null;
@@ -158,9 +187,8 @@ function _setupReply(text, remember = true) {
}
function _showSetupEndpointChoices() {
const providers = SETUP_PROVIDER_NAMES.map(name =>
'<span class="setup-clickable-provider" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
).join(' ');
const providers = _setupApiProviderChips();
const deviceAuthProviders = _setupDeviceAuthProviderChips();
return slashReply(
'<div class="setup-guide-no-censor" style="display:grid;gap:10px;">' +
'<div>' +
@@ -178,6 +206,7 @@ function _showSetupEndpointChoices() {
'<div>Paste provider name then API key (example):</div>' +
'<pre style="margin:4px 0 0;"><code class="setup-clickable-code" style="cursor:pointer;text-decoration:underline;" title="Click to fill in chat">deepseek sk-...</code></pre>' +
'<div style="margin-top:8px;font-size:1em;"><span>Supported providers:</span><br>' + providers + '</div>' +
'<div style="margin-top:8px;font-size:1em;"><span>Account sign-in:</span><br>' + deviceAuthProviders + '</div>' +
'</div>' +
'</div>'
);
@@ -208,9 +237,8 @@ function _showSetupEndpointChoicesStreamed(options = {}) {
text: 'deepseek sk-...',
copyText: 'deepseek sk-...',
},
{ kind: 'p', html: '<strong>Supported providers:</strong><br>' + SETUP_PROVIDER_NAMES.map(name =>
'<span class="setup-clickable-provider" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
).join(' ') },
{ kind: 'p', html: '<strong>Supported providers:</strong><br>' + _setupApiProviderChips() },
{ kind: 'p', html: '<strong>Account sign-in:</strong><br>' + _setupDeviceAuthProviderChips() },
];
return typewriterBlocksReply(blocks, { gap: '4px', bodyClass: 'setup-guide-no-censor', interval: 3 });
}
@@ -231,7 +259,7 @@ async function _hasConfiguredModels() {
}
function _setupProviderPrompt() {
const chips = SETUP_PROVIDER_NAMES.map(name =>
const chips = SETUP_PROVIDER_HINT_NAMES.map(name =>
'<span style="font-weight:650;">' + name + '</span>'
).join(' ');
slashReply('<b>Supported providers:</b><br>' + chips);
@@ -286,6 +314,53 @@ function slashReply(text) {
return { el: div, body };
}
let _skillCatalogCache = { at: 0, items: [] };
async function _loadSkillSlashCatalog(force = false) {
const now = Date.now();
if (!force && (now - _skillCatalogCache.at) < 15000) return _skillCatalogCache.items;
try {
const res = await fetch(`${API_BASE}/api/skills/slash-catalog`, { credentials: 'same-origin' });
if (!res.ok) throw new Error('catalog unavailable');
const data = await res.json();
const items = Array.isArray(data.skills) ? data.skills : [];
_skillCatalogCache = { at: now, items };
return items;
} catch {
return _skillCatalogCache.items || [];
}
}
function _submitComposedMessage(text) {
const msgInput = document.getElementById('message');
const form = document.getElementById('chat-form');
if (!msgInput || !form) return false;
msgInput.value = text;
msgInput.dispatchEvent(new Event('input', { bubbles: true }));
if (typeof form.requestSubmit === 'function') form.requestSubmit();
else form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true }));
return true;
}
async function _invokeSkillByName(name, requestText, ctx) {
const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/invoke`, {
method: 'POST',
credentials: 'same-origin',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ request: requestText || '' })
});
if (!res.ok) {
const err = await res.json().catch(() => null);
slashReply(ctx?.esc ? ctx.esc(err?.detail || 'Skill is not available') : 'Skill is not available');
return true;
}
const data = await res.json();
if (!data.message || !_submitComposedMessage(data.message)) {
slashReply('Could not start skill invocation.');
}
return true;
}
/** Minimal footer for slash replies: copy + dismiss */
function _slashFooter(msgEl) {
const footer = document.createElement('div');
@@ -681,6 +756,13 @@ async function handleSetupWizard(mode, input) {
await _setupProviderPrompt();
return;
}
const deviceAuthProvider = _setupDeviceAuthProviderFromInput(input);
if (deviceAuthProvider) {
_addMessage('user', input);
setupMode = false;
await _setupProviderDeviceFlow(deviceAuthProvider);
return;
}
const paired = _extractSetupProviderCredential(input);
const provider = paired?.provider || _setupProviderFromInput(input);
if (!provider) {
@@ -1429,6 +1511,42 @@ async function _cmdModels(args, ctx) {
return true;
}
async function _cmdModel(args, ctx) {
const sub = (args[0] || '').toLowerCase();
if (sub === 'list' || sub === 'ls') return _cmdModels(args.slice(1), ctx);
const model = sessionModule.getCurrentModel ? sessionModule.getCurrentModel() : '';
const endpoint = sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : '';
slashReply(`<pre>${[
`Current model: ${ctx.esc(model || 'None selected')}`,
endpoint ? `Endpoint: ${ctx.esc(endpoint)}` : 'Endpoint: not available',
'',
'Usage: /model list to show all available models'
].join('\n')}</pre>`);
return true;
}
async function _cmdMcp(args, ctx) {
const res = await fetch(`${API_BASE}/api/mcp/servers`, { credentials: 'same-origin' });
if (!res.ok) {
slashReply('MCP status is unavailable for this user.');
return true;
}
const servers = await res.json();
if (!Array.isArray(servers) || !servers.length) {
slashReply('No MCP servers configured.');
return true;
}
const lines = servers.map(s => {
const status = s.status || (s.is_enabled ? 'enabled' : 'disabled');
const enabled = Number(s.enabled_tool_count ?? s.tool_count ?? 0);
const total = Number(s.tool_count ?? enabled);
return `${s.name || s.id || 'MCP server'} - ${status} (${enabled}/${total} tools)`;
});
slashReply(`<pre>${lines.map(line => ctx.esc(line)).join('\n')}</pre>`);
return true;
}
// ── Memory ──
async function _cmdMemoryList(args, ctx) {
@@ -1507,6 +1625,73 @@ async function _cmdMemorySearch(args, ctx) {
return true;
}
// ── Skills ──
async function _cmdSkills(args, ctx) {
const sub = (args[0] || 'list').toLowerCase();
const rest = args.slice(1);
if (sub === 'list' || sub === 'ls') {
const skills = await _loadSkillSlashCatalog(true);
if (!skills.length) {
slashReply('No published skills available for slash commands.');
return true;
}
const lines = skills.map(s => {
const uses = Number(s.uses || 0);
const useText = uses > 0 ? ` uses:${uses}` : '';
return `${ctx.esc(String(s.token || '').padEnd(24))}${ctx.esc(s.help || '')}${useText}`;
});
slashReply(`<pre>${lines.join('\n')}</pre>`);
return true;
}
if (sub === 'search' || sub === 'find') {
const query = rest.join(' ').trim();
if (!query) { slashReply('Usage: /skills search query'); return true; }
const res = await fetch(`${API_BASE}/api/skills/search`, {
method: 'POST',
credentials: 'same-origin',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ query })
});
if (!res.ok) { slashReply('Skill search failed.'); return true; }
const data = await res.json();
const skills = Array.isArray(data.skills) ? data.skills : [];
if (!skills.length) { slashReply(`No skills found for "${ctx.esc(query)}".`); return true; }
const lines = skills.map(s =>
ctx.esc(`/${s.name || s.id || ''}`.padEnd(24)) + ctx.esc(s.description || '')
);
slashReply(`<pre>${lines.join('\n')}</pre>`);
return true;
}
if (sub === 'view' || sub === 'cat' || sub === 'show') {
const name = (rest[0] || '').trim();
if (!name) { slashReply('Usage: /skills view name'); return true; }
const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/markdown`, { credentials: 'same-origin' });
if (!res.ok) { slashReply(`Skill "${ctx.esc(name)}" was not found.`); return true; }
const data = await res.json();
slashReply(`<pre>${ctx.esc(data.markdown || '')}</pre>`);
return true;
}
if (sub === 'use' || sub === 'run') {
const name = (rest[0] || '').trim();
if (!name) { slashReply('Usage: /skills use name request'); return true; }
return _invokeSkillByName(name, rest.slice(1).join(' ').trim(), ctx);
}
slashReply('Usage: /skills list | search query | view name | use name request');
return true;
}
async function _cmdReloadSkills(args, ctx) {
const skills = await _loadSkillSlashCatalog(true);
slashReply(`Reloaded skills. ${skills.length} skill command${skills.length === 1 ? '' : 's'} available.`);
return true;
}
// ── Note (quick Notes shortcut) ──
async function _cmdNote(args, ctx) {
@@ -1799,6 +1984,53 @@ Uploads: ${d.uploads || '?'}</pre>`);
return true;
}
async function _cmdUsage(args, ctx) {
const sid = ctx.sid;
if (!sid) {
slashReply('No active session.');
return true;
}
let session = null;
try {
const sessions = sessionModule.getSessions ? sessionModule.getSessions() : [];
session = (sessions || []).find(s => s.id === sid) || null;
if (!session) {
const res = await fetch(`${API_BASE}/api/sessions`, { credentials: 'same-origin' });
if (res.ok) {
const data = await res.json();
const items = Array.isArray(data) ? data : (data.sessions || data.items || []);
session = items.find(s => s.id === sid) || null;
}
}
} catch (_) {}
const model = session?.model || 'Unknown';
const endpointUrl = session?.endpoint_url || (
sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : ''
);
const messageCount = Number(session?.message_count || 0);
const totalTokens = Number(session?.total_tokens || 0);
const costTracked = chatRenderer.isCostTrackedEndpoint ? chatRenderer.isCostTrackedEndpoint(endpointUrl) : true;
const cost = costTracked && chatRenderer.getSessionCost ? Number(chatRenderer.getSessionCost(sid) || 0) : 0;
const costLine = costTracked
? (cost > 0
? `Estimated local cost: $${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}`
: 'Estimated local cost: unavailable or zero')
: 'Estimated local cost: not tracked for this endpoint';
slashReply(`<pre>${[
`Session: ${ctx.esc(session?.name || 'Current chat')}`,
`Model: ${ctx.esc(model)}`,
`Messages: ${messageCount.toLocaleString()}`,
`Recorded tokens: ${totalTokens.toLocaleString()}`,
costLine,
'',
'Provider account usage is not available from here; check the provider dashboard for account quota/billing.'
].join('\n')}</pre>`);
return true;
}
// ── Context compaction ──
async function _cmdCompact(args, ctx) {
@@ -4783,39 +5015,53 @@ function _clearSetupCommandInput() {
}
}
// GitHub Copilot device-flow sign-in, driven from chat (mirrors the Settings
// "Connect GitHub Copilot" button). Replies via the setup guide messages.
async function _setupCopilot() {
async function _setupProviderDeviceFlow(providerKey) {
_clearSetupGuideMessages();
await _setupReply('Starting GitHub Copilot sign-in…');
let start;
const config = PROVIDER_DEVICE_FLOWS[providerKey];
if (!config) {
await _setupReply('Provider not recognised.');
return;
}
await _setupReply(`Starting ${config.label} sign-in...`);
try {
const r = await fetch(`${API_BASE}/api/copilot/device/start`, { method: 'POST', body: new FormData(), credentials: 'same-origin' });
start = await r.json();
if (!r.ok) { await _setupReply(start.detail || 'Failed to start Copilot sign-in.'); return; }
} catch (e) { await _setupReply('Request failed.'); return; }
const authUrl = start.verification_uri_complete || start.verification_uri || '';
await _setupReply(`Opening GitHub — approve the request (code ${start.user_code}). Waiting…`);
try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {}
const deadline = Date.now() + (start.expires_in || 900) * 1000;
const stepMs = Math.max((start.interval || 5), 2) * 1000;
const poll = async () => {
if (Date.now() > deadline) { await _setupReply('Copilot sign-in expired — run /setup copilot again.'); return; }
try {
const fd = new FormData(); fd.append('poll_id', start.poll_id);
const r = await fetch(`${API_BASE}/api/copilot/device/poll`, { method: 'POST', body: fd, credentials: 'same-origin' });
const d = await r.json();
if (d.status === 'authorized') {
const n = ((d.endpoint && d.endpoint.models) || []).length;
await _setupReply(`Connected — ${n} Copilot model${n !== 1 ? 's' : ''} available.`);
if (modelsModule) modelsModule.refreshModels(true);
return;
}
if (d.status === 'failed') { await _setupReply('Copilot sign-in failed (' + (d.error || 'denied') + ').'); return; }
} catch (e) { /* transient — keep polling */ }
setTimeout(poll, stepMs);
};
setTimeout(poll, stepMs);
const result = await runProviderDeviceFlow(providerKey, {
onStart: async ({ start, authUrl }) => {
const place = providerKey === 'copilot' ? 'GitHub' : 'OpenAI';
const action = providerKey === 'copilot' ? 'approve the request' : 'enter the code';
if (providerKey === 'chatgpt-subscription') {
slashReply(
'<div class="setup-guide-no-censor" style="display:grid;gap:6px;">' +
'<div>Open this URL in your browser, enter the code, then come back here. Waiting...</div>' +
'<div>Code: <code>' + uiModule.esc(start.user_code || '') + '</code></div>' +
'<div><a href="' + uiModule.esc(authUrl || '') + '" target="_blank" rel="noopener noreferrer">' + uiModule.esc(authUrl || '') + '</a></div>' +
'</div>'
);
return;
}
await _setupReply(`Opening ${place} - ${action} (code ${start.user_code}). Waiting...`);
},
openWindow: (url) => {
if (providerKey === 'chatgpt-subscription') return;
try { if (url) window.open(url, '_blank', 'noopener'); } catch (e) {}
},
});
if (result.status === 'authorized') {
const n = ((result.endpoint && result.endpoint.models) || []).length;
await _setupReply(`Connected - ${n} ${config.label} model${n !== 1 ? 's' : ''} available.`);
if (modelsModule) modelsModule.refreshModels(true);
return;
}
if (result.status === 'failed') {
await _setupReply(`${config.label} sign-in failed (${result.error || 'denied'}).`);
return;
}
if (result.status === 'expired') {
await _setupReply(`${config.label} sign-in expired - run /setup ${providerKey} again.`);
return;
}
} catch (e) {
await _setupReply(formatDeviceFlowError(e));
}
}
async function _cmdSetup(args, ctx) {
@@ -4823,7 +5069,11 @@ async function _cmdSetup(args, ctx) {
_clearSetupCommandInput();
const topic = (args[0] || '').trim().toLowerCase();
const topicArgs = args.slice(1);
if (topic === 'copilot' || topic === 'github') { await _setupCopilot(); return true; }
const deviceAuthProvider = _setupDeviceAuthProviderFromInput(topic);
if (deviceAuthProvider) {
await _setupProviderDeviceFlow(deviceAuthProvider);
return true;
}
const provider = _setupProviderFromInput(topic);
if (provider) {
_clearSetupGuideMessages();
@@ -5463,8 +5713,20 @@ async function _cmdHelp(args, ctx) {
lines.push('');
}
}
const skillCommands = await _loadSkillSlashCatalog(false);
if (skillCommands.length) {
lines.push('Skills:');
for (const skill of skillCommands.slice(0, 20)) {
const token = String(skill.token || '').padEnd(21);
lines.push(` ${ctx.esc(token)}${ctx.esc(skill.help || '')}`);
}
if (skillCommands.length > 20) {
lines.push(` ... ${skillCommands.length - 20} more. Use /skills list`);
}
lines.push('');
}
lines.push('Tip: /<command> --help for details');
lines.push('Shortcuts: /new /rename /fork /web /bash /memories /forget');
lines.push('Shortcuts: /new /rename /fork /web /bash /memories /skills');
slashReply(`<pre style="line-height:1.7">${lines.join('\n')}</pre>`);
return true;
}
@@ -5539,6 +5801,20 @@ const COMMANDS = {
'search': { handler: _cmdMemorySearch, alias: ['grep'], help: 'Search memories', usage: '/memory search q' }
}
},
skills: {
alias: ['skill'],
category: 'Memory',
help: 'List, search, inspect, or run skills',
handler: _cmdSkills,
usage: '/skills list | search query | view name | use name request',
},
'reload-skills': {
alias: ['reload_skills'],
category: 'Memory',
help: 'Refresh the slash skill catalog',
handler: _cmdReloadSkills,
usage: '/reload-skills',
},
rag: {
alias: [],
category: 'RAG',
@@ -5572,7 +5848,7 @@ const COMMANDS = {
category: 'Getting started',
help: 'Add local or API model endpoints',
handler: _cmdSetup,
usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup endpoint',
usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup chatgpt-subscription',
// Provider subs so the autocomplete popup surfaces "/setup deepseek",
// "/setup openai", etc. when the user types "/setup de". Each sub's
// handler is a thin wrapper that re-prepends the sub name and
@@ -5590,6 +5866,7 @@ const COMMANDS = {
xai: { help: 'xAI (Grok)', alias: ['grok'], usage: '/setup xai xai-...', handler: (a, c) => _cmdSetup(['xai', ...a], c) },
ollama: { help: 'Ollama Cloud', usage: '/setup ollama KEY', handler: (a, c) => _cmdSetup(['ollama', ...a], c) },
copilot: { help: 'GitHub Copilot', usage: '/setup copilot', handler: (a, c) => _cmdSetup(['copilot', ...a], c) },
'chatgpt-subscription': { help: 'ChatGPT Subscription', alias: ['codex'], usage: '/setup chatgpt-subscription', handler: (a, c) => _cmdSetup(['chatgpt-subscription', ...a], c) },
local: { help: 'Local model server (vLLM / LM Studio / llama.cpp / Ollama)',
usage: '/setup local http://localhost:8000/v1',
handler: (a, c) => _cmdSetup(['local', ...a], c) },
@@ -5767,8 +6044,22 @@ const COMMANDS = {
handler: (args, ctx) => _cmdToolPanel('compare', args, ctx),
usage: '/compare'
},
mcp: {
alias: [],
category: 'Tools',
help: 'Show MCP server status',
handler: _cmdMcp,
usage: '/mcp'
},
model: {
alias: [],
category: 'Settings',
help: 'Show current chat model',
handler: _cmdModel,
usage: '/model · /model list'
},
models: {
alias: ['model'],
alias: [],
category: 'Settings',
help: 'List available models',
handler: _cmdModels,
@@ -5799,10 +6090,16 @@ const COMMANDS = {
handler: _cmdStats,
usage: '/stats'
},
usage: {
alias: ['cost', 'tokens'],
category: 'Utility',
help: 'Show local usage for the current chat',
handler: _cmdUsage,
usage: '/usage'
},
compact: {
alias: [],
category: 'Utility',
hidden: true,
help: 'Compact older chat messages',
handler: _cmdCompact,
usage: '/compact'
@@ -6075,33 +6372,13 @@ async function handleSlashCommand(input) {
}
// --- 4. Skill invocation: /<skill-name> [request] ---
// If `rawCmd` matches a published skill, pin its SKILL.md to the user's
// message and re-submit. Lets you fire a stored procedure on demand
// without the model having to discover the skill itself.
// If `rawCmd` matches a published skill, the backend records usage and
// returns a skill-pinned message to submit as the next agent turn.
try {
const skillRes = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(rawCmd)}/markdown`, { credentials: 'same-origin' });
if (skillRes.ok) {
const skillData = await skillRes.json();
const md = skillData.markdown || '';
if (md) {
_showUser();
const request = args.join(' ').trim();
const msgInput = document.getElementById('message');
const composed =
`Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n` +
`--- BEGIN SKILL ---\n${md}\n--- END SKILL ---\n\n` +
(request ? `Request: ${request}` : `Request: (use the skill as appropriate)`);
if (msgInput) {
msgInput.value = composed;
const form = document.getElementById('chat-form');
if (form && typeof form.requestSubmit === 'function') {
form.requestSubmit();
} else if (form) {
form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true }));
}
}
return true;
}
const catalog = await _loadSkillSlashCatalog(false);
if (catalog.some(s => s.name === rawCmd)) {
_showUser();
return await _invokeSkillByName(rawCmd, args.join(' ').trim(), ctx);
}
} catch (_) { /* fall through to fuzzy match */ }
@@ -6158,10 +6435,13 @@ export function initSlashCommands(deps) {
const providerEl = e.target.closest('.setup-clickable-provider');
if (providerEl) {
e.preventDefault();
const providerKey = providerEl.dataset.setupProvider || providerEl.textContent.trim();
const providerName = providerEl.textContent.trim();
const messageInput = document.getElementById('message');
if (messageInput) {
const text = providerName + ' sk-';
const text = providerEl.dataset.setupKind === 'device-auth'
? '/setup ' + providerKey
: providerName + ' sk-';
messageInput.value = text;
messageInput.dispatchEvent(new Event('input', { bubbles: true }));
messageInput.focus();
+65
View File
@@ -0,0 +1,65 @@
"""Static regressions for Add Models provider device-flow UX."""
from pathlib import Path
_REPO = Path(__file__).resolve().parent.parent
_INDEX = (_REPO / "static" / "index.html").read_text(encoding="utf-8")
_ADMIN = (_REPO / "static" / "js" / "admin.js").read_text(encoding="utf-8")
def _between(src: str, start: str, end: str) -> str:
start_idx = src.index(start)
end_idx = src.index(end, start_idx)
return src[start_idx:end_idx]
def test_copilot_and_chatgpt_subscription_are_dropdown_device_auth_options():
assert 'value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot' in _INDEX
assert 'value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription' in _INDEX
assert 'id="adm-deviceAuthStatus"' in _INDEX
def test_provider_selection_is_inert_and_add_button_starts_device_flow():
change_block = _between(_ADMIN, "provider.addEventListener('change'", "urlInput.addEventListener('input'")
add_block = _between(_ADMIN, "el('adm-epAddBtn').addEventListener('click'", "async function _startProviderDeviceAuth")
assert "_startProviderDeviceAuth" not in change_block
assert "_startProviderDeviceAuth(deviceAuthProvider" in add_block
def test_device_auth_selection_disables_and_dims_api_test_button():
form_block = _between(_ADMIN, "function _setApiFormForProvider()", "function _renderPickerMenu()")
assert "testBtn.disabled = true" in form_block
assert "testBtn.style.opacity = '0.45'" in form_block
assert "testBtn.style.cursor = 'not-allowed'" in form_block
assert "testBtn.disabled = false" in form_block
assert "testBtn.style.opacity = ''" in form_block
assert "testBtn.style.cursor = ''" in form_block
def test_device_auth_keeps_manual_auth_button_without_auto_opening_tab():
auth_block = _between(_ADMIN, "async function _startProviderDeviceAuth", "// Local \"Add\" button")
assert "Authorize with OpenAI" in auth_block
assert "Authorize on GitHub" in auth_block
assert "adm-copilot-panel" in auth_block
assert "adm-device-auth-copy" in auth_block
assert "openWindow: () => {}" in auth_block
assert "A new tab opened" not in auth_block
def test_loud_oauth_copy_and_removed_button_hooks_do_not_return():
forbidden = [
"Click Add to start",
"uses account sign-in",
"Uses ChatGPT/Codex OAuth, not an OpenAI API key.",
"adm-chatgptStatus",
"adm-chatgptConnectBtn",
"adm-copilotConnectBtn",
"adm-copilotStatus",
]
for needle in forbidden:
assert needle not in _INDEX
assert needle not in _ADMIN
+280
View File
@@ -0,0 +1,280 @@
"""DB-backed ChatGPT Subscription endpoint provisioning tests."""
import json
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from core.database import Base, ModelEndpoint, ProviderAuthSession
import routes.chatgpt_subscription_routes as csr
def _mem_db(monkeypatch):
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(bind=engine)
# Match production (core.database SessionLocal is autoflush=False): a pending
# db.delete(ep) is NOT flushed before the orphan-auth reference-count SELECT,
# which is exactly why _delete_orphaned_provider_auth needs exclude_ep_id.
TestSessionLocal = sessionmaker(bind=engine, autoflush=False)
monkeypatch.setattr(csr, "SessionLocal", TestSessionLocal)
return TestSessionLocal
def test_provision_creates_owner_scoped_auth_session_and_endpoint(monkeypatch):
TestSessionLocal = _mem_db(monkeypatch)
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5", "o4-mini"])
res = csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice")
assert res["name"] == "ChatGPT Subscription"
assert res["base_url"] == csr.chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
assert res["models"] == ["gpt-5.5", "o4-mini"]
db = TestSessionLocal()
try:
auth = db.query(ProviderAuthSession).first()
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == res["id"]).first()
assert auth is not None
assert auth.owner == "alice"
assert auth.provider == csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER
assert auth.access_token == "AT"
assert auth.refresh_token == "RT"
assert auth.auth_mode == "chatgpt"
assert ep is not None
assert ep.owner == "alice"
assert ep.api_key is None
assert ep.provider_auth_id == auth.id
assert ep.endpoint_kind == "api"
assert ep.model_refresh_mode == "manual"
assert ep.supports_tools is False
assert json.loads(ep.cached_models) == ["gpt-5.5", "o4-mini"]
finally:
db.close()
def test_provision_refreshes_existing_auth_session_and_endpoint(monkeypatch):
TestSessionLocal = _mem_db(monkeypatch)
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5"])
first = csr._provision_endpoint({"access_token": "OLD", "refresh_token": "OLD-RT"}, "bob")
second = csr._provision_endpoint({"access_token": "NEW", "refresh_token": "NEW-RT"}, "bob")
assert first["id"] == second["id"]
db = TestSessionLocal()
try:
auth_rows = db.query(ProviderAuthSession).filter(ProviderAuthSession.owner == "bob").all()
ep_rows = db.query(ModelEndpoint).filter(ModelEndpoint.owner == "bob").all()
assert len(auth_rows) == 1
assert len(ep_rows) == 1
assert auth_rows[0].access_token == "NEW"
assert auth_rows[0].refresh_token == "NEW-RT"
assert ep_rows[0].provider_auth_id == auth_rows[0].id
finally:
db.close()
def test_provision_rejects_missing_tokens(monkeypatch):
_mem_db(monkeypatch)
with pytest.raises(ValueError, match="missing access_token or refresh_token"):
csr._provision_endpoint({"access_token": "AT"}, "alice")
def test_provision_rejects_accounts_without_usable_models(monkeypatch):
_mem_db(monkeypatch)
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: [])
with pytest.raises(ValueError, match="no usable Codex models"):
csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice")
def _add_auth_and_endpoints(db, *, auth_id="auth1", ep_ids=("ep1",)):
db.add(ProviderAuthSession(
id=auth_id, provider=csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
owner="alice", base_url="https://chatgpt.com/backend-api/codex",
refresh_token="RT", auth_mode="chatgpt",
))
for ep_id in ep_ids:
db.add(ModelEndpoint(
id=ep_id, name="ChatGPT Subscription",
base_url="https://chatgpt.com/backend-api/codex",
provider_auth_id=auth_id, owner="alice",
))
db.commit()
def test_delete_orphaned_provider_auth_revokes_when_last_endpoint_removed(monkeypatch):
from routes.model_routes import _delete_orphaned_provider_auth
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
# Mirror the production delete route: db.delete(ep) is issued (but not yet
# flushed/committed) BEFORE the orphan check runs.
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
db.delete(ep1)
# ep1 (its only referencing endpoint) is being deleted, so the auth clears.
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is True
db.commit()
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
finally:
db.close()
def test_delete_orphaned_provider_auth_requires_exclude_ep_id_for_pending_delete(monkeypatch):
from routes.model_routes import _delete_orphaned_provider_auth
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
db.delete(ep1)
# Without exclude_ep_id, the un-flushed pending delete leaves ep1 visible
# to the reference-count SELECT (autoflush=False), so the helper must
# conservatively KEEP the auth row. This is the bug exclude_ep_id fixes.
assert _delete_orphaned_provider_auth(db, "auth1") is False
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
finally:
db.close()
def test_delete_orphaned_provider_auth_keeps_auth_while_another_endpoint_uses_it(monkeypatch):
from routes.model_routes import _delete_orphaned_provider_auth
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
# ep2 still references auth1, so deleting ep1 must NOT revoke it.
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
finally:
db.close()
def test_delete_orphaned_provider_auth_noop_without_auth_id(monkeypatch):
from routes.model_routes import _delete_orphaned_provider_auth
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
assert _delete_orphaned_provider_auth(db, None, exclude_ep_id="ep1") is False
finally:
db.close()
def test_delete_orphaned_provider_auth_noop_when_auth_row_missing(monkeypatch):
from routes.model_routes import _delete_orphaned_provider_auth
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
# Endpoint points at an auth_id whose ProviderAuthSession is already gone.
db.add(ModelEndpoint(
id="ep1", name="ChatGPT Subscription",
base_url="https://chatgpt.com/backend-api/codex",
provider_auth_id="ghost", owner="alice",
))
db.commit()
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
db.delete(ep1)
# No other endpoint references "ghost" and no auth row exists → no-op, no error.
assert _delete_orphaned_provider_auth(db, "ghost", exclude_ep_id="ep1") is False
finally:
db.close()
def _delete_route(monkeypatch, TestSessionLocal):
"""Resolve the real DELETE /model-endpoints/{ep_id} route, wired to the test DB.
Neutralizes the route's unrelated cleanup side effects (settings/prefs files,
in-memory session manager) so the test stays hermetic and focuses on the
provider-auth revocation wiring.
"""
import routes.model_routes as mr
import routes.prefs_routes as prefs_routes
import src.ai_interaction as ai_interaction
monkeypatch.setattr(mr, "SessionLocal", TestSessionLocal)
monkeypatch.setattr(mr, "require_admin", lambda request: None)
monkeypatch.setattr(mr, "_load_settings", lambda: {})
monkeypatch.setattr(mr, "_save_settings", lambda settings: None)
monkeypatch.setattr(prefs_routes, "_load", lambda: {})
monkeypatch.setattr(prefs_routes, "_save", lambda prefs: None)
monkeypatch.setattr(ai_interaction, "get_session_manager", lambda: None)
router = mr.setup_model_routes(model_discovery=None)
for route in router.routes:
if getattr(route, "path", "") == "/api/model-endpoints/{ep_id}" and "DELETE" in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError("DELETE /api/model-endpoints/{ep_id} not found")
def test_delete_endpoint_route_revokes_orphaned_provider_auth(monkeypatch):
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
finally:
db.close()
delete_endpoint = _delete_route(monkeypatch, TestSessionLocal)
result = delete_endpoint("ep1", object())
assert result["deleted"] is True
# The last (only) endpoint backed by auth1 is gone, so the route revokes it.
assert result["cleared_provider_auth"] is True
db = TestSessionLocal()
try:
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
assert db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() is None
finally:
db.close()
def test_delete_endpoint_route_keeps_auth_when_shared(monkeypatch):
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
finally:
db.close()
delete_endpoint = _delete_route(monkeypatch, TestSessionLocal)
result = delete_endpoint("ep1", object())
assert result["deleted"] is True
# ep2 still references auth1, so deleting ep1 must NOT revoke the credentials.
assert result["cleared_provider_auth"] is False
db = TestSessionLocal()
try:
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
finally:
db.close()
def test_delete_orphaned_provider_auth_revokes_only_after_last_of_several(monkeypatch):
from routes.model_routes import _delete_orphaned_provider_auth
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
# Delete ep1 first: ep2 still references auth1, so the row survives.
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
db.delete(ep1)
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False
db.commit()
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
# Now delete the last endpoint ep2: the auth row is finally cleared.
ep2 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep2").first()
db.delete(ep2)
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep2") is True
db.commit()
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
finally:
db.close()
+138
View File
@@ -0,0 +1,138 @@
"""Shared device-flow route helper regressions."""
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from routes import device_flow
def _client(monkeypatch, now_ref, start_flow, poll_flow):
store = device_flow.PendingDeviceFlowStore(time_func=lambda: now_ref[0])
router = device_flow.create_device_flow_router(
prefix="/api/test-device",
tags=["test-device"],
store=store,
start_flow=start_flow,
poll_flow=poll_flow,
)
app = FastAPI()
app.include_router(router)
monkeypatch.setattr(device_flow, "require_admin", lambda request: None)
return TestClient(app)
def _start(_request, _form):
return device_flow.DeviceFlowStart(
pending={"secret": "server-only", "owner": "alice"},
response={"user_code": "ABCD-EFGH", "verification_uri": "https://example.test/device"},
interval=5,
expires_in=20,
)
def test_pending_poll_is_throttled_until_interval(monkeypatch):
now = [100.0]
calls = []
def poll(_request, pending):
calls.append(dict(pending))
return device_flow.DeviceFlowPoll.pending()
client = _client(monkeypatch, now, _start, poll)
start = client.post("/api/test-device/device/start").json()
first = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
assert first.json() == {"status": "pending"}
assert calls == [{"secret": "server-only", "owner": "alice"}]
second = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
assert second.json() == {"status": "pending"}
assert len(calls) == 1
now[0] += 5
third = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
assert third.json() == {"status": "pending"}
assert len(calls) == 2
def test_slow_down_updates_poll_interval(monkeypatch):
now = [100.0]
calls = []
def poll(_request, _pending):
calls.append(now[0])
if len(calls) == 1:
return device_flow.DeviceFlowPoll.slow_down(interval=10)
return device_flow.DeviceFlowPoll.authorized({"id": "ep1", "models": ["gpt-4o"]})
client = _client(monkeypatch, now, _start, poll)
poll_id = client.post("/api/test-device/device/start").json()["poll_id"]
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"}
now[0] += 9
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"}
assert len(calls) == 1
now[0] += 1
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {
"status": "authorized",
"endpoint": {"id": "ep1", "models": ["gpt-4o"]},
}
def test_authorized_and_failed_polls_remove_pending_session(monkeypatch):
now = [100.0]
outcomes = [
device_flow.DeviceFlowPoll.authorized({"id": "ep1"}),
device_flow.DeviceFlowPoll.failed("access_denied"),
]
def poll(_request, _pending):
return outcomes.pop(0)
client = _client(monkeypatch, now, _start, poll)
first = client.post("/api/test-device/device/start").json()["poll_id"]
second = client.post("/api/test-device/device/start").json()["poll_id"]
assert client.post("/api/test-device/device/poll", data={"poll_id": first}).json()["status"] == "authorized"
assert client.post("/api/test-device/device/poll", data={"poll_id": first}).status_code == 404
assert client.post("/api/test-device/device/poll", data={"poll_id": second}).json() == {
"status": "failed",
"error": "access_denied",
}
assert client.post("/api/test-device/device/poll", data={"poll_id": second}).status_code == 404
def test_cancel_and_expiry_remove_pending_session(monkeypatch):
now = [100.0]
def poll(_request, _pending):
return device_flow.DeviceFlowPoll.pending()
client = _client(monkeypatch, now, _start, poll)
cancelled = client.post("/api/test-device/device/start").json()["poll_id"]
assert client.post("/api/test-device/device/cancel", data={"poll_id": cancelled}).json() == {"status": "cancelled"}
assert client.post("/api/test-device/device/poll", data={"poll_id": cancelled}).status_code == 404
expired = client.post("/api/test-device/device/start").json()["poll_id"]
now[0] += 21
assert client.post("/api/test-device/device/poll", data={"poll_id": expired}).status_code == 404
def test_routes_are_admin_gated(monkeypatch):
now = [100.0]
def poll(_request, _pending):
return device_flow.DeviceFlowPoll.pending()
client = _client(monkeypatch, now, _start, poll)
def deny(_request):
raise HTTPException(403, "admin required")
monkeypatch.setattr(device_flow, "require_admin", deny)
assert client.post("/api/test-device/device/start").status_code == 403
assert client.post("/api/test-device/device/poll", data={"poll_id": "missing"}).status_code == 403
assert client.post("/api/test-device/device/cancel", data={"poll_id": "missing"}).status_code == 403
+92 -23
View File
@@ -25,32 +25,36 @@ from unittest.mock import MagicMock
import httpx
import pytest
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state
# Match test_model_routes.py: if another test stubbed src.endpoint_resolver
# during collection, drop the stub so the real URL helpers load here.
clear_fake_endpoint_resolver_modules()
with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"):
# Match test_model_routes.py: if another test stubbed src.endpoint_resolver
# during collection, drop the stub so the real URL helpers load here.
clear_fake_endpoint_resolver_modules()
if "core.database" not in sys.modules:
_core_db = types.ModuleType("core.database")
for _name in [
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer",
]:
setattr(_core_db, _name, MagicMock())
sys.modules["core.database"] = _core_db
if "core.database" not in sys.modules:
_core_db = types.ModuleType("core.database")
for _name in [
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer",
"ProviderAuthSession", "Base",
]:
setattr(_core_db, _name, MagicMock())
_core_db.utcnow_naive = MagicMock()
sys.modules["core.database"] = _core_db
import routes.model_routes as model_routes
import src.endpoint_resolver as endpoint_resolver
from routes.model_routes import (
_probe_endpoint,
_ping_endpoint,
_probe_single_model,
_classify_endpoint,
_rewrite_loopback_for_docker,
_PROVIDER_CURATED,
)
import routes.model_routes as model_routes
import src.endpoint_resolver as endpoint_resolver
from routes.model_routes import (
_probe_endpoint,
_ping_endpoint,
_probe_single_model,
_resolve_probe_key,
_classify_endpoint,
_rewrite_loopback_for_docker,
_PROVIDER_CURATED,
)
def _patch_resolve(monkeypatch):
@@ -117,6 +121,26 @@ class TestProbeEndpointParsing:
)
assert _probe_endpoint("https://api.example.com/v1") == []
def test_chatgpt_subscription_probe_uses_discovery_only(self, monkeypatch):
_patch_resolve(monkeypatch)
calls = []
def fake_fetch(access_token, timeout=5):
calls.append((access_token, timeout))
return ["gpt-5.5"]
monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", fake_fetch)
assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS", timeout=7) == ["gpt-5.5"]
assert calls == [("ACCESS", 7)]
def test_chatgpt_subscription_probe_without_discovery_returns_empty(self, monkeypatch):
_patch_resolve(monkeypatch)
monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", lambda access_token, timeout=5: [])
assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS") == []
assert _probe_endpoint("https://chatgpt.com/backend-api/codex") == []
# ── _ping_endpoint: reachability classification ──
@@ -321,6 +345,51 @@ class TestProbeSingleModel:
_probe_single_model("https://api.anthropic.com/v1", "sk-ant", "claude-sonnet-4-5", with_tools=True)
assert "input_schema" in captured["payload"]["tools"][0]
def test_chatgpt_subscription_skips_completion_probe(self, monkeypatch):
# This provider speaks the Responses/Codex API. A chat-completions probe
# would 400 and (via the re-probe flow) hide every model, so it must be
# short-circuited as discovery-only without any HTTP call.
_patch_resolve(monkeypatch)
def boom(*args, **kwargs):
raise AssertionError("must not send a completion probe for chatgpt-subscription")
monkeypatch.setattr(model_routes.httpx, "post", boom)
result = _probe_single_model("https://chatgpt.com/backend-api/codex", None, "gpt-5.1-codex")
assert result["status"] == "ok"
assert result.get("skipped") is True
# Pin the full documented return shape — downstream JSON/UI reads latency_ms.
assert result["latency_ms"] == 0
# ── _resolve_probe_key: static key vs provider-auth runtime token ──
class TestResolveProbeKey:
def test_static_endpoint_uses_api_key(self):
ep = types.SimpleNamespace(id="e1", api_key="sk-static", provider_auth_id=None, owner=None)
assert _resolve_probe_key(ep) == "sk-static"
def test_provider_auth_endpoint_resolves_runtime_token(self, monkeypatch):
ep = types.SimpleNamespace(id="e2", api_key=None, provider_auth_id="auth123", owner="alice")
seen = {}
def fake_runtime(endpoint, owner=None):
seen["owner"] = owner
return ("https://chatgpt.com/backend-api/codex", "live-bearer")
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", fake_runtime)
assert _resolve_probe_key(ep) == "live-bearer"
assert seen["owner"] == "alice"
def test_provider_auth_resolution_failure_returns_none(self, monkeypatch):
ep = types.SimpleNamespace(id="e3", api_key=None, provider_auth_id="auth123", owner=None)
def boom(endpoint, owner=None):
raise RuntimeError("reauth required")
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", boom)
assert _resolve_probe_key(ep) is None
# ── _classify_endpoint: Tailscale CGNAT range ──
+22
View File
@@ -75,6 +75,28 @@ def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
assert payload["temperature"] == 1.2
def test_chatgpt_subscription_payload_uses_max_output_tokens():
payload = llm_core._build_chatgpt_responses_payload(
"gpt-5.1-codex",
[{"role": "user", "content": "Say OK"}],
temperature=0.2,
max_tokens=37,
)
assert payload["max_output_tokens"] == 37
def test_chatgpt_subscription_payload_omits_empty_max_output_tokens():
payload = llm_core._build_chatgpt_responses_payload(
"gpt-5.1-codex",
[{"role": "user", "content": "Say OK"}],
temperature=0.2,
max_tokens=0,
)
assert "max_output_tokens" not in payload
def _anthropic_payload(temperature):
return llm_core._build_anthropic_payload(
"claude-3-5-sonnet",
+92 -42
View File
@@ -11,49 +11,51 @@ from types import SimpleNamespace
import httpx
import pytest
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state
# Other tests stub this module during collection. These helper tests need
# the real URL normalization helpers so Anthropic /v1 handling is covered.
clear_fake_endpoint_resolver_modules()
with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"):
# Other tests stub this module during collection. These helper tests need
# the real URL normalization helpers so Anthropic /v1 handling is covered.
clear_fake_endpoint_resolver_modules()
if "core.database" not in sys.modules:
_core_db = types.ModuleType("core.database")
for _name in [
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun",
"McpServer",
]:
setattr(_core_db, _name, MagicMock())
sys.modules["core.database"] = _core_db
if "core.database" not in sys.modules:
_core_db = types.ModuleType("core.database")
for _name in [
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun",
"McpServer", "ProviderAuthSession", "Base",
]:
setattr(_core_db, _name, MagicMock())
_core_db.utcnow_naive = MagicMock()
sys.modules["core.database"] = _core_db
import routes.model_routes as model_routes
import src.database as src_database
import src.endpoint_resolver as endpoint_resolver
import src.llm_core as llm_core
from routes.model_routes import (
_match_provider_curated,
_curate_models,
_visible_models,
_normalize_model_ids,
_api_key_fingerprint,
_is_chat_model,
_classify_endpoint,
_effective_endpoint_kind,
_probe_endpoint,
_ping_endpoint,
_parse_model_list,
_normalize_refresh_mode,
_truthy,
_speech_settings_using_endpoint,
_clear_speech_settings_for_endpoint,
_endpoint_settings_using_endpoint,
_clear_endpoint_settings_for_endpoint,
_clear_user_pref_endpoint_refs,
_PROVIDER_CURATED,
)
from src.llm_core import ANTHROPIC_MODELS
import routes.model_routes as model_routes
import src.database as src_database
import src.endpoint_resolver as endpoint_resolver
import src.llm_core as llm_core
from routes.model_routes import (
_match_provider_curated,
_curate_models,
_visible_models,
_normalize_model_ids,
_api_key_fingerprint,
_is_chat_model,
_classify_endpoint,
_effective_endpoint_kind,
_probe_endpoint,
_ping_endpoint,
_parse_model_list,
_normalize_refresh_mode,
_truthy,
_speech_settings_using_endpoint,
_clear_speech_settings_for_endpoint,
_endpoint_settings_using_endpoint,
_clear_endpoint_settings_for_endpoint,
_clear_user_pref_endpoint_refs,
_PROVIDER_CURATED,
)
from src.llm_core import ANTHROPIC_MODELS
# ── speech endpoint settings ──
@@ -687,8 +689,7 @@ class _PinnedFakeRequest:
def _get_route(path, method):
from routes.model_routes import setup_model_routes
router = setup_model_routes(model_discovery=None)
router = model_routes.setup_model_routes(model_discovery=None)
for route in router.routes:
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
return route.endpoint
@@ -787,6 +788,55 @@ def test_reprobe_preserves_pinned_models(monkeypatch):
assert json.loads(ep.cached_models) == ["m1"]
def test_reprobe_chatgpt_subscription_does_not_hide_models(monkeypatch):
# The whole point of the _probe_single_model short-circuit is that re-probing
# a chatgpt-subscription endpoint must NOT mark every (un-probeable) model as
# failed and write them all into hidden_models. Assert that end-to-end at the
# route level, with the REAL _probe_single_model doing the skip.
ep = _make_endpoint(
base_url="https://chatgpt.com/backend-api/codex",
api_key=None,
hidden_models=json.dumps(["stale-hidden"]),
)
db = _PinnedFakeDb([ep])
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
monkeypatch.setattr(model_routes, "_normalize_base", lambda url: url.rstrip("/"))
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["gpt-5.1-codex", "gpt-5.1"])
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
# Any completion probe would be a bug for this provider.
monkeypatch.setattr(
model_routes.httpx, "post",
lambda *a, **k: (_ for _ in ()).throw(AssertionError("must not probe chatgpt-subscription")),
)
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
response = endpoint("ep1", _PinnedFakeRequest())
chunks = []
async def _drain():
async for chunk in response.body_iterator:
chunks.append(chunk.decode() if isinstance(chunk, bytes) else chunk)
asyncio.run(_drain())
events = []
for chunk in chunks:
for line in chunk.splitlines():
if line.startswith("data: "):
events.append(json.loads(line[len("data: "):]))
done = next(e for e in events if e.get("type") == "probe_done")
results = [e for e in events if e.get("type") == "probe_result"]
# Every model was skipped as ok; none failed → nothing hidden.
assert done["hidden"] == 0
assert done["ok"] == len(results) == 2
assert all(r["status"] == "ok" and r.get("skipped") is True for r in results)
# The stale hidden_models is cleared, not repopulated with every model.
assert ep.hidden_models is None
def test_visible_models_handles_malformed_strings():
# Non-JSON cached/pinned strings are treated as comma/newline lists and
# never raise; a malformed hidden string is normalized too.
+10
View File
@@ -42,6 +42,10 @@ class TestHostMatch:
class TestDetectProviderRealHosts:
def test_chatgpt_subscription_codex_backend(self):
assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex") == "chatgpt-subscription"
assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex/responses") == "chatgpt-subscription"
def test_anthropic(self):
assert llm_core._detect_provider("https://api.anthropic.com") == "anthropic"
@@ -93,6 +97,12 @@ class TestBuildersRejectLookalikeHosts:
def test_real_anthropic_chat(self):
assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages"
def test_chatgpt_subscription_chat_uses_responses(self):
assert build_chat_url("https://chatgpt.com/backend-api/codex") == "https://chatgpt.com/backend-api/codex/responses"
def test_chatgpt_subscription_models_uses_no_live_probe(self):
assert build_models_url("https://chatgpt.com/backend-api/codex") is None
def test_lookalike_anthropic_chat_is_openai(self):
assert build_chat_url("https://notanthropic.com") == "https://notanthropic.com/chat/completions"
+157
View File
@@ -0,0 +1,157 @@
"""Node-driven tests for the shared provider device-flow runner."""
import json
import shutil
import subprocess
from pathlib import Path
import pytest
_REPO = Path(__file__).resolve().parent.parent
_HELPER = _REPO / "static" / "js" / "providerDeviceFlow.js"
pytestmark = pytest.mark.skipif(not shutil.which("node"), reason="node not on PATH")
def _run_node(script: str):
proc = subprocess.run(
["node", "--input-type=module"],
input=script,
capture_output=True,
text=True,
cwd=str(_REPO),
timeout=30,
)
assert proc.returncode == 0, proc.stderr
return json.loads(proc.stdout.strip())
def test_copilot_success_uses_complete_verification_uri():
js = f"""
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
const calls = [];
const opened = [];
let polls = 0;
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
const fetchImpl = async (url) => {{
calls.push(url);
if (url.endsWith('/device/start')) {{
return response(true, 200, {{
poll_id: 'poll-1',
user_code: 'GH-CODE',
verification_uri: 'https://github.com/login/device',
verification_uri_complete: 'https://github.com/login/device?user_code=GH-CODE',
interval: 2,
expires_in: 30,
}});
}}
polls += 1;
return response(true, 200, polls === 1
? {{ status: 'pending' }}
: {{ status: 'authorized', endpoint: {{ id: 'ep1', models: ['gpt-4o'] }} }}
);
}};
const result = await runProviderDeviceFlow('copilot', {{
fetchImpl,
openWindow: (url) => opened.push(url),
sleep: async () => {{}},
now: () => 0,
}});
console.log(JSON.stringify({{ result, calls, opened }}));
"""
out = _run_node(js)
assert out["result"]["status"] == "authorized"
assert out["result"]["endpoint"]["id"] == "ep1"
assert out["opened"] == ["https://github.com/login/device?user_code=GH-CODE"]
assert out["calls"] == ["/api/copilot/device/start", "/api/copilot/device/poll", "/api/copilot/device/poll"]
def test_chatgpt_success_uses_plain_verification_uri():
js = f"""
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
const opened = [];
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
const fetchImpl = async (url) => {{
if (url.endsWith('/device/start')) {{
return response(true, 200, {{
poll_id: 'poll-1',
user_code: 'OA-CODE',
verification_uri: 'https://auth.openai.com/codex/device',
interval: 2,
expires_in: 30,
}});
}}
return response(true, 200, {{ status: 'authorized', endpoint: {{ id: 'chatgpt', models: ['gpt-5.5'] }} }});
}};
const result = await runProviderDeviceFlow('chatgpt-subscription', {{
fetchImpl,
openWindow: (url) => opened.push(url),
sleep: async () => {{}},
now: () => 0,
}});
console.log(JSON.stringify({{ result, opened }}));
"""
out = _run_node(js)
assert out["result"]["status"] == "authorized"
assert out["opened"] == ["https://auth.openai.com/codex/device"]
def test_start_errors_surface_backend_detail():
js = f"""
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
try {{
await runProviderDeviceFlow('copilot', {{
fetchImpl: async () => response(false, 502, {{ detail: 'GitHub device-code request failed: upstream down' }}),
openWindow: () => {{}},
sleep: async () => {{}},
now: () => 0,
}});
}} catch (err) {{
console.log(JSON.stringify({{ message: err.message }}));
}}
"""
out = _run_node(js)
assert out["message"] == "GitHub device-code request failed: upstream down"
def test_thrown_fetch_errors_are_preserved():
js = f"""
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
try {{
await runProviderDeviceFlow('chatgpt-subscription', {{
fetchImpl: async () => {{ throw new Error('network offline'); }},
openWindow: () => {{}},
sleep: async () => {{}},
now: () => 0,
}});
}} catch (err) {{
console.log(JSON.stringify({{ message: err.message }}));
}}
"""
out = _run_node(js)
assert out["message"] == "network offline"
def test_expired_flow_returns_expired_status():
js = f"""
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
let currentTime = 0;
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
const result = await runProviderDeviceFlow('copilot', {{
fetchImpl: async (url) => url.endsWith('/device/start')
? response(true, 200, {{
poll_id: 'poll-1',
user_code: 'GH-CODE',
verification_uri: 'https://github.com/login/device',
interval: 2,
expires_in: 1,
}})
: response(true, 200, {{ status: 'pending' }}),
openWindow: () => {{}},
sleep: async () => {{ currentTime += 2000; }},
now: () => currentTime,
}});
console.log(JSON.stringify(result));
"""
out = _run_node(js)
assert out == {"status": "expired"}
+27 -1
View File
@@ -24,7 +24,7 @@ _sd = types.ModuleType("src.database")
_sd.ModelEndpoint = MagicMock()
sys.modules.setdefault("src.database", _sd)
from routes.research_routes import _owned_enabled_endpoint # noqa: E402
from routes.research_routes import _owned_enabled_endpoint, _resolve_endpoint_runtime # noqa: E402
class _Predicate:
@@ -129,3 +129,29 @@ def test_null_owner_is_legacy_single_user_noop():
rows = [_ep("ep-x", "bob"), _ep("ep-y", "alice")]
ep = _resolve(rows, None, "ep-x")
assert ep is not None and ep.id == "ep-x"
def test_runtime_resolution_uses_provider_auth_for_chatgpt_subscription(monkeypatch):
ep = SimpleNamespace(
id="ep-chatgpt",
owner="alice",
base_url="https://chatgpt.com/backend-api/codex",
api_key=None,
provider_auth_id="auth-1",
cached_models='["gpt-5.5"]',
hidden_models=None,
)
monkeypatch.setattr(
"src.chatgpt_subscription.resolve_runtime_credentials",
lambda auth_id, owner=None: {
"base_url": "https://chatgpt.com/backend-api/codex",
"api_key": "fresh-access-token",
},
)
url, model, headers = _resolve_endpoint_runtime(ep, owner="alice", model="")
assert url == "https://chatgpt.com/backend-api/codex/responses"
assert model == "gpt-5.5"
assert headers["Authorization"] == "Bearer fresh-access-token"
+215
View File
@@ -0,0 +1,215 @@
"""resolve_session_auth must not persist the ChatGPT Subscription bearer.
The ChatGPT Subscription access token is a short-lived OAuth bearer re-resolved
(and refreshed) on every request. resolve_session_auth() may set it on the
in-memory session for the current request, but it must never write it back into
the sessions table otherwise the live token sits at rest as
"Authorization: Bearer ...". Only the encrypted refresh token in
ProviderAuthSession is allowed to persist.
"""
import types
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import routes.chat_helpers as chat_helpers
import src.endpoint_resolver as endpoint_resolver
from core.database import Base, ModelEndpoint, Session as DbSession
_CODEX_BASE = "https://chatgpt.com/backend-api/codex"
def _mem_db(monkeypatch):
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(bind=engine)
# Match production SessionLocal (core.database) which is autoflush=False.
TestSessionLocal = sessionmaker(bind=engine, autoflush=False)
monkeypatch.setattr(chat_helpers, "SessionLocal", TestSessionLocal)
return TestSessionLocal
def test_chatgpt_subscription_auth_is_not_written_to_sessions_table(monkeypatch):
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
db.add(ModelEndpoint(
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
))
db.add(DbSession(
id="sess1", name="chat", endpoint_url=_CODEX_BASE,
model="gpt-5.1-codex", owner="alice", headers={},
))
db.commit()
finally:
db.close()
# A live access token is resolved at request time.
monkeypatch.setattr(
endpoint_resolver, "resolve_endpoint_runtime",
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
)
sess = types.SimpleNamespace(
id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex",
owner="alice", headers={},
)
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
# In-memory session got request-local auth for this request...
assert any(k.lower() == "authorization" for k in sess.headers)
assert sess.headers["Authorization"] == "Bearer live-access-token"
# ...but the DB row must NOT have the bearer persisted.
db = TestSessionLocal()
try:
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
stored = row.headers or {}
assert not any(k.lower() == "authorization" for k in stored), (
f"ChatGPT bearer leaked into sessions table: {stored}"
)
finally:
db.close()
def test_non_subscription_auth_is_still_persisted_to_sessions_table(monkeypatch):
"""The early-return must be scoped to ChatGPT Subscription only.
Ordinary endpoints rely on resolve_session_auth() persisting the resolved
headers into the sessions table so they aren't re-resolved on every request.
If the is_chatgpt_subscription guard ever widened, this would silently break;
this test pins the persistence path as still reached for normal endpoints.
"""
base = "https://api.example.com/v1"
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
db.add(ModelEndpoint(
id="ep1", name="Generic", base_url=base,
owner="alice", is_enabled=True, api_key="sk-static",
))
db.add(DbSession(
id="sess1", name="chat", endpoint_url=base,
model="gpt-x", owner="alice", headers={},
))
db.commit()
finally:
db.close()
monkeypatch.setattr(
endpoint_resolver, "resolve_endpoint_runtime",
lambda ep, owner=None: (base, "sk-static"),
)
sess = types.SimpleNamespace(
id="sess1", endpoint_url=base, model="gpt-x", owner="alice", headers={},
)
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
# In-memory session got auth...
assert any(k.lower() in ("authorization", "x-api-key") for k in sess.headers)
# ...AND it was persisted to the DB row (the normal, non-subscription path).
db = TestSessionLocal()
try:
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
stored = row.headers or {}
assert any(k.lower() in ("authorization", "x-api-key") for k in stored), (
f"non-subscription auth was not persisted: {stored}"
)
finally:
db.close()
def test_chatgpt_subscription_clears_previously_persisted_bearer(monkeypatch):
"""A bearer left at rest by an older code path is stripped on next resolve."""
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
db.add(ModelEndpoint(
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
))
# Simulate the leak: a stale bearer already sitting in the sessions table.
db.add(DbSession(
id="sess1", name="chat", endpoint_url=_CODEX_BASE,
model="gpt-5.1-codex", owner="alice",
headers={"Authorization": "Bearer stale-leaked-token"},
))
db.commit()
finally:
db.close()
monkeypatch.setattr(
endpoint_resolver,
"resolve_endpoint_runtime",
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
)
sess = types.SimpleNamespace(
id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex",
owner="alice", headers={},
)
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
# The stale bearer must have been stripped from the DB row.
db = TestSessionLocal()
try:
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
stored = row.headers or {}
assert not any(k.lower() == "authorization" for k in stored), (
f"stale ChatGPT bearer was not cleared: {stored}"
)
finally:
db.close()
def test_chatgpt_subscription_fallback_auth_is_not_written_to_sessions_table(monkeypatch):
"""Fallback endpoint selection must keep the resolved bearer request-local."""
TestSessionLocal = _mem_db(monkeypatch)
db = TestSessionLocal()
try:
db.add(ModelEndpoint(
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
cached_models='["gpt-5.1-codex"]',
))
db.add(DbSession(
id="sess1", name="chat", endpoint_url="https://old.example/v1",
model="old-model", owner="alice", headers={},
))
db.commit()
finally:
db.close()
monkeypatch.setattr(
endpoint_resolver,
"resolve_endpoint_runtime",
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
)
sess = types.SimpleNamespace(
id="sess1", endpoint_url="https://old.example/v1", model="old-model",
owner="alice", headers={},
)
result = chat_helpers.try_fallback_endpoint(sess, "sess1")
assert result == {
"model": "gpt-5.1-codex",
"endpoint_url": _CODEX_BASE + "/responses",
"endpoint_name": "ChatGPT Subscription",
}
assert sess.headers["Authorization"] == "Bearer live-access-token"
db = TestSessionLocal()
try:
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
assert row.model == "gpt-5.1-codex"
assert row.endpoint_url == _CODEX_BASE + "/responses"
stored = row.headers or {}
assert not any(k.lower() == "authorization" for k in stored), (
f"ChatGPT fallback bearer leaked into sessions table: {stored}"
)
finally:
db.close()
+1 -1
View File
@@ -386,7 +386,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {})
monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester")
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model: None)
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
+9
View File
@@ -137,3 +137,12 @@ def test_unauthenticated_caller_rejected(monkeypatch):
with pytest.raises(HTTPException) as exc:
SR._verify_session_owner(req, "sid")
assert exc.value.status_code == 401
def test_auth_disabled_allows_owner_stamped_session(monkeypatch):
monkeypatch.setenv("AUTH_ENABLED", "false")
monkeypatch.setattr(SR, "SessionLocal", _session_local_returning("admin"))
req = _req(api_token=False, current_user=None)
# Single-user/auth-disabled mode should verify existence but not compare owner.
SR._verify_session_owner(req, "sid-owned-by-admin")
+42
View File
@@ -0,0 +1,42 @@
"""Static regressions for `/setup` account sign-in providers."""
from pathlib import Path
_REPO = Path(__file__).resolve().parent.parent
_SLASH = (_REPO / "static" / "js" / "slashCommands.js").read_text(encoding="utf-8")
def _between(src: str, start: str, end: str) -> str:
start_idx = src.index(start)
end_idx = src.index(end, start_idx)
return src[start_idx:end_idx]
def test_setup_guide_lists_account_sign_in_providers():
guide_block = _between(_SLASH, "function _showSetupEndpointChoices", "async function _hasConfiguredModels")
assert 'data-setup-provider="' in _SLASH
assert "provider.key" in _SLASH
assert "'copilot'" in _SLASH
assert "'chatgpt-subscription'" in _SLASH
assert "/setup copilot" in _SLASH
assert "/setup chatgpt-subscription" in _SLASH
def test_clicking_account_sign_in_provider_prefills_setup_command_not_api_key():
click_block = _between(_SLASH, "const providerEl = e.target.closest('.setup-clickable-provider')", "// 3. Check")
assert "providerEl.dataset.setupProvider" in click_block
assert "providerEl.dataset.setupKind === 'device-auth'" in click_block
assert "'/setup ' + providerKey" in click_block
def test_setup_chatgpt_subscription_prints_auth_url_without_auto_opening_tab():
flow_block = _between(_SLASH, "async function _setupProviderDeviceFlow", "async function _cmdSetup")
assert "providerKey === 'chatgpt-subscription'" in flow_block
assert "Open this URL" in flow_block
assert "authUrl" in flow_block
assert 'href="\' + uiModule.esc(authUrl || \'\') + \'"' in flow_block
assert "if (providerKey === 'chatgpt-subscription') return;" in flow_block
+17
View File
@@ -0,0 +1,17 @@
"""Static regressions for slash autocomplete command-group expansion."""
from pathlib import Path
_REPO = Path(__file__).resolve().parent.parent
_AC = (_REPO / "static" / "js" / "slashAutocomplete.js").read_text(encoding="utf-8")
def test_exact_parent_command_expands_subcommands_before_top_level_row_cap():
assert "function _exactCommandGroupItems" in _AC
assert "entry.token.toLowerCase().startsWith(prefix)" in _AC
assert "items = groupItems.slice(0, MAX_VISIBLE);" in _AC
def test_setup_group_has_room_for_chatgpt_subscription_suggestion():
assert "const MAX_VISIBLE = 14;" in _AC