mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
feat: add ChatGPT Subscription provider (#2876)
* feat: Add ChatGPT Subscription support and related features - Introduced a new provider option for ChatGPT Subscription in the endpoint selection UI. - Implemented OAuth flow for ChatGPT Subscription sign-in, including polling for authorization status. - Updated admin interface to handle ChatGPT Subscription, including disabling API key input and providing user guidance. - Enhanced cost tracking logic to differentiate between subscription and non-subscription endpoints. - Added new slash commands for managing skills, including listing, searching, and invoking skills. - Implemented caching for skill catalog to optimize performance. - Updated tests to cover new ChatGPT Subscription functionality and ensure proper endpoint probing. - Refactored existing code to accommodate new features and improve maintainability. * refactor: share provider device-flow setup - reuse one device-flow backend for Copilot and ChatGPT Subscription - add one frontend device-flow helper for Settings and /setup - put GitHub Copilot back into Add Models, now as a dropdown option - make provider selection just select; clicking Add starts sign-in - stop ChatGPT Subscription setup from opening auth tabs automatically - make /setup copilot and /setup chatgpt-subscription work from chat - show ChatGPT Subscription in the /setup suggestions - show the real error message when setup fails - add focused tests for the shared flow and setup UI * feat(chatgpt-subscription): harden credential lifecycle and streamline auth UX Backend: - Resolve runtime bearer for provider-auth endpoints at probe time via a shared _resolve_probe_key() that delegates to resolve_endpoint_runtime, applied across all probe/refresh call sites. - Skip live completion probes and health pings for discovery-only providers (centralized behind _is_discovery_only_provider) — the Codex/Responses API has no such endpoints, so status is derived from cached models. - Never persist the short lived ChatGPT bearer to the plaintext sessions table; proactively clear any stale bearer left by an earlier code path. - Revoke orphaned ProviderAuthSession credentials when the last endpoint backing them is deleted (_delete_orphaned_provider_auth), surfaced via cleared_provider_auth in the delete response. Frontend (admin.js): - Auto-start the device-auth flow on provider selection so the authorization panel (code + Authorize) shows immediately instead of behind a "Sign in" click. - Remove the redundant top button for device auth providers, move retry into the panel via an inline "Try again". - Drop the self-evident hint text and add an execCommand clipboard fallback so Copy works in non-secure (HTTP/LAN) contexts. * fix: harden chatgpt subscription provider * chore: remove PR media from branch * Fix chatgpt subscription recovery and token handling --------- Co-authored-by: 5p00kyy <admin@5p00ky.dev>
This commit is contained in:
@@ -598,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery))
|
||||
from routes.copilot_routes import setup_copilot_routes
|
||||
app.include_router(setup_copilot_routes())
|
||||
|
||||
# ChatGPT Subscription device-flow login
|
||||
from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes
|
||||
app.include_router(setup_chatgpt_subscription_routes())
|
||||
|
||||
# TTS
|
||||
from routes.tts_routes import setup_tts_routes
|
||||
app.include_router(setup_tts_routes(tts_service))
|
||||
|
||||
@@ -361,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base):
|
||||
# is the historical default. When non-null, the model picker only shows
|
||||
# the endpoint to that user (admins always see everything).
|
||||
owner = Column(String, nullable=True, index=True)
|
||||
# Optional OAuth/session-backed credential row. Used by subscription-backed
|
||||
# providers that need refresh tokens instead of a static API key.
|
||||
provider_auth_id = Column(String, nullable=True, index=True)
|
||||
|
||||
|
||||
class ProviderAuthSession(TimestampMixin, Base):
|
||||
"""Encrypted OAuth/session credentials for refresh-aware model providers."""
|
||||
__tablename__ = "provider_auth_sessions"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
provider = Column(String, nullable=False, index=True)
|
||||
owner = Column(String, nullable=True, index=True)
|
||||
label = Column(String, nullable=True)
|
||||
base_url = Column(String, nullable=False)
|
||||
access_token = Column(EncryptedText, nullable=True)
|
||||
refresh_token = Column(EncryptedText, nullable=True)
|
||||
last_refresh = Column(DateTime, nullable=True)
|
||||
auth_mode = Column(String, nullable=True)
|
||||
|
||||
class McpServer(TimestampMixin, Base):
|
||||
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
||||
@@ -801,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column():
|
||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||
|
||||
|
||||
def _migrate_add_provider_auth_id_column():
|
||||
"""Add provider_auth_id column to model_endpoints if it doesn't exist."""
|
||||
import sqlite3
|
||||
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||
if not os.path.exists(db_path):
|
||||
return
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
if columns and "provider_auth_id" not in columns:
|
||||
conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||
conn.commit()
|
||||
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||
|
||||
|
||||
def _migrate_add_model_type_column():
|
||||
"""Add model_type column to model_endpoints if it doesn't exist."""
|
||||
import sqlite3
|
||||
@@ -1599,6 +1637,7 @@ def init_db():
|
||||
_migrate_add_model_type_column()
|
||||
_migrate_add_model_endpoint_refresh_columns()
|
||||
_migrate_add_model_endpoint_owner_column()
|
||||
_migrate_add_provider_auth_id_column()
|
||||
_migrate_add_supports_tools_column()
|
||||
_migrate_add_task_run_model_column()
|
||||
_migrate_add_owner_column()
|
||||
|
||||
+86
-28
@@ -196,14 +196,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||
"""
|
||||
import requests as _req
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
||||
from src.endpoint_resolver import (
|
||||
build_chat_url,
|
||||
build_headers,
|
||||
build_models_url,
|
||||
normalize_base,
|
||||
resolve_endpoint_runtime,
|
||||
)
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
|
||||
current_url = sess.endpoint_url or ""
|
||||
owner = getattr(sess, "owner", None)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(
|
||||
q = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True
|
||||
).all()
|
||||
)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -212,26 +224,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
# Skip current endpoint
|
||||
if current_url and base in current_url:
|
||||
continue
|
||||
# Quick ping
|
||||
ping_url = build_models_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
try:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
ping_url = build_models_url(base)
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
if ping_url:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
models = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
models = json.loads(ep.cached_models or "[]")
|
||||
if not models:
|
||||
continue
|
||||
# Found a working endpoint — update session
|
||||
new_model = models[0]
|
||||
chat_url = build_chat_url(base)
|
||||
new_headers = build_headers(ep.api_key, base)
|
||||
new_headers = build_headers(api_key, base)
|
||||
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
|
||||
|
||||
sess.model = new_model
|
||||
sess.endpoint_url = chat_url
|
||||
@@ -243,7 +262,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||
"model": new_model,
|
||||
"endpoint_url": chat_url,
|
||||
"headers": json.dumps(new_headers),
|
||||
"headers": persisted_headers,
|
||||
})
|
||||
_db.commit()
|
||||
finally:
|
||||
@@ -336,16 +355,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _has_auth_keys(headers) -> bool:
|
||||
"""True if a headers dict carries an Authorization/x-api-key entry."""
|
||||
return isinstance(headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in headers
|
||||
)
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||
)
|
||||
if has_auth:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
|
||||
except Exception:
|
||||
is_chatgpt_subscription = False
|
||||
has_auth = _has_auth_keys(sess.headers)
|
||||
if has_auth and not is_chatgpt_subscription:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers, normalize_base
|
||||
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
|
||||
db = SessionLocal()
|
||||
try:
|
||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||
@@ -361,10 +390,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||
for ep in q.all():
|
||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||
continue
|
||||
if not ep.api_key:
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
|
||||
return
|
||||
if not api_key:
|
||||
# No usable key (e.g. ChatGPT Subscription needs re-auth).
|
||||
return
|
||||
sess.headers = build_headers(api_key, base)
|
||||
if is_chatgpt_subscription:
|
||||
# The bearer is short-lived and re-resolved per request, so it
|
||||
# stays request-local and is never written to the plaintext
|
||||
# sessions.headers column. Proactively strip any bearer an
|
||||
# older code path may have persisted so it does not linger.
|
||||
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
stale_q = stale_q.filter(DBSession.owner == owner)
|
||||
stored = stale_q.first()
|
||||
if stored is not None and _has_auth_keys(stored.headers):
|
||||
stale_q.update({"headers": {}})
|
||||
db.commit()
|
||||
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
|
||||
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
|
||||
return
|
||||
base = normalize_base(ep.base_url or "")
|
||||
sess.headers = build_headers(ep.api_key, base)
|
||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
update_q = update_q.filter(DBSession.owner == owner)
|
||||
@@ -408,7 +457,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
owner = getattr(sess, "owner", None)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
endpoints = q.all()
|
||||
for ep in endpoints:
|
||||
try:
|
||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||
@@ -542,7 +596,11 @@ async def build_chat_context(
|
||||
|
||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||
# re-hit slow local /models endpoints on every participant turn.
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
|
||||
sess.endpoint_url,
|
||||
sess.model,
|
||||
owner=getattr(sess, "owner", None),
|
||||
)
|
||||
if norm:
|
||||
sess.model = norm
|
||||
|
||||
|
||||
+57
-12
@@ -169,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
Covers the window between endpoint setup and the first chat send: the
|
||||
picker showed a model in the dropdown but the session record never got
|
||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||
Without this, we'd POST the upstream with model="" and get a generic
|
||||
401/503 instead of using the model the user already picked.
|
||||
|
||||
Returns True iff sess.model was repaired.
|
||||
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
|
||||
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
|
||||
"""
|
||||
if getattr(sess, "model", None):
|
||||
return False
|
||||
current_model = (getattr(sess, "model", "") or "").strip()
|
||||
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||
is_chatgpt_subscription = False
|
||||
if current_model:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
|
||||
if not is_chatgpt_subscription:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Prefer the endpoint whose base URL matches the session — we know the
|
||||
@@ -194,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
break
|
||||
if not ep:
|
||||
return False
|
||||
if not is_chatgpt_subscription:
|
||||
try:
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
|
||||
except Exception:
|
||||
is_chatgpt_subscription = False
|
||||
try:
|
||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||
except Exception:
|
||||
cached = []
|
||||
if not cached:
|
||||
visible = []
|
||||
else:
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if is_chatgpt_subscription:
|
||||
live_models = []
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
if api_key:
|
||||
live_models = fetch_available_models(api_key)
|
||||
if live_models:
|
||||
ep.cached_models = json.dumps(live_models)
|
||||
db.commit()
|
||||
except Exception:
|
||||
live_models = []
|
||||
# ChatGPT Subscription recovery must use the live Codex catalog.
|
||||
# Cached rows are only trusted above to avoid revalidating a model
|
||||
# that is already present in the visible picker list.
|
||||
cached = live_models
|
||||
if not cached:
|
||||
return False
|
||||
try:
|
||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||
except Exception:
|
||||
visible = cached
|
||||
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||
return False
|
||||
if not visible:
|
||||
return False
|
||||
model = visible[0]
|
||||
@@ -213,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
||||
# Persist so the next request, websocket reconnect, or page reload
|
||||
# picks up the same model (we'd otherwise re-pick on every send
|
||||
# and silently switch on the user if the cached order shifts).
|
||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
||||
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||
if owner:
|
||||
db_session_q = db_session_q.filter(DBSession.owner == owner)
|
||||
db_session = db_session_q.first()
|
||||
if db_session:
|
||||
db_session.model = model
|
||||
db_session.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
sess.model = model
|
||||
logger.info(
|
||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
||||
"Recovered session model for %s — picked %r from endpoint %s",
|
||||
session_id, model, ep.id,
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
"""ChatGPT Subscription device-flow setup routes."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
|
||||
from routes.device_flow import (
|
||||
DeviceFlowPoll,
|
||||
DeviceFlowStart,
|
||||
PendingDeviceFlowStore,
|
||||
create_device_flow_router,
|
||||
)
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import chatgpt_subscription
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||
|
||||
|
||||
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
|
||||
access_token = tokens.get("access_token")
|
||||
refresh_token = tokens.get("refresh_token")
|
||||
if not access_token or not refresh_token:
|
||||
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
|
||||
|
||||
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||
models = chatgpt_subscription.fetch_available_models(access_token)
|
||||
if not models:
|
||||
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
auth = (
|
||||
db.query(ProviderAuthSession)
|
||||
.filter(
|
||||
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
ProviderAuthSession.owner == owner,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if auth is None:
|
||||
auth = ProviderAuthSession(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
owner=owner,
|
||||
label="ChatGPT Subscription",
|
||||
base_url=base,
|
||||
auth_mode="chatgpt",
|
||||
)
|
||||
db.add(auth)
|
||||
auth.base_url = base
|
||||
auth.access_token = access_token
|
||||
auth.refresh_token = refresh_token
|
||||
auth.last_refresh = utcnow_naive()
|
||||
auth.auth_mode = "chatgpt"
|
||||
|
||||
ep = (
|
||||
db.query(ModelEndpoint)
|
||||
.filter(
|
||||
ModelEndpoint.base_url == base,
|
||||
ModelEndpoint.provider_auth_id == auth.id,
|
||||
ModelEndpoint.owner == owner,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if ep is None:
|
||||
ep = ModelEndpoint(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
name="ChatGPT Subscription",
|
||||
base_url=base,
|
||||
model_type="llm",
|
||||
endpoint_kind="api",
|
||||
owner=owner,
|
||||
)
|
||||
db.add(ep)
|
||||
ep.name = "ChatGPT Subscription"
|
||||
ep.base_url = base
|
||||
ep.api_key = None
|
||||
ep.provider_auth_id = auth.id
|
||||
ep.is_enabled = True
|
||||
ep.supports_tools = False
|
||||
ep.model_type = "llm"
|
||||
ep.endpoint_kind = "api"
|
||||
ep.model_refresh_mode = "manual"
|
||||
ep.cached_models = json.dumps(models)
|
||||
db.commit()
|
||||
result = {
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"models": models,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
from routes.model_routes import _invalidate_models_cache
|
||||
|
||||
_invalidate_models_cache()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
|
||||
try:
|
||||
data = chatgpt_subscription.request_device_code()
|
||||
except Exception as exc:
|
||||
raise chatgpt_subscription.to_http_exception(exc)
|
||||
|
||||
device_auth_id = data.get("device_auth_id")
|
||||
user_code = data.get("user_code")
|
||||
if not device_auth_id or not user_code:
|
||||
raise HTTPException(502, "ChatGPT did not return a complete device code")
|
||||
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
|
||||
return DeviceFlowStart(
|
||||
pending={
|
||||
"device_auth_id": device_auth_id,
|
||||
"user_code": user_code,
|
||||
"owner": get_current_user(request) or None,
|
||||
},
|
||||
response={
|
||||
"user_code": user_code,
|
||||
"verification_uri": verification_uri,
|
||||
},
|
||||
interval=int(data.get("interval") or 5),
|
||||
expires_in=int(data.get("expires_in") or 900),
|
||||
)
|
||||
|
||||
|
||||
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||
try:
|
||||
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
|
||||
except Exception as exc:
|
||||
logger.debug("ChatGPT device poll failed: %s", exc)
|
||||
return DeviceFlowPoll.pending(str(exc))
|
||||
|
||||
authorization_code = data.get("authorization_code")
|
||||
code_verifier = data.get("code_verifier")
|
||||
if authorization_code and code_verifier:
|
||||
try:
|
||||
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
|
||||
result = _provision_endpoint(tokens, pending["owner"])
|
||||
except Exception as exc:
|
||||
logger.exception("ChatGPT Subscription endpoint provisioning failed")
|
||||
raise chatgpt_subscription.to_http_exception(exc)
|
||||
return DeviceFlowPoll.authorized(result)
|
||||
|
||||
err = data.get("error") or data.get("status")
|
||||
if err in ("authorization_pending", "pending", None):
|
||||
return DeviceFlowPoll.pending()
|
||||
if err == "slow_down":
|
||||
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||
if err in ("expired_token", "access_denied", "denied"):
|
||||
return DeviceFlowPoll.failed(err)
|
||||
return DeviceFlowPoll.pending(err or "unknown")
|
||||
|
||||
|
||||
def setup_chatgpt_subscription_routes():
|
||||
return create_device_flow_router(
|
||||
prefix="/api/chatgpt-subscription",
|
||||
tags=["chatgpt-subscription"],
|
||||
store=_DEVICE_FLOW_STORE,
|
||||
start_flow=_start_device_flow,
|
||||
poll_flow=_poll_device_flow,
|
||||
)
|
||||
+67
-117
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Request, Form, HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from core.database import SessionLocal, ModelEndpoint
|
||||
from core.middleware import require_admin
|
||||
from routes.device_flow import (
|
||||
DeviceFlowPoll,
|
||||
DeviceFlowStart,
|
||||
PendingDeviceFlowStore,
|
||||
create_device_flow_router,
|
||||
)
|
||||
from src.auth_helpers import get_current_user
|
||||
from src import copilot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
||||
# browser. Entries expire with the GitHub device code.
|
||||
#
|
||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
||||
# on a worker that never saw the start, returning "Unknown or expired login
|
||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
||||
_PENDING: Dict[str, Dict] = {}
|
||||
_PENDING_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _prune_expired() -> None:
|
||||
now = time.time()
|
||||
with _PENDING_LOCK:
|
||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
||||
_PENDING.pop(k, None)
|
||||
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||
|
||||
|
||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||
return result
|
||||
|
||||
|
||||
def setup_copilot_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
||||
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = str(form.get("enterprise_url") or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
|
||||
@router.post("/device/start")
|
||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
host = copilot.GITHUB_HOST
|
||||
ent = (enterprise_url or "").strip()
|
||||
if ent:
|
||||
host = copilot.normalize_domain(ent)
|
||||
try:
|
||||
data = copilot.request_device_code(host)
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response is not None else "unknown"
|
||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||
except Exception as e:
|
||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
|
||||
device_code = data.get("device_code")
|
||||
if not device_code:
|
||||
raise HTTPException(502, "GitHub did not return a device code")
|
||||
interval = int(data.get("interval") or 5)
|
||||
expires_in = int(data.get("expires_in") or 900)
|
||||
poll_id = uuid.uuid4().hex
|
||||
with _PENDING_LOCK:
|
||||
_PENDING[poll_id] = {
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"interval": interval,
|
||||
"owner": get_current_user(request) or None,
|
||||
"expires_at": time.time() + expires_in,
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return {
|
||||
"poll_id": poll_id,
|
||||
# verification_uri_complete embeds the user code, so the browser tab we
|
||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||
# code pre-filled — one click, no manual code entry.
|
||||
return DeviceFlowStart(
|
||||
pending={
|
||||
"device_code": device_code,
|
||||
"host": host,
|
||||
"enterprise_url": ent,
|
||||
"owner": get_current_user(request) or None,
|
||||
},
|
||||
response={
|
||||
"user_code": data.get("user_code"),
|
||||
"verification_uri": data.get("verification_uri"),
|
||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||
"interval": interval,
|
||||
"expires_in": expires_in,
|
||||
}
|
||||
},
|
||||
interval=int(data.get("interval") or 5),
|
||||
expires_in=int(data.get("expires_in") or 900),
|
||||
)
|
||||
|
||||
@router.post("/device/poll")
|
||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
_prune_expired()
|
||||
with _PENDING_LOCK:
|
||||
pending = _PENDING.get(poll_id)
|
||||
if not pending:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
|
||||
# Enforce GitHub's polling interval server-side so a chatty client
|
||||
# can't trip slow_down.
|
||||
now = time.time()
|
||||
if now < pending.get("next_poll_at", 0):
|
||||
return {"status": "pending"}
|
||||
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
except Exception as e:
|
||||
return DeviceFlowPoll.pending(f"poll error: {e}")
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
return DeviceFlowPoll.authorized(result)
|
||||
|
||||
token = data.get("access_token")
|
||||
if token:
|
||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||
try:
|
||||
result = _provision_endpoint(token, base, pending["owner"])
|
||||
except Exception as e:
|
||||
logger.exception("Copilot endpoint provisioning failed")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "authorized", "endpoint": result}
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
return DeviceFlowPoll.pending()
|
||||
if err == "slow_down":
|
||||
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||
if err in ("expired_token", "access_denied"):
|
||||
return DeviceFlowPoll.failed(err)
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return DeviceFlowPoll.pending(err or "unknown")
|
||||
|
||||
err = data.get("error")
|
||||
if err == "authorization_pending":
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
||||
return {"status": "pending"}
|
||||
if err == "slow_down":
|
||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
||||
with _PENDING_LOCK:
|
||||
if poll_id in _PENDING:
|
||||
_PENDING[poll_id]["interval"] = new_interval
|
||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
||||
return {"status": "pending"}
|
||||
if err in ("expired_token", "access_denied"):
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "failed", "error": err}
|
||||
# Unknown error — surface but keep the session for another try.
|
||||
return {"status": "pending", "detail": err or "unknown"}
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
with _PENDING_LOCK:
|
||||
_PENDING.pop(poll_id, None)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
def setup_copilot_routes():
|
||||
return create_device_flow_router(
|
||||
prefix="/api/copilot",
|
||||
tags=["copilot"],
|
||||
store=_DEVICE_FLOW_STORE,
|
||||
start_flow=_start_device_flow,
|
||||
poll_flow=_poll_device_flow,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Shared OAuth/device-flow route scaffolding for provider setup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
|
||||
from core.middleware import require_admin
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceFlowStart:
|
||||
"""Provider-specific start result consumed by the shared route wrapper."""
|
||||
|
||||
pending: Mapping[str, Any]
|
||||
response: Mapping[str, Any]
|
||||
interval: int = 5
|
||||
expires_in: int = 900
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceFlowPoll:
|
||||
"""Normalized provider poll outcome."""
|
||||
|
||||
status: str
|
||||
endpoint: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
interval: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||
return cls(status="pending", detail=detail)
|
||||
|
||||
@classmethod
|
||||
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||
return cls(status="slow_down", interval=interval, detail=detail)
|
||||
|
||||
@classmethod
|
||||
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
|
||||
return cls(status="authorized", endpoint=endpoint)
|
||||
|
||||
@classmethod
|
||||
def failed(cls, error: str) -> "DeviceFlowPoll":
|
||||
return cls(status="failed", error=error)
|
||||
|
||||
|
||||
class PendingDeviceFlowStore:
|
||||
"""Thread-safe in-memory pending device-flow store.
|
||||
|
||||
Device codes and provider-side secrets stay inside this process. Each entry
|
||||
stores provider payload separately from poll metadata so provider callbacks
|
||||
only receive the fields they created.
|
||||
"""
|
||||
|
||||
def __init__(self, *, time_func: Callable[[], float] = time.time):
|
||||
self._pending: dict[str, dict[str, Any]] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._time = time_func
|
||||
|
||||
def _now(self) -> float:
|
||||
return float(self._time())
|
||||
|
||||
def prune_expired(self) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
|
||||
self._pending.pop(key, None)
|
||||
|
||||
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
|
||||
self.prune_expired()
|
||||
poll_id = uuid.uuid4().hex
|
||||
with self._lock:
|
||||
self._pending[poll_id] = {
|
||||
"payload": dict(payload),
|
||||
"interval": max(int(interval or 5), 1),
|
||||
"expires_at": self._now() + max(int(expires_in or 900), 1),
|
||||
"next_poll_at": 0.0,
|
||||
}
|
||||
return poll_id
|
||||
|
||||
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
|
||||
self.prune_expired()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is None:
|
||||
return None
|
||||
return dict(entry.get("payload") or {})
|
||||
|
||||
def is_throttled(self, poll_id: str) -> bool:
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
|
||||
|
||||
def schedule_next(self, poll_id: str) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is not None:
|
||||
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
|
||||
|
||||
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
entry = self._pending.get(poll_id)
|
||||
if entry is not None:
|
||||
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
|
||||
entry["interval"] = max(new_interval, 1)
|
||||
entry["next_poll_at"] = now + entry["interval"]
|
||||
|
||||
def pop(self, poll_id: str) -> None:
|
||||
with self._lock:
|
||||
self._pending.pop(poll_id, None)
|
||||
|
||||
|
||||
async def _maybe_await(value: Any) -> Any:
|
||||
if inspect.isawaitable(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
|
||||
response: dict[str, Any] = {"status": "pending"}
|
||||
if detail:
|
||||
response["detail"] = detail
|
||||
return response
|
||||
|
||||
|
||||
def create_device_flow_router(
|
||||
*,
|
||||
prefix: str,
|
||||
tags: Iterable[str],
|
||||
store: PendingDeviceFlowStore,
|
||||
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
|
||||
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
|
||||
) -> APIRouter:
|
||||
"""Create standard `/device/start|poll|cancel` routes for a provider."""
|
||||
|
||||
router = APIRouter(prefix=prefix, tags=list(tags))
|
||||
|
||||
@router.post("/device/start")
|
||||
async def device_start(request: Request):
|
||||
require_admin(request)
|
||||
form = await request.form()
|
||||
start = await _maybe_await(start_flow(request, form))
|
||||
interval = int(start.interval or 5)
|
||||
expires_in = int(start.expires_in or 900)
|
||||
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
|
||||
response = dict(start.response)
|
||||
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
|
||||
return response
|
||||
|
||||
@router.post("/device/poll")
|
||||
async def device_poll(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
payload = store.get_payload(poll_id)
|
||||
if payload is None:
|
||||
raise HTTPException(404, "Unknown or expired login session")
|
||||
if store.is_throttled(poll_id):
|
||||
return {"status": "pending"}
|
||||
|
||||
try:
|
||||
outcome = await _maybe_await(poll_flow(request, payload))
|
||||
except Exception:
|
||||
store.pop(poll_id)
|
||||
raise
|
||||
|
||||
if outcome.status == "authorized":
|
||||
store.pop(poll_id)
|
||||
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
|
||||
if outcome.status == "failed":
|
||||
store.pop(poll_id)
|
||||
return {"status": "failed", "error": outcome.error or "denied"}
|
||||
if outcome.status == "slow_down":
|
||||
store.slow_down(poll_id, outcome.interval)
|
||||
return _pending_response(outcome.detail)
|
||||
|
||||
store.schedule_next(poll_id)
|
||||
return _pending_response(outcome.detail)
|
||||
|
||||
@router.post("/device/cancel")
|
||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||
require_admin(request)
|
||||
store.pop(poll_id)
|
||||
return {"status": "cancelled"}
|
||||
|
||||
return router
|
||||
+107
-13
@@ -283,6 +283,7 @@ _HOST_TO_CURATED = (
|
||||
("fireworks.ai", "fireworks"),
|
||||
("googleapis.com", "google"),
|
||||
("x.ai", "xai"),
|
||||
|
||||
("openrouter.ai", "openrouter"),
|
||||
("ollama.com", "ollama"),
|
||||
("opencode.ai/zen/go", "opencode-go"),
|
||||
@@ -493,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
|
||||
def _is_chat_model(model_id: str) -> bool:
|
||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||
mid = model_id.lower()
|
||||
if mid in {"gpt-5.1-codex"}:
|
||||
return True
|
||||
for prefix in _NON_CHAT_PREFIXES:
|
||||
if mid.startswith(prefix):
|
||||
return False
|
||||
@@ -505,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||
|
||||
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||
endpoint backed by that auth row is removed, the stored credentials should
|
||||
be cleared instead of lingering. Returns True if a row was deleted.
|
||||
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||
reference count so it does not keep its own auth alive.
|
||||
"""
|
||||
if not auth_id:
|
||||
return False
|
||||
from core.database import ProviderAuthSession
|
||||
still_referenced = db.query(ModelEndpoint.id).filter(
|
||||
ModelEndpoint.provider_auth_id == auth_id,
|
||||
ModelEndpoint.id != exclude_ep_id,
|
||||
).first()
|
||||
if still_referenced is not None:
|
||||
return False
|
||||
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
|
||||
if auth_row is None:
|
||||
return False
|
||||
db.delete(auth_row)
|
||||
return True
|
||||
|
||||
|
||||
def _is_discovery_only_provider(provider: str) -> bool:
|
||||
"""Provider that only supports model discovery, not live probing.
|
||||
|
||||
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||
chat-completions or general health endpoint, so completion probes and
|
||||
reachability pings are skipped — status is derived from cached models.
|
||||
"""
|
||||
return provider == "chatgpt-subscription"
|
||||
|
||||
|
||||
def _resolve_probe_key(ep) -> Optional[str]:
|
||||
"""API key/bearer to probe an endpoint with.
|
||||
|
||||
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||
"""
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||
return key
|
||||
except Exception as e:
|
||||
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||
return None
|
||||
|
||||
|
||||
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||
provider = _detect_provider(base)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# Responses/Codex API, not chat-completions: a completion probe would
|
||||
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Say OK"},
|
||||
@@ -621,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||
from src.endpoint_resolver import resolve_url
|
||||
base = resolve_url(_normalize_base(base_url))
|
||||
if _detect_provider(base) == "chatgpt-subscription":
|
||||
from src.chatgpt_subscription import fetch_available_models
|
||||
if api_key:
|
||||
return fetch_available_models(api_key, timeout=timeout)
|
||||
return []
|
||||
if _detect_provider(base) == "anthropic":
|
||||
# Try Anthropic's /v1/models endpoint first
|
||||
url = build_models_url(base)
|
||||
@@ -647,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||
return list(ANTHROPIC_MODELS)
|
||||
url = build_models_url(base)
|
||||
if not url:
|
||||
curated_key = _match_provider_curated(base, None)
|
||||
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||
return list(fallback or [])
|
||||
headers = build_headers(api_key, base)
|
||||
try:
|
||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||
@@ -998,6 +1068,17 @@ def setup_model_routes(model_discovery):
|
||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||
if not ok:
|
||||
continue
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||
ep,
|
||||
owner=getattr(ep, "owner", None),
|
||||
)
|
||||
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||
except Exception as e:
|
||||
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||
continue
|
||||
groups.setdefault(info["key"], {
|
||||
"base": info["base"],
|
||||
"api_key": info["api_key"],
|
||||
@@ -1266,12 +1347,20 @@ def setup_model_routes(model_discovery):
|
||||
"endpoint_kind": kind,
|
||||
}
|
||||
try:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
if _is_discovery_only_provider(provider):
|
||||
# No general health endpoint — an unauthenticated GET just
|
||||
# 401s. Report status from cached models instead of pinging.
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
entry["error"] = None
|
||||
entry["model_count"] = cached_count
|
||||
else:
|
||||
t0 = _time.time()
|
||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||
entry["error"] = ping.get("error")
|
||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||
except Exception as e:
|
||||
entry["latency_ms"] = None
|
||||
entry["status"] = "online" if cached_count else "offline"
|
||||
@@ -1304,7 +1393,7 @@ def setup_model_routes(model_discovery):
|
||||
if ep_id and ep_id not in endpoints_cache:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if ep:
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
ep_data = endpoints_cache.get(ep_id)
|
||||
if not ep_data:
|
||||
# Try to find by base_url from the model's endpoint field
|
||||
@@ -1343,7 +1432,7 @@ def setup_model_routes(model_discovery):
|
||||
"id": ep.id,
|
||||
"name": ep.name,
|
||||
"base_url": ep.base_url,
|
||||
"api_key": ep.api_key,
|
||||
"api_key": _resolve_probe_key(ep),
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
@@ -1432,12 +1521,14 @@ def setup_model_routes(model_discovery):
|
||||
# Endpoint counts as reachable if it has any model — including
|
||||
# admin-pinned IDs that a probe would never surface.
|
||||
status = "online" if (all_models or pinned) else "offline"
|
||||
base = _normalize_base(r.base_url)
|
||||
ping = None
|
||||
if not all_models and not pinned and r.is_enabled:
|
||||
# Discovery-only providers have no health endpoint — an
|
||||
# unauthenticated ping just 401s, so don't bother.
|
||||
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||
if ping.get("reachable"):
|
||||
status = "empty"
|
||||
base = _normalize_base(r.base_url)
|
||||
kind = _effective_endpoint_kind(r, base)
|
||||
results.append({
|
||||
"id": r.id,
|
||||
@@ -1713,7 +1804,7 @@ def setup_model_routes(model_discovery):
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found")
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1777,7 +1868,7 @@ def setup_model_routes(model_discovery):
|
||||
category = _classify_endpoint(base, kind)
|
||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||
try:
|
||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
||||
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||
except Exception as exc:
|
||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||
probed = []
|
||||
@@ -2116,7 +2207,9 @@ def setup_model_routes(model_discovery):
|
||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||
auth_id = getattr(ep, "provider_auth_id", None)
|
||||
db.delete(ep)
|
||||
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
|
||||
db.commit()
|
||||
_invalidate_models_cache()
|
||||
_local_probe_cache["data"] = None
|
||||
@@ -2126,6 +2219,7 @@ def setup_model_routes(model_discovery):
|
||||
"cleared_user_preferences": cleared_user_preferences,
|
||||
"cleared_sessions": cleared_sessions,
|
||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||
"cleared_provider_auth": cleared_provider_auth,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
+39
-26
@@ -75,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
|
||||
return owner_filter(q, ModelEndpoint, owner).first()
|
||||
|
||||
|
||||
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
|
||||
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
|
||||
|
||||
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
|
||||
panel-selected research endpoints. ChatGPT Subscription endpoints keep
|
||||
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
|
||||
"""
|
||||
from src.endpoint_resolver import (
|
||||
build_chat_url,
|
||||
build_headers,
|
||||
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
|
||||
)
|
||||
|
||||
try:
|
||||
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint credentials for research: %s", e)
|
||||
return None
|
||||
|
||||
ep_model = (model or "").strip()
|
||||
if not ep_model:
|
||||
try:
|
||||
models = json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
if not ep_model:
|
||||
return None
|
||||
return build_chat_url(base), ep_model, build_headers(api_key, base)
|
||||
|
||||
|
||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
router = APIRouter(tags=["research"])
|
||||
|
||||
@@ -371,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
|
||||
if body.endpoint_id:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped: never resolve another user's private endpoint
|
||||
@@ -380,18 +411,10 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found or disabled")
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = body.model or ""
|
||||
if not ep_model:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
|
||||
if not resolved:
|
||||
raise HTTPException(400, "Endpoint is not configured with a usable model.")
|
||||
ep_url, ep_model, ep_headers = resolved
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
@@ -408,7 +431,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||
if not ep_url:
|
||||
from src.database import SessionLocal
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Owner-scoped first-enabled fallback: the caller's own rows
|
||||
@@ -417,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
||||
ep = _owned_enabled_endpoint(db, user)
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = ""
|
||||
if ep.cached_models:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models)
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
resolved = _resolve_endpoint_runtime(ep, owner=user)
|
||||
if resolved:
|
||||
ep_url, ep_model, ep_headers = resolved
|
||||
finally:
|
||||
db.close()
|
||||
if not ep_url:
|
||||
|
||||
+22
-17
@@ -92,18 +92,13 @@ def _reject_compact_during_active_run(session_id: str) -> None:
|
||||
|
||||
|
||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||
"""Verify the current user owns the session. Raises 404 if not.
|
||||
"""Verify the current user owns the session, honoring single-user modes.
|
||||
|
||||
Ownership is checked against the DB row when one exists (unchanged). If
|
||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
||||
that lives only in ``session_manager`` because it was never persisted, or
|
||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
||||
user can still manage and delete it. Without this fallback such sessions are
|
||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
||||
|
||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
||||
that only care about persisted sessions keep their exact prior behavior.
|
||||
Authenticated requests must match the stored DB or in-memory owner. When
|
||||
auth is disabled and no user is present, treat the app as single-user mode:
|
||||
verify that the session exists, but do not compare its stored owner. This
|
||||
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
|
||||
rows created while auth was previously enabled.
|
||||
"""
|
||||
user = effective_user(request)
|
||||
if not user and not _auth_disabled():
|
||||
@@ -114,13 +109,13 @@ def _verify_session_owner(request: Request, session_id: str, session_manager=Non
|
||||
finally:
|
||||
db.close()
|
||||
if row is not None:
|
||||
if row.owner != user:
|
||||
if user and row.owner != user:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
return
|
||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||
if session_manager is not None:
|
||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
||||
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
|
||||
return
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
|
||||
@@ -372,8 +367,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
pass
|
||||
elif not model_to_use:
|
||||
from src.llm_core import list_model_ids
|
||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers)
|
||||
ids = list_model_ids(
|
||||
endpoint_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers,
|
||||
owner=user,
|
||||
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||
)
|
||||
if not ids:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
# Default to the first CHAT model — endpoints often list embedding/
|
||||
@@ -387,8 +387,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
||||
from src.llm_core import list_model_ids
|
||||
import os as _os
|
||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers)
|
||||
avail = list_model_ids(
|
||||
endpoint_url,
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
headers=validation_headers,
|
||||
owner=user,
|
||||
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||
)
|
||||
if not avail:
|
||||
raise HTTPException(400, "Cannot reach /v1/models")
|
||||
if model_to_use not in avail:
|
||||
|
||||
@@ -1109,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
idx = skills_manager.index_for(owner=user)
|
||||
return {"index": idx, "count": len(idx)}
|
||||
|
||||
@router.get("/slash-catalog")
|
||||
async def get_slash_catalog(request: Request):
|
||||
"""Return skills that are available as slash commands.
|
||||
|
||||
Mirrors the agent prompt's published-skill index so the UI never offers
|
||||
a slash command the model would not normally be allowed to discover.
|
||||
"""
|
||||
user = _owner(request)
|
||||
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
|
||||
entries = []
|
||||
for s in skills_manager.index_for(owner=user):
|
||||
name = (s.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
full = all_skills.get(name) or {}
|
||||
category = (s.get("category") or full.get("category") or "general").strip() or "general"
|
||||
entries.append({
|
||||
"type": "skill",
|
||||
"token": f"/{name}",
|
||||
"name": name,
|
||||
"category": f"Skills / {category}",
|
||||
"help": s.get("description") or full.get("description") or "",
|
||||
"usage": f"/{name} <request>",
|
||||
"uses": int(full.get("uses") or 0),
|
||||
"last_used": full.get("last_used"),
|
||||
})
|
||||
entries.sort(key=lambda row: row["name"])
|
||||
return {"skills": entries, "count": len(entries)}
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_skills(request: Request):
|
||||
"""Read-only list of the agent's built-in tool capabilities (research,
|
||||
@@ -1272,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
||||
_fire_skill_added(user)
|
||||
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
||||
|
||||
@router.post("/{skill_id}/invoke")
|
||||
async def invoke_skill(request: Request, skill_id: str):
|
||||
"""Build a skill-pinned prompt for slash-command invocation.
|
||||
|
||||
This is intentionally server-side so availability, ownership, and usage
|
||||
accounting use the same rules as the SkillsManager.
|
||||
"""
|
||||
user = _owner(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
body = {}
|
||||
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
|
||||
|
||||
invokable = {
|
||||
s.get("name"): s for s in skills_manager.index_for(owner=user)
|
||||
if (s.get("name") or "").strip()
|
||||
}
|
||||
match = invokable.get(skill_id)
|
||||
if not match:
|
||||
raise HTTPException(404, "Skill is not available for slash invocation")
|
||||
|
||||
name = match.get("name")
|
||||
md = skills_manager.read_skill_md(name, owner=user)
|
||||
if md is None:
|
||||
raise HTTPException(404, "Skill source unavailable")
|
||||
|
||||
skills_manager.record_use(name, owner=user)
|
||||
message = (
|
||||
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
|
||||
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
|
||||
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
|
||||
)
|
||||
return {
|
||||
"ok": True,
|
||||
"type": "skill",
|
||||
"name": name,
|
||||
"command": f"/{name}",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@router.get("/{skill_id}")
|
||||
async def get_skill(request: Request, skill_id: str):
|
||||
user = _owner(request)
|
||||
|
||||
+21
-10
@@ -325,22 +325,33 @@ def setup_webhook_routes(
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
model = body.model or "auto"
|
||||
api_key = ep.api_key
|
||||
if getattr(ep, "provider_auth_id", None):
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
|
||||
endpoint_url = build_chat_url(base_url)
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not resolve endpoint credentials")
|
||||
|
||||
if model == "auto":
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
models_url = build_models_url(base_url)
|
||||
hdrs = build_headers(api_key, base_url)
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not ids:
|
||||
ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
if models_url:
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not ids:
|
||||
ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
import json as _json
|
||||
ids = _json.loads(ep.cached_models or "[]")
|
||||
model = ids[0] if ids else "auto"
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not discover models from endpoint")
|
||||
|
||||
+39
-25
@@ -57,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
||||
# Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
|
||||
|
||||
|
||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||
@@ -98,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
provider = _detect_provider(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
if provider == "anthropic":
|
||||
# Anthropic: match against hardcoded model list
|
||||
@@ -114,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
||||
else:
|
||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||
try:
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
models_url = build_models_url(base)
|
||||
if models_url:
|
||||
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
model_ids = json.loads(ep.cached_models or "[]")
|
||||
except Exception:
|
||||
model_ids = []
|
||||
|
||||
@@ -1121,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
|
||||
total_models = 0
|
||||
|
||||
for ep in endpoints:
|
||||
base = _normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception:
|
||||
continue
|
||||
provider = _detect_provider(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
model_ids = []
|
||||
if provider == "anthropic":
|
||||
model_ids = list(ANTHROPIC_MODELS)
|
||||
else:
|
||||
try:
|
||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
models_url = build_models_url(base)
|
||||
if models_url:
|
||||
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not model_ids:
|
||||
model_ids = [
|
||||
m.get("name") or m.get("model")
|
||||
for m in (data.get("models") or [])
|
||||
if m.get("name") or m.get("model")
|
||||
]
|
||||
else:
|
||||
model_ids = json.loads(ep.cached_models or "[]")
|
||||
except Exception:
|
||||
model_ids = ["(endpoint offline)"]
|
||||
|
||||
|
||||
@@ -0,0 +1,311 @@
|
||||
"""ChatGPT subscription / Codex backend OAuth helpers.
|
||||
|
||||
This provider is intentionally separate from OpenAI API-key endpoints. It uses
|
||||
OpenAI account OAuth device authorization, stores refresh tokens server-side,
|
||||
and resolves a fresh bearer token at request time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
||||
|
||||
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
||||
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
||||
or "https://chatgpt.com/backend-api/codex"
|
||||
)
|
||||
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
|
||||
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
|
||||
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
|
||||
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
|
||||
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
||||
with _AUTH_REFRESH_LOCKS_GUARD:
|
||||
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_AUTH_REFRESH_LOCKS[auth_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatGPTSubscriptionError(RuntimeError):
|
||||
"""Base error for ChatGPT subscription provider failures."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
|
||||
"""Stored OAuth credentials are invalid or expired beyond refresh."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
|
||||
"""Upstream quota/rate limit; reconnecting will not fix it."""
|
||||
|
||||
|
||||
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
|
||||
"""No matching owner-scoped auth session exists."""
|
||||
|
||||
|
||||
def is_chatgpt_subscription_base(url: str) -> bool:
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(url or "")
|
||||
host = (parsed.hostname or "").lower().rstrip(".")
|
||||
path = (parsed.path or "").rstrip("/")
|
||||
except Exception:
|
||||
return False
|
||||
return host == "chatgpt.com" and (
|
||||
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
|
||||
)
|
||||
|
||||
|
||||
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
|
||||
headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Origin": "https://chatgpt.com",
|
||||
"Referer": "https://chatgpt.com/codex",
|
||||
"User-Agent": "Odysseus ChatGPT Subscription",
|
||||
}
|
||||
if access_token:
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
return headers
|
||||
|
||||
|
||||
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
|
||||
if not access_token:
|
||||
return []
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||
headers=chatgpt_headers(access_token),
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
return []
|
||||
data = response.json()
|
||||
except Exception:
|
||||
return []
|
||||
entries = data.get("models", []) if isinstance(data, dict) else []
|
||||
sortable: list[tuple[int, str]] = []
|
||||
for item in entries:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
slug = item.get("slug")
|
||||
if not isinstance(slug, str) or not slug.strip():
|
||||
continue
|
||||
visibility = item.get("visibility", "")
|
||||
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
|
||||
continue
|
||||
priority = item.get("priority")
|
||||
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||
sortable.append((rank, slug.strip()))
|
||||
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||
ordered: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _, slug in sortable:
|
||||
if slug not in seen:
|
||||
ordered.append(slug)
|
||||
seen.add(slug)
|
||||
return ordered
|
||||
|
||||
|
||||
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
|
||||
if response.status_code < 400:
|
||||
return
|
||||
code = ""
|
||||
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
|
||||
try:
|
||||
payload = response.json()
|
||||
err = payload.get("error") if isinstance(payload, dict) else None
|
||||
if isinstance(err, dict):
|
||||
code = str(err.get("code") or err.get("type") or "").strip()
|
||||
msg = err.get("message")
|
||||
if msg:
|
||||
message = f"ChatGPT Subscription {action} failed: {msg}"
|
||||
elif isinstance(err, str):
|
||||
code = err.strip()
|
||||
desc = payload.get("error_description") or payload.get("message")
|
||||
if desc:
|
||||
message = f"ChatGPT Subscription {action} failed: {desc}"
|
||||
except Exception:
|
||||
pass
|
||||
if response.status_code == 429:
|
||||
raise ChatGPTSubscriptionRateLimited(
|
||||
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
|
||||
)
|
||||
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
|
||||
raise ChatGPTSubscriptionReauthRequired(message)
|
||||
raise ChatGPTSubscriptionError(message)
|
||||
|
||||
|
||||
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
|
||||
_raise_for_oauth_response(response, action)
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
|
||||
if not isinstance(data, dict):
|
||||
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
|
||||
return data
|
||||
|
||||
|
||||
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
|
||||
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "device-code request")
|
||||
if not data.get("device_auth_id") or not data.get("user_code"):
|
||||
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
|
||||
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
|
||||
data.setdefault("interval", 5)
|
||||
data.setdefault("expires_in", 900)
|
||||
return data
|
||||
|
||||
|
||||
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
|
||||
json={"device_auth_id": device_auth_id, "user_code": user_code},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code in (403, 404):
|
||||
return {"status": "pending", "error": "authorization_pending"}
|
||||
return _json_or_error(response, "device-code poll")
|
||||
|
||||
|
||||
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||
response = httpx.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": authorization_code,
|
||||
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "token exchange")
|
||||
if not data.get("access_token"):
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
|
||||
return data
|
||||
|
||||
|
||||
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
|
||||
del access_token
|
||||
if not refresh_token:
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
|
||||
response = httpx.post(
|
||||
CHATGPT_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
data = _json_or_error(response, "token refresh")
|
||||
if not data.get("access_token"):
|
||||
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
|
||||
return data
|
||||
|
||||
|
||||
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
|
||||
parts = (token or "").split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError("not a JWT")
|
||||
segment = parts[1]
|
||||
segment += "=" * (-len(segment) % 4)
|
||||
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
|
||||
payload = json.loads(raw.decode("utf-8"))
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
|
||||
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
|
||||
try:
|
||||
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
|
||||
except Exception:
|
||||
return True
|
||||
return exp <= int(time.time()) + int(skew_seconds)
|
||||
|
||||
|
||||
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(ProviderAuthSession).filter(
|
||||
ProviderAuthSession.id == auth_id,
|
||||
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
)
|
||||
if owner:
|
||||
q = q.filter(ProviderAuthSession.owner == owner)
|
||||
row = q.first()
|
||||
if row is None:
|
||||
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
|
||||
|
||||
access_token = row.access_token or ""
|
||||
if force_refresh or access_token_is_expiring(access_token):
|
||||
with _refresh_lock_for(auth_id):
|
||||
db.refresh(row)
|
||||
access_token = row.access_token or ""
|
||||
refresh_token = row.refresh_token or ""
|
||||
if force_refresh or access_token_is_expiring(access_token):
|
||||
refreshed = refresh_oauth_tokens(access_token, refresh_token)
|
||||
row.access_token = refreshed["access_token"]
|
||||
if refreshed.get("refresh_token"):
|
||||
row.refresh_token = refreshed["refresh_token"]
|
||||
row.last_refresh = utcnow_naive()
|
||||
db.commit()
|
||||
db.refresh(row)
|
||||
access_token = row.access_token or ""
|
||||
|
||||
return {
|
||||
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
|
||||
"api_key": access_token,
|
||||
"auth_mode": row.auth_mode or "chatgpt",
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def to_http_exception(exc: Exception) -> HTTPException:
|
||||
if isinstance(exc, ChatGPTSubscriptionRateLimited):
|
||||
return HTTPException(429, str(exc))
|
||||
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
|
||||
return HTTPException(401, f"{exc} Reconnect the provider.")
|
||||
return HTTPException(502, str(exc))
|
||||
|
||||
|
||||
def build_responses_input(messages: list[dict]) -> list[dict]:
|
||||
input_items: list[dict] = []
|
||||
for msg in messages or []:
|
||||
role = msg.get("role") or "user"
|
||||
if role == "tool":
|
||||
role = "user"
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
|
||||
else:
|
||||
text = "" if content is None else str(content)
|
||||
input_type = "output_text" if role == "assistant" else "input_text"
|
||||
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
|
||||
return input_items
|
||||
@@ -70,6 +70,25 @@ def _endpoint_enabled_models(ep) -> list:
|
||||
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
|
||||
|
||||
|
||||
def resolve_endpoint_runtime(ep, owner: Optional[str] = None) -> Tuple[str, Optional[str]]:
|
||||
"""Resolve a ModelEndpoint row to its runtime base URL and bearer/API key.
|
||||
|
||||
Static-key providers use ``ModelEndpoint.api_key``. Session-backed providers
|
||||
store refreshable credentials in ProviderAuthSession and must resolve a
|
||||
current access token at call time.
|
||||
"""
|
||||
base = normalize_base(getattr(ep, "base_url", "") or "")
|
||||
api_key = getattr(ep, "api_key", None)
|
||||
auth_id = getattr(ep, "provider_auth_id", None)
|
||||
if auth_id:
|
||||
from src.chatgpt_subscription import resolve_runtime_credentials
|
||||
|
||||
creds = resolve_runtime_credentials(auth_id, owner=owner)
|
||||
base = normalize_base(creds.get("base_url") or base)
|
||||
api_key = creds.get("api_key")
|
||||
return base, api_key
|
||||
|
||||
|
||||
# Cache for Tailscale hostname → IP resolution
|
||||
_tailscale_cache: Dict[str, Optional[str]] = {}
|
||||
|
||||
@@ -133,7 +152,7 @@ def resolve_url(url: str) -> str:
|
||||
def normalize_base(url: str) -> str:
|
||||
"""Strip known API path suffixes from a base URL."""
|
||||
url = (url or "").strip().rstrip("/")
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages", "/responses"]:
|
||||
if url.endswith(suffix):
|
||||
url = url[: -len(suffix)].rstrip("/")
|
||||
for suffix in ["/chat", "/tags", "/generate"]:
|
||||
@@ -158,10 +177,12 @@ def build_chat_url(base: str) -> str:
|
||||
return _anthropic_api_root(base) + "/v1/messages"
|
||||
if provider == "ollama":
|
||||
return _ollama_api_root(base) + "/chat"
|
||||
if provider == "chatgpt-subscription":
|
||||
return base.rstrip("/") + "/responses"
|
||||
return base + "/chat/completions"
|
||||
|
||||
|
||||
def build_models_url(base: str) -> str:
|
||||
def build_models_url(base: str) -> Optional[str]:
|
||||
"""Return the provider-specific model-list endpoint URL for a base."""
|
||||
base = resolve_url(base)
|
||||
provider = _detect_provider(base)
|
||||
@@ -169,6 +190,8 @@ def build_models_url(base: str) -> str:
|
||||
return _anthropic_api_root(base) + "/v1/models"
|
||||
if provider == "ollama":
|
||||
return _ollama_api_root(base) + "/tags"
|
||||
if provider == "chatgpt-subscription":
|
||||
return None
|
||||
return base + "/models"
|
||||
|
||||
|
||||
@@ -184,6 +207,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
||||
if provider == "copilot":
|
||||
from src.copilot import copilot_headers
|
||||
return copilot_headers(api_key)
|
||||
if provider == "chatgpt-subscription":
|
||||
from src.chatgpt_subscription import chatgpt_headers
|
||||
return chatgpt_headers(api_key)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if provider == "openrouter":
|
||||
@@ -262,9 +288,13 @@ def resolve_endpoint(
|
||||
if not ep:
|
||||
return fallback_url, fallback_model, fallback_headers
|
||||
|
||||
base = normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
|
||||
return fallback_url, fallback_model, fallback_headers
|
||||
chat_url = build_chat_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
|
||||
# Discard a configured model the user has since disabled on the
|
||||
# endpoint (e.g. a stale `default_model` left pointing at a now-hidden
|
||||
@@ -308,9 +338,13 @@ def resolve_endpoint_by_id(
|
||||
ep = q.first()
|
||||
if not ep:
|
||||
return None
|
||||
base = normalize_base(ep.base_url)
|
||||
try:
|
||||
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
|
||||
return None
|
||||
chat_url = build_chat_url(base)
|
||||
headers = build_headers(ep.api_key, base)
|
||||
headers = build_headers(api_key, base)
|
||||
m = (model or "").strip()
|
||||
# Drop a model the user disabled on the endpoint, then pick the first
|
||||
# enabled chat model rather than a hidden one.
|
||||
|
||||
+217
-7
@@ -426,6 +426,9 @@ def _detect_provider(url: str) -> str:
|
||||
return "openrouter"
|
||||
if _host_match(url, "groq.com"):
|
||||
return "groq"
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
if is_chatgpt_subscription_base(url):
|
||||
return "chatgpt-subscription"
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(url):
|
||||
return "copilot"
|
||||
@@ -462,6 +465,8 @@ def _provider_label(url: str) -> str:
|
||||
if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go"
|
||||
if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen"
|
||||
if _host_match(url, "groq.com"): return "Groq"
|
||||
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||
if is_chatgpt_subscription_base(url): return "ChatGPT Subscription"
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(url): return "GitHub Copilot"
|
||||
if _host_match(url, "mistral.ai"): return "Mistral"
|
||||
@@ -479,6 +484,77 @@ def _provider_label(url: str) -> str:
|
||||
return host or "provider"
|
||||
|
||||
|
||||
def _normalize_chatgpt_subscription_url(url: str) -> str:
|
||||
base = (url or "").strip().rstrip("/")
|
||||
if base.endswith("/responses"):
|
||||
return base
|
||||
return base + "/responses"
|
||||
|
||||
|
||||
def _message_content_as_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
if part:
|
||||
parts.append(str(part))
|
||||
continue
|
||||
if isinstance(part.get("text"), str):
|
||||
parts.append(part["text"])
|
||||
continue
|
||||
if isinstance(part.get("content"), str):
|
||||
parts.append(part["content"])
|
||||
return "\n".join(parts)
|
||||
return "" if content is None else str(content)
|
||||
|
||||
|
||||
def _chatgpt_subscription_instructions(messages: List[Dict]) -> str:
|
||||
instructions = [
|
||||
_message_content_as_text(msg.get("content")).strip()
|
||||
for msg in messages or []
|
||||
if (msg.get("role") or "") == "system"
|
||||
]
|
||||
instructions = [part for part in instructions if part]
|
||||
if instructions:
|
||||
return "\n\n".join(instructions)
|
||||
return "You are a helpful AI assistant."
|
||||
|
||||
|
||||
def _build_chatgpt_responses_payload(
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
*,
|
||||
stream: bool = False,
|
||||
) -> Dict:
|
||||
from src.chatgpt_subscription import build_responses_input
|
||||
|
||||
conversation = [msg for msg in (messages or []) if (msg.get("role") or "") != "system"]
|
||||
payload: Dict = {
|
||||
"model": model,
|
||||
"instructions": _chatgpt_subscription_instructions(messages),
|
||||
"input": build_responses_input(conversation),
|
||||
"stream": stream,
|
||||
"store": False,
|
||||
}
|
||||
if not _restricts_temperature(model):
|
||||
payload["temperature"] = temperature
|
||||
if max_tokens and max_tokens > 0:
|
||||
payload["max_output_tokens"] = max_tokens
|
||||
return payload
|
||||
|
||||
|
||||
def _format_chatgpt_subscription_error(status_code: int, text: str) -> str:
|
||||
if status_code in (401, 403):
|
||||
return "ChatGPT Subscription credentials expired or were rejected. Reconnect the provider."
|
||||
if status_code == 429:
|
||||
return "ChatGPT Subscription quota or rate limit was reached. Retry after the upstream limit resets."
|
||||
return _format_upstream_error(status_code, text, "https://chatgpt.com/backend-api/codex")
|
||||
|
||||
|
||||
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
|
||||
"""Turn an upstream HTTP error into a user-readable sentence.
|
||||
|
||||
@@ -874,7 +950,7 @@ def _normalize_anthropic_url(url: str) -> str:
|
||||
def _model_list_base(url: str) -> str:
|
||||
"""Normalize model/chat URLs to the configured endpoint base."""
|
||||
base = (url or "").strip().rstrip("/")
|
||||
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
|
||||
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages", "/responses"):
|
||||
if base.endswith(suffix):
|
||||
base = base[: -len(suffix)].rstrip("/")
|
||||
for suffix in ("/chat", "/tags", "/generate"):
|
||||
@@ -903,7 +979,12 @@ def _parse_model_cache(raw) -> List[str]:
|
||||
return out
|
||||
|
||||
|
||||
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
def _configured_cached_model_ids(
|
||||
endpoint_url: str,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""Return cached models for a configured endpoint matching endpoint_url."""
|
||||
target = _model_list_base(endpoint_url)
|
||||
if not target:
|
||||
@@ -914,7 +995,13 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
return []
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||
if endpoint_id:
|
||||
q = q.filter(ModelEndpoint.id == endpoint_id)
|
||||
if owner:
|
||||
from src.auth_helpers import owner_filter
|
||||
q = owner_filter(q, ModelEndpoint, owner)
|
||||
rows = q.all()
|
||||
for ep in rows:
|
||||
if _model_list_base(getattr(ep, "base_url", "")) != target:
|
||||
continue
|
||||
@@ -933,9 +1020,16 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
||||
def list_model_ids(
|
||||
base_chat_url: str,
|
||||
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||
headers: Optional[Dict] = None,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""List available model IDs from an endpoint."""
|
||||
cached = _configured_cached_model_ids(base_chat_url)
|
||||
cached = _configured_cached_model_ids(base_chat_url, owner=owner, endpoint_id=endpoint_id)
|
||||
if cached:
|
||||
return cached
|
||||
provider = _detect_provider(base_chat_url)
|
||||
@@ -971,9 +1065,16 @@ def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||
pass
|
||||
return []
|
||||
|
||||
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
|
||||
def normalize_model_id(
|
||||
endpoint_url: str,
|
||||
requested: str,
|
||||
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Normalize a model ID to match available models."""
|
||||
avail = list_model_ids(endpoint_url, timeout)
|
||||
avail = list_model_ids(endpoint_url, timeout, owner=owner, endpoint_id=endpoint_id)
|
||||
if not avail:
|
||||
return None
|
||||
if requested in avail:
|
||||
@@ -1169,6 +1270,49 @@ async def llm_call_async(
|
||||
logger.debug(f"Returning cached response for key: {cache_key}")
|
||||
return cached_response
|
||||
|
||||
if provider == "chatgpt-subscription":
|
||||
# ChatGPT/Codex requires streamed Responses requests even for callers
|
||||
# that want a plain string (auto-title, memory extraction, etc.).
|
||||
# Reuse stream_llm's validated Codex SSE path and collect deltas.
|
||||
parts: List[str] = []
|
||||
async for chunk in stream_llm(
|
||||
url,
|
||||
model,
|
||||
messages_copy,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
):
|
||||
event_is_error = False
|
||||
for line in str(chunk).splitlines():
|
||||
if line.startswith("event:"):
|
||||
event_is_error = line[6:].strip() == "error"
|
||||
continue
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
raw = line[5:].strip()
|
||||
if not raw:
|
||||
continue
|
||||
if raw == "[DONE]":
|
||||
response = "".join(parts)
|
||||
_set_cached_response(cache_key, response)
|
||||
return response
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if event_is_error or data.get("error") or (data.get("status") and data.get("text")):
|
||||
status = int(data.get("status") or 502)
|
||||
text = data.get("text") or data.get("error") or "ChatGPT Subscription request failed"
|
||||
raise HTTPException(status, text)
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, str):
|
||||
parts.append(delta)
|
||||
response = "".join(parts)
|
||||
_set_cached_response(cache_key, response)
|
||||
return response
|
||||
|
||||
if provider == "anthropic":
|
||||
target_url = _normalize_anthropic_url(url)
|
||||
h = _build_anthropic_headers(headers)
|
||||
@@ -1294,6 +1438,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
model, messages_copy, temperature, max_tokens,
|
||||
stream=True, tools=tools, num_ctx=get_context_length(url, model),
|
||||
)
|
||||
elif provider == "chatgpt-subscription":
|
||||
target_url = _normalize_chatgpt_subscription_url(url)
|
||||
h = _provider_headers(provider, headers)
|
||||
payload = _build_chatgpt_responses_payload(model, messages_copy, temperature, max_tokens, stream=True)
|
||||
else:
|
||||
target_url = url
|
||||
payload = {
|
||||
@@ -1325,6 +1473,68 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
return
|
||||
note_model_activity(target_url, model)
|
||||
|
||||
# ── ChatGPT Subscription / Codex Responses streaming ──
|
||||
if provider == "chatgpt-subscription":
|
||||
event_name = ""
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
try:
|
||||
client = _get_http_client()
|
||||
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||
_clear_host_dead(target_url)
|
||||
if r.status_code != 200:
|
||||
raw = (await r.aread()).decode(errors="replace")
|
||||
friendly = _format_chatgpt_subscription_error(r.status_code, raw)
|
||||
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
||||
return
|
||||
async for line in r.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith("event:"):
|
||||
event_name = line[6:].strip()
|
||||
continue
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
raw = line[5:].strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
evt = data.get("type") or event_name
|
||||
if evt == "response.output_text.delta":
|
||||
delta = data.get("delta") or ""
|
||||
if delta:
|
||||
yield f'data: {json.dumps({"delta": delta})}\n\n'
|
||||
elif evt == "response.completed":
|
||||
usage = (data.get("response") or {}).get("usage") or data.get("usage") or {}
|
||||
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or input_tokens
|
||||
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") or output_tokens
|
||||
if input_tokens or output_tokens:
|
||||
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": input_tokens, "output_tokens": output_tokens}})}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
elif evt in ("response.failed", "error"):
|
||||
err = data.get("error") or (data.get("response") or {}).get("error") or {}
|
||||
text = err.get("message") if isinstance(err, dict) else str(err or "ChatGPT Subscription request failed")
|
||||
yield f'event: error\ndata: {json.dumps({"status": 502, "text": text})}\n\n'
|
||||
return
|
||||
yield "data: [DONE]\n\n"
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
_cooled = _mark_host_dead(target_url)
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"ChatGPT Subscription stream connect to {target_url} failed: {e}{_tail}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
||||
except httpx.ReadTimeout:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
||||
except httpx.NetworkError:
|
||||
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
||||
except Exception as e:
|
||||
logger.error(f"ChatGPT Subscription stream error: {e}")
|
||||
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
||||
return
|
||||
|
||||
# ── Native Ollama streaming ──
|
||||
if provider == "ollama":
|
||||
_ollama_tool_calls: List[Dict] = []
|
||||
|
||||
@@ -2108,6 +2108,8 @@
|
||||
<option value="https://api.anthropic.com" data-logo="anthropic">Anthropic</option>
|
||||
<option value="https://api.deepseek.com/v1" data-logo="deepseek" selected>DeepSeek</option>
|
||||
<option value="https://api.openai.com/v1" data-logo="openai">OpenAI</option>
|
||||
<option value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot</option>
|
||||
<option value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription</option>
|
||||
<option value="https://openrouter.ai/api/v1" data-logo="openrouter">OpenRouter</option>
|
||||
<option value="https://ollama.com/api" data-logo="ollama">Ollama Cloud</option>
|
||||
<option value="https://api.groq.com/openai/v1" data-logo="groq">Groq</option>
|
||||
@@ -2136,6 +2138,7 @@
|
||||
<button class="admin-btn-add" id="adm-epAddBtn" style="width:55px;text-align:center;">Add</button>
|
||||
</div>
|
||||
<div id="adm-epApiMsg" class="adm-ep-inline-msg"></div>
|
||||
<div id="adm-deviceAuthStatus" class="adm-ep-inline-msg"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
+201
-68
@@ -5,6 +5,7 @@ import uiModule from './ui.js';
|
||||
import settingsModule from './settings.js';
|
||||
import { providerLogo } from './providers.js';
|
||||
import { sortModelObjects } from './modelSort.js';
|
||||
import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js';
|
||||
|
||||
let initialized = false;
|
||||
let modalEl = null;
|
||||
@@ -707,6 +708,80 @@ function initEndpointForm() {
|
||||
const pickerBtn = el('adm-provider-btn');
|
||||
const pickerMenu = el('adm-provider-menu');
|
||||
const pickerCurrent = picker ? picker.querySelector('.adm-provider-current') : null;
|
||||
const DEVICE_AUTH_PROVIDER_VALUES = new Set(Object.keys(PROVIDER_DEVICE_FLOWS));
|
||||
let deviceAuthPolling = false;
|
||||
function _selectedProviderOption() {
|
||||
return provider && provider.selectedOptions ? provider.selectedOptions[0] : null;
|
||||
}
|
||||
function _selectedDeviceAuthProvider() {
|
||||
const opt = _selectedProviderOption();
|
||||
const flow = opt && opt.dataset ? opt.dataset.authFlow : '';
|
||||
if (flow && DEVICE_AUTH_PROVIDER_VALUES.has(flow)) return flow;
|
||||
return DEVICE_AUTH_PROVIDER_VALUES.has(provider.value) ? provider.value : '';
|
||||
}
|
||||
function _isDeviceAuthSelected() {
|
||||
return !!_selectedDeviceAuthProvider();
|
||||
}
|
||||
function _setApiFormForProvider() {
|
||||
const deviceAuthProvider = _selectedDeviceAuthProvider();
|
||||
const deviceAuthConfig = PROVIDER_DEVICE_FLOWS[deviceAuthProvider] || null;
|
||||
const apiKey = el('adm-epApiKey');
|
||||
const testBtn = el('adm-epApiTestBtn');
|
||||
const addBtn = el('adm-epAddBtn');
|
||||
const status = el('adm-deviceAuthStatus');
|
||||
const msg = _endpointMsg('api');
|
||||
if (deviceAuthConfig) {
|
||||
urlInput.value = '';
|
||||
urlInput.placeholder = deviceAuthProvider === 'copilot'
|
||||
? 'GitHub Copilot uses GitHub account sign-in'
|
||||
: 'ChatGPT Subscription uses OpenAI account sign-in';
|
||||
urlInput.readOnly = true;
|
||||
if (apiKey) {
|
||||
apiKey.value = '';
|
||||
apiKey.placeholder = 'No API key needed';
|
||||
apiKey.disabled = true;
|
||||
}
|
||||
if (testBtn) {
|
||||
testBtn.disabled = true;
|
||||
testBtn.style.opacity = '0.45';
|
||||
testBtn.style.cursor = 'not-allowed';
|
||||
}
|
||||
if (addBtn) {
|
||||
addBtn.disabled = false;
|
||||
addBtn.textContent = 'Add';
|
||||
addBtn.style.width = '55px';
|
||||
addBtn.style.display = '';
|
||||
}
|
||||
if (kindSel) kindSel.value = 'api';
|
||||
if (msg) {
|
||||
msg.textContent = '';
|
||||
msg.className = '';
|
||||
}
|
||||
} else {
|
||||
urlInput.placeholder = 'Base URL or pick provider';
|
||||
urlInput.readOnly = false;
|
||||
if (apiKey) {
|
||||
apiKey.placeholder = 'API key';
|
||||
apiKey.disabled = false;
|
||||
}
|
||||
if (testBtn) {
|
||||
testBtn.disabled = false;
|
||||
testBtn.style.opacity = '';
|
||||
testBtn.style.cursor = '';
|
||||
}
|
||||
if (addBtn) {
|
||||
addBtn.disabled = false;
|
||||
addBtn.textContent = 'Add';
|
||||
addBtn.style.width = '55px';
|
||||
addBtn.style.display = '';
|
||||
}
|
||||
if (msg) {
|
||||
msg.textContent = '';
|
||||
msg.className = '';
|
||||
}
|
||||
if (!deviceAuthPolling && status) status.textContent = '';
|
||||
}
|
||||
}
|
||||
function _renderPickerMenu() {
|
||||
if (!pickerMenu) return;
|
||||
pickerMenu.innerHTML = Array.from(provider.options).map(o => {
|
||||
@@ -748,9 +823,16 @@ function initEndpointForm() {
|
||||
}
|
||||
|
||||
provider.addEventListener('change', () => {
|
||||
if (_isDeviceAuthSelected()) {
|
||||
_setApiFormForProvider();
|
||||
_renderPickerMenu();
|
||||
_syncPickerCurrent();
|
||||
return;
|
||||
}
|
||||
if (provider.value) urlInput.value = provider.value;
|
||||
else urlInput.value = '';
|
||||
if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy';
|
||||
_setApiFormForProvider();
|
||||
});
|
||||
urlInput.addEventListener('input', () => {
|
||||
if (provider.value && urlInput.value.trim() !== provider.value) {
|
||||
@@ -838,6 +920,12 @@ function initEndpointForm() {
|
||||
const apiCancelTestBtn = el('adm-epApiCancelTestBtn');
|
||||
if (apiTestBtn) {
|
||||
apiTestBtn.addEventListener('click', async () => {
|
||||
if (_isDeviceAuthSelected()) {
|
||||
const msg = _endpointMsg('api');
|
||||
msg.textContent = '';
|
||||
msg.className = '';
|
||||
return;
|
||||
}
|
||||
const msg = _endpointMsg('api');
|
||||
msg.textContent = ''; msg.className = '';
|
||||
const rawUrl = (urlInput.value || provider.value).trim();
|
||||
@@ -885,6 +973,11 @@ function initEndpointForm() {
|
||||
}
|
||||
|
||||
el('adm-epAddBtn').addEventListener('click', async () => {
|
||||
const deviceAuthProvider = _selectedDeviceAuthProvider();
|
||||
if (deviceAuthProvider) {
|
||||
await _startProviderDeviceAuth(deviceAuthProvider, el('adm-epAddBtn'));
|
||||
return;
|
||||
}
|
||||
const msg = _endpointMsg('api');
|
||||
msg.textContent = ''; msg.className = '';
|
||||
const rawUrl = (urlInput.value || provider.value).trim();
|
||||
@@ -936,76 +1029,116 @@ function initEndpointForm() {
|
||||
btn.disabled = false; btn.textContent = 'Add';
|
||||
});
|
||||
|
||||
// GitHub Copilot — device-flow login. Starts the flow, shows the user a
|
||||
// code + verification link, and polls until they authorise (or it expires).
|
||||
const copilotBtn = el('adm-copilotConnectBtn');
|
||||
if (copilotBtn) {
|
||||
let copilotPolling = false;
|
||||
copilotBtn.addEventListener('click', async () => {
|
||||
if (copilotPolling) return;
|
||||
const status = el('adm-copilotStatus');
|
||||
const reset = () => { copilotBtn.disabled = false; copilotBtn.textContent = 'Connect GitHub Copilot'; copilotPolling = false; };
|
||||
status.textContent = ''; status.className = 'adm-ep-inline-msg';
|
||||
copilotBtn.disabled = true; copilotBtn.textContent = 'Starting...';
|
||||
copilotPolling = true;
|
||||
let start;
|
||||
try {
|
||||
const res = await fetch('/api/copilot/device/start', { method: 'POST', body: new FormData(), credentials: 'same-origin' });
|
||||
start = await res.json();
|
||||
if (!res.ok) { status.textContent = start.detail || 'Failed to start login'; status.className = 'admin-error'; reset(); return; }
|
||||
} catch (e) { status.textContent = 'Request failed'; status.className = 'admin-error'; reset(); return; }
|
||||
async function _startProviderDeviceAuth(providerKey, triggerEl = null) {
|
||||
if (deviceAuthPolling) return;
|
||||
const config = PROVIDER_DEVICE_FLOWS[providerKey];
|
||||
if (!config) return;
|
||||
const status = el('adm-deviceAuthStatus') || _endpointMsg('api');
|
||||
if (!status) return;
|
||||
const triggerText = triggerEl ? triggerEl.textContent : '';
|
||||
// Render an error with an inline "Try again" (the top button is hidden for
|
||||
// device-auth providers, so retry lives here). Built with DOM methods, not
|
||||
// innerHTML. Call reset() first so the deviceAuthPolling guard is cleared.
|
||||
const showAuthError = (text) => {
|
||||
status.className = 'admin-error';
|
||||
status.textContent = text + ' ';
|
||||
const retry = document.createElement('button');
|
||||
retry.type = 'button';
|
||||
retry.className = 'admin-btn-sm';
|
||||
retry.textContent = 'Try again';
|
||||
retry.addEventListener('click', () => { _startProviderDeviceAuth(providerKey, triggerEl); });
|
||||
status.appendChild(retry);
|
||||
};
|
||||
const reset = () => {
|
||||
if (triggerEl) {
|
||||
triggerEl.disabled = false;
|
||||
triggerEl.textContent = triggerText || 'Add';
|
||||
}
|
||||
deviceAuthPolling = false;
|
||||
_setApiFormForProvider();
|
||||
};
|
||||
status.textContent = '';
|
||||
status.className = 'adm-ep-inline-msg';
|
||||
if (triggerEl) {
|
||||
triggerEl.disabled = true;
|
||||
triggerEl.textContent = 'Starting...';
|
||||
}
|
||||
deviceAuthPolling = true;
|
||||
_setApiFormForProvider();
|
||||
status.textContent = `Starting ${config.label} sign-in...`;
|
||||
|
||||
const { poll_id, user_code, verification_uri, verification_uri_complete, interval, expires_in } = start;
|
||||
// Prefer the "complete" URL — it embeds the code so the user only has to
|
||||
// click "Authorize" (no manual code entry).
|
||||
const authUrl = verification_uri_complete || verification_uri || '';
|
||||
const esc = (s) => String(s || '').replace(/[<>&"]/g, (c) => ({ '<': '<', '>': '>', '&': '&', '"': '"' }[c]));
|
||||
copilotBtn.textContent = 'Waiting…';
|
||||
|
||||
// Cohesive waiting panel: spinner + status line, the device code as a
|
||||
// copyable chip, and a primary "Authorize on GitHub" action.
|
||||
status.className = '';
|
||||
status.innerHTML =
|
||||
'<div class="adm-copilot-panel">' +
|
||||
'<div class="adm-copilot-wait"><span class="admin-spinner"></span>' +
|
||||
'<span>Waiting for GitHub authorization…</span></div>' +
|
||||
'<div class="adm-copilot-coderow">' +
|
||||
'<span class="adm-copilot-code-label">Code</span>' +
|
||||
'<code class="adm-copilot-code">' + esc(user_code) + '</code>' +
|
||||
'<button type="button" class="admin-btn-sm adm-copilot-copy">Copy</button>' +
|
||||
'</div>' +
|
||||
'<a class="admin-btn-add adm-copilot-auth" href="' + encodeURI(authUrl) + '" target="_blank" rel="noopener">Authorize on GitHub ↗</a>' +
|
||||
'<div class="adm-copilot-hint">A new tab opened on GitHub — approve there to finish. Didn\'t open? Use the button above.</div>' +
|
||||
'</div>';
|
||||
const copyBtn = status.querySelector('.adm-copilot-copy');
|
||||
if (copyBtn) copyBtn.addEventListener('click', async () => {
|
||||
try { await navigator.clipboard.writeText(user_code || ''); copyBtn.textContent = 'Copied'; setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500); } catch (e) {}
|
||||
try {
|
||||
const result = await runProviderDeviceFlow(providerKey, {
|
||||
openWindow: () => {},
|
||||
onStart: ({ start, authUrl }) => {
|
||||
if (triggerEl) triggerEl.textContent = 'Waiting...';
|
||||
status.className = '';
|
||||
const authLabel = providerKey === 'copilot' ? 'Authorize on GitHub' : 'Authorize with OpenAI';
|
||||
const waitLabel = providerKey === 'copilot' ? 'Waiting for GitHub authorization...' : 'Waiting for ChatGPT authorization...';
|
||||
status.innerHTML =
|
||||
'<div class="adm-copilot-panel">' +
|
||||
'<div class="adm-copilot-wait"><span class="admin-spinner"></span>' +
|
||||
'<span>' + esc(waitLabel) + '</span></div>' +
|
||||
'<div class="adm-copilot-coderow">' +
|
||||
'<span class="adm-copilot-code-label">Code</span>' +
|
||||
'<code class="adm-copilot-code">' + esc(start.user_code) + '</code>' +
|
||||
'<button type="button" class="admin-btn-sm adm-device-auth-copy">Copy</button>' +
|
||||
'</div>' +
|
||||
'<a class="admin-btn-add adm-copilot-auth" href="' + encodeURI(authUrl || '') + '" target="_blank" rel="noopener">' + esc(authLabel) + ' ↗</a>' +
|
||||
'</div>';
|
||||
const copyBtn = status.querySelector('.adm-device-auth-copy');
|
||||
if (copyBtn) copyBtn.addEventListener('click', async () => {
|
||||
const code = start.user_code || '';
|
||||
let ok = false;
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(code);
|
||||
ok = true;
|
||||
}
|
||||
} catch (e) {}
|
||||
if (!ok) {
|
||||
// navigator.clipboard is unavailable in non-secure contexts (HTTP
|
||||
// self-host over a LAN IP), so fall back to execCommand('copy').
|
||||
const ta = document.createElement('textarea');
|
||||
ta.value = code;
|
||||
ta.style.cssText = 'position:fixed;top:0;left:0;width:1px;height:1px;padding:0;border:0;opacity:0;font-size:16px;';
|
||||
document.body.appendChild(ta);
|
||||
ta.focus();
|
||||
ta.select();
|
||||
try { ta.setSelectionRange(0, code.length); } catch (e) {}
|
||||
try { ok = document.execCommand('copy'); } catch (e) {}
|
||||
ta.remove();
|
||||
}
|
||||
copyBtn.textContent = ok ? 'Copied' : 'Failed';
|
||||
setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500);
|
||||
});
|
||||
},
|
||||
});
|
||||
try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {}
|
||||
|
||||
const deadline = Date.now() + (expires_in || 900) * 1000;
|
||||
const stepMs = Math.max((interval || 5), 2) * 1000;
|
||||
const done = (cls, text) => { status.className = cls; status.textContent = text; reset(); };
|
||||
const poll = async () => {
|
||||
if (Date.now() > deadline) { done('admin-error', 'Authorization expired — try again.'); return; }
|
||||
try {
|
||||
const fd = new FormData(); fd.append('poll_id', poll_id);
|
||||
const r = await fetch('/api/copilot/device/poll', { method: 'POST', body: fd, credentials: 'same-origin' });
|
||||
const d = await r.json();
|
||||
if (d.status === 'authorized') {
|
||||
const n = ((d.endpoint && d.endpoint.models) || []).length;
|
||||
done('admin-success', '✓ Connected — ' + n + ' Copilot model' + (n !== 1 ? 's' : '') + ' available.');
|
||||
if (d.endpoint && d.endpoint.id) _recentlyAddedEpId = String(d.endpoint.id);
|
||||
await loadEndpoints();
|
||||
await _selectAddedModelInChat(d.endpoint || {});
|
||||
return;
|
||||
}
|
||||
if (d.status === 'failed') { done('admin-error', 'Authorization failed (' + (d.error || 'denied') + ').'); return; }
|
||||
} catch (e) { /* transient — keep polling */ }
|
||||
setTimeout(poll, stepMs);
|
||||
};
|
||||
setTimeout(poll, stepMs);
|
||||
});
|
||||
if (result.status === 'authorized') {
|
||||
const endpoint = result.endpoint || {};
|
||||
const n = ((endpoint && endpoint.models) || []).length;
|
||||
status.className = 'admin-success';
|
||||
status.textContent = 'Connected - ' + n + ' ' + config.label + ' model' + (n !== 1 ? 's' : '') + ' available.';
|
||||
if (endpoint && endpoint.id) _recentlyAddedEpId = String(endpoint.id);
|
||||
await loadEndpoints();
|
||||
await _selectAddedModelInChat(endpoint || {});
|
||||
reset();
|
||||
return;
|
||||
}
|
||||
if (result.status === 'failed') {
|
||||
reset();
|
||||
showAuthError('Authorization failed (' + (result.error || 'denied') + ').');
|
||||
return;
|
||||
}
|
||||
if (result.status === 'expired') {
|
||||
reset();
|
||||
showAuthError('Authorization expired.');
|
||||
return;
|
||||
}
|
||||
} catch (e) {
|
||||
reset();
|
||||
showAuthError(formatDeviceFlowError(e));
|
||||
}
|
||||
}
|
||||
|
||||
// Local "Add" button — sibling form for self-hosted base URLs.
|
||||
|
||||
+39
-15
@@ -680,9 +680,11 @@ export function applyModelColor(roleEl, modelName) {
|
||||
html += '<div><span class="ctx-label">Max tokens</span> ' + _mt.toLocaleString() + ' <span style="opacity:0.4">(configured)</span></div>';
|
||||
}
|
||||
}
|
||||
if (info && info.input != null) html += '<div><span class="ctx-label">Input</span> $' + info.input.toFixed(2) + ' / 1M</div>';
|
||||
if (info && info.output != null) html += '<div><span class="ctx-label">Output</span> $' + info.output.toFixed(2) + ' / 1M</div>';
|
||||
if (!info) html += '<div style="opacity:0.4;font-size:0.85em;margin-top:4px;">No pricing data available</div>';
|
||||
if (isCostTrackedEndpoint(_epUrl)) {
|
||||
if (info && info.input != null) html += '<div><span class="ctx-label">Input</span> $' + info.input.toFixed(2) + ' / 1M</div>';
|
||||
if (info && info.output != null) html += '<div><span class="ctx-label">Output</span> $' + info.output.toFixed(2) + ' / 1M</div>';
|
||||
if (!info) html += '<div style="opacity:0.4;font-size:0.85em;margin-top:4px;">No pricing data available</div>';
|
||||
}
|
||||
popup.innerHTML = html;
|
||||
const rect = roleEl.getBoundingClientRect();
|
||||
popup.style.top = (rect.bottom + 4) + 'px';
|
||||
@@ -735,11 +737,31 @@ export function isLocalEndpoint(url) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Cost for the current turn, returning null (free) for local endpoints. */
|
||||
function _billableCost(model, inputTokens, outputTokens) {
|
||||
const url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
|
||||
export function isSubscriptionEndpoint(url) {
|
||||
if (!url) return false;
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
const path = parsed.pathname.replace(/\/+$/, '');
|
||||
return parsed.hostname === 'chatgpt.com'
|
||||
&& (path === '/backend-api/codex' || path.startsWith('/backend-api/codex/'));
|
||||
} catch (_e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function _currentEndpointUrl() {
|
||||
return (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
|
||||
? window.sessionModule.getCurrentEndpointUrl() : null;
|
||||
if (isLocalEndpoint(url)) return null;
|
||||
}
|
||||
|
||||
export function isCostTrackedEndpoint(url) {
|
||||
return !isLocalEndpoint(url) && !isSubscriptionEndpoint(url);
|
||||
}
|
||||
|
||||
/** Cost for the current turn, returning null for non-billable endpoints. */
|
||||
function _billableCost(model, inputTokens, outputTokens) {
|
||||
const url = _currentEndpointUrl();
|
||||
if (!isCostTrackedEndpoint(url)) return null;
|
||||
return getModelCost(model, inputTokens, outputTokens);
|
||||
}
|
||||
|
||||
@@ -784,11 +806,10 @@ export function resetSessionCost(sessionId) {
|
||||
export function updateSessionCostUI() {
|
||||
const el = document.getElementById('session-cost-display');
|
||||
if (!el) return;
|
||||
// Local model? It's free — hide the badge and clear any stale cost that a
|
||||
// previous (buggy) cloud-rate billing left in localStorage for this session.
|
||||
const _url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
|
||||
? window.sessionModule.getCurrentEndpointUrl() : null;
|
||||
if (isLocalEndpoint(_url)) {
|
||||
// Non-billable endpoint? Hide the badge and clear stale cost that a previous
|
||||
// cloud-rate calculation may have left in localStorage for this session.
|
||||
const _url = _currentEndpointUrl();
|
||||
if (!isCostTrackedEndpoint(_url)) {
|
||||
const sid = window.sessionModule && window.sessionModule.getCurrentSessionId();
|
||||
if (sid && getSessionCost(sid) > 0) {
|
||||
try {
|
||||
@@ -1708,7 +1729,8 @@ export function displayMetrics(messageElement, metrics) {
|
||||
e.stopPropagation();
|
||||
document.querySelectorAll('.ctx-popup').forEach(p => { if (typeof p._dismiss === 'function') p._dismiss(); else p.remove(); });
|
||||
|
||||
const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : 'n/a';
|
||||
const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : '';
|
||||
const costRows = costStr ? `<div><span class="ctx-label">Cost</span> ${costStr}</div>` : '';
|
||||
const speedStr = tps != null && tps !== 'undefined' ? `${tps} tok/s` : 'n/a';
|
||||
const totalTok = inputTokens + outputTokens;
|
||||
const ctxColor = ctxPct >= 85 ? 'var(--red, #e06c75)' : ctxPct >= 70 ? '#ff9900' : 'var(--color-muted-alt, #6b7280)';
|
||||
@@ -1722,7 +1744,7 @@ export function displayMetrics(messageElement, metrics) {
|
||||
// Session total cost
|
||||
let sessionCostStr = '';
|
||||
const sc = getSessionCost();
|
||||
if (sc > 0) {
|
||||
if (costStr && sc > 0) {
|
||||
sessionCostStr = `<div><span class="ctx-label">Session</span> $${sc < 0.01 ? sc.toFixed(4) : sc.toFixed(3)}</div>`;
|
||||
}
|
||||
|
||||
@@ -1738,7 +1760,7 @@ export function displayMetrics(messageElement, metrics) {
|
||||
<div><span class="ctx-label">Time</span> ${responseTime}s</div>
|
||||
${prepTime != null ? `<div><span class="ctx-label">Prep</span> ${prepTime}s</div>` : ''}
|
||||
${modelWaitTime != null ? `<div><span class="ctx-label">Model wait</span> ${modelWaitTime}s</div>` : ''}
|
||||
<div><span class="ctx-label">Cost</span> ${costStr}</div>
|
||||
${costRows}
|
||||
${sessionCostStr}
|
||||
${prepDetails ? `<div style="margin-top:6px;padding-top:6px;border-top:1px solid var(--border);font-size:0.85em;opacity:0.8;">
|
||||
<div style="font-weight:600;margin-bottom:4px;color:var(--fg);">Agent prep</div>
|
||||
@@ -2392,6 +2414,8 @@ const chatRenderer = {
|
||||
modelColor,
|
||||
applyModelColor,
|
||||
getModelCost,
|
||||
isCostTrackedEndpoint,
|
||||
isSubscriptionEndpoint,
|
||||
getImageCost,
|
||||
getSessionCost,
|
||||
resetSessionCost,
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// Shared DOM-free provider device-flow runner.
|
||||
|
||||
export const PROVIDER_DEVICE_FLOWS = {
|
||||
copilot: {
|
||||
label: 'GitHub Copilot',
|
||||
startUrl: '/api/copilot/device/start',
|
||||
pollUrl: '/api/copilot/device/poll',
|
||||
authUrl(start) {
|
||||
return start?.verification_uri_complete || start?.verification_uri || '';
|
||||
},
|
||||
},
|
||||
'chatgpt-subscription': {
|
||||
label: 'ChatGPT Subscription',
|
||||
startUrl: '/api/chatgpt-subscription/device/start',
|
||||
pollUrl: '/api/chatgpt-subscription/device/poll',
|
||||
authUrl(start) {
|
||||
return start?.verification_uri || '';
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
function _formData() {
|
||||
if (typeof FormData !== 'undefined') return new FormData();
|
||||
return new URLSearchParams();
|
||||
}
|
||||
|
||||
async function _jsonOrEmpty(response) {
|
||||
try {
|
||||
return await response.json();
|
||||
} catch (_) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function _messageFromPayload(payload, fallback) {
|
||||
if (payload && typeof payload.detail === 'string' && payload.detail.trim()) {
|
||||
return payload.detail.trim();
|
||||
}
|
||||
if (payload && typeof payload.error === 'string' && payload.error.trim()) {
|
||||
return payload.error.trim();
|
||||
}
|
||||
if (payload && typeof payload.message === 'string' && payload.message.trim()) {
|
||||
return payload.message.trim();
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
export function formatDeviceFlowError(error, fallback = 'Request failed') {
|
||||
if (!error) return fallback;
|
||||
if (typeof error === 'string') return error;
|
||||
if (error.detail) return String(error.detail);
|
||||
if (error.message) return String(error.message);
|
||||
return fallback;
|
||||
}
|
||||
|
||||
async function _fetchJson(fetchImpl, url, options, fallback) {
|
||||
const response = await fetchImpl(url, options);
|
||||
const payload = await _jsonOrEmpty(response);
|
||||
if (!response.ok) {
|
||||
throw new Error(_messageFromPayload(payload, fallback || `Request failed (HTTP ${response.status})`));
|
||||
}
|
||||
return payload;
|
||||
}
|
||||
|
||||
function _defaultSleep(ms) {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
async function _callCallback(fn, payload) {
|
||||
if (typeof fn === 'function') await fn(payload);
|
||||
}
|
||||
|
||||
export async function runProviderDeviceFlow(provider, options = {}) {
|
||||
const cfg = PROVIDER_DEVICE_FLOWS[provider];
|
||||
if (!cfg) throw new Error(`Unknown device-flow provider: ${provider}`);
|
||||
|
||||
const fetchImpl = options.fetchImpl || globalThis.fetch?.bind(globalThis);
|
||||
if (!fetchImpl) throw new Error('Fetch API is unavailable');
|
||||
|
||||
const openWindow = options.openWindow || ((url) => {
|
||||
if (globalThis.window && typeof globalThis.window.open === 'function') {
|
||||
globalThis.window.open(url, '_blank', 'noopener');
|
||||
}
|
||||
});
|
||||
const sleep = options.sleep || _defaultSleep;
|
||||
const now = options.now || (() => Date.now());
|
||||
const formData = options.formData || _formData();
|
||||
|
||||
const start = await _fetchJson(fetchImpl, cfg.startUrl, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
credentials: 'same-origin',
|
||||
}, `Failed to start ${cfg.label} sign-in`);
|
||||
|
||||
if (!start.poll_id) throw new Error(`${cfg.label} sign-in did not return a poll id`);
|
||||
const authUrl = cfg.authUrl(start);
|
||||
await _callCallback(options.onStart, { provider, config: cfg, start, authUrl });
|
||||
if (authUrl) openWindow(authUrl);
|
||||
|
||||
const deadline = now() + Number(start.expires_in || 900) * 1000;
|
||||
let stepMs = Math.max(Number(start.interval || 5), 2) * 1000;
|
||||
|
||||
while (true) {
|
||||
if (now() > deadline) return { status: 'expired' };
|
||||
await _callCallback(options.onWaiting, { provider, config: cfg, start, authUrl });
|
||||
await sleep(stepMs);
|
||||
if (now() > deadline) return { status: 'expired' };
|
||||
|
||||
const fd = _formData();
|
||||
fd.append('poll_id', start.poll_id);
|
||||
const poll = await _fetchJson(fetchImpl, cfg.pollUrl, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
credentials: 'same-origin',
|
||||
}, `${cfg.label} sign-in poll failed`);
|
||||
await _callCallback(options.onPoll, { provider, config: cfg, start, poll });
|
||||
|
||||
if (poll.status === 'authorized') {
|
||||
return { status: 'authorized', endpoint: poll.endpoint || {} };
|
||||
}
|
||||
if (poll.status === 'failed') {
|
||||
return { status: 'failed', error: poll.error || 'denied' };
|
||||
}
|
||||
if (poll.interval) {
|
||||
stepMs = Math.max(Number(poll.interval || 5), 2) * 1000;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,10 @@ const _PROVIDERS = [
|
||||
[/opencode/i,
|
||||
'<svg viewBox="0 0 24 30" fill="currentColor"><path d="M18 6H6V24H18V6ZM24 30H0V0H24V30Z"/></svg>'],
|
||||
|
||||
// GitHub / Copilot
|
||||
[/github|copilot/i,
|
||||
'<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .5A12 12 0 0 0 8.2 23.9c.6.1.8-.3.8-.6v-2.1c-3.3.7-4-1.4-4-1.4-.5-1.4-1.3-1.8-1.3-1.8-1.1-.8.1-.8.1-.8 1.2.1 1.9 1.3 1.9 1.3 1.1 1.9 2.9 1.3 3.6 1 .1-.8.4-1.3.8-1.6-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.3-.5-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.4 11.4 0 0 1 6 0C15.3 4.7 16 5 16 5c.6 1.6.2 2.9.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .5Z"/></svg>'],
|
||||
|
||||
// OpenRouter
|
||||
[/openrouter|open router/i,
|
||||
'<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="5" cy="12" r="2.5"/><circle cx="19" cy="6" r="2.5"/><circle cx="19" cy="18" r="2.5"/><path d="M7.5 12h4.5c2 0 2.5-6 4.5-6"/><path d="M12 12c2 0 2.5 6 4.5 6"/></svg>'],
|
||||
@@ -102,6 +106,7 @@ export function providerLogo(modelId) {
|
||||
// doesn't match `x.ai`.
|
||||
const _ENDPOINT_LABELS = [
|
||||
[/(^|\.)githubcopilot\.com$/i, "GitHub Copilot"],
|
||||
[/(^|\.)chatgpt\.com$/i, "ChatGPT Subscription"],
|
||||
[/(^|\.)openrouter\.ai$/i, "OpenRouter"],
|
||||
[/(^|\.)anthropic\.com$/i, "Anthropic"],
|
||||
[/(^|\.)openai\.com$/i, "OpenAI"],
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import { COMMANDS, LEGACY_ALIASES } from './slashCommands.js';
|
||||
|
||||
const POPUP_ID = 'slash-autocomplete';
|
||||
const MAX_VISIBLE = 12;
|
||||
const MAX_VISIBLE = 14;
|
||||
|
||||
// Flatten the registry into a searchable list of leaf entries. Each entry is
|
||||
// either a top-level command or a "cmd sub" pair (so subcommands get their
|
||||
@@ -81,6 +81,23 @@ function _flatten() {
|
||||
return out;
|
||||
}
|
||||
|
||||
async function _loadSkillEntries() {
|
||||
try {
|
||||
const res = await fetch('/api/skills/slash-catalog', { credentials: 'same-origin' });
|
||||
if (!res.ok) return [];
|
||||
const data = await res.json();
|
||||
return (Array.isArray(data.skills) ? data.skills : []).map(s => ({
|
||||
token: s.token || `/${s.name}`,
|
||||
aliases: [],
|
||||
category: s.category || 'Skills',
|
||||
help: s.help || 'Run skill',
|
||||
usage: s.usage || `${s.token || `/${s.name}`} <request>`,
|
||||
})).filter(e => e.token && e.token.startsWith('/'));
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function _scoreMatch(entry, query) {
|
||||
// query already starts with "/". Match against token + aliases. Prefix wins
|
||||
// over substring; alias match scores slightly lower than token match.
|
||||
@@ -98,6 +115,17 @@ function _scoreMatch(entry, query) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
function _exactCommandGroupItems(all, query) {
|
||||
const q = query.toLowerCase();
|
||||
if (!/^\/[a-z0-9_-]+$/i.test(q)) return [];
|
||||
const parent = all.find(entry => entry.token.toLowerCase() === q);
|
||||
if (!parent) return [];
|
||||
const prefix = q + ' ';
|
||||
const children = all.filter(entry => entry.token.toLowerCase().startsWith(prefix));
|
||||
if (!children.length) return [];
|
||||
return children.concat(parent);
|
||||
}
|
||||
|
||||
function _ensurePopup(textarea) {
|
||||
let el = document.getElementById(POPUP_ID);
|
||||
if (el) return el;
|
||||
@@ -164,7 +192,7 @@ export function initSlashAutocomplete(textarea) {
|
||||
if (!textarea || textarea._slashAcWired) return;
|
||||
textarea._slashAcWired = true;
|
||||
|
||||
const all = _flatten();
|
||||
let all = _flatten();
|
||||
let popup = null;
|
||||
let visible = false;
|
||||
let items = [];
|
||||
@@ -191,12 +219,17 @@ export function initSlashAutocomplete(textarea) {
|
||||
// the menu hides — we don't autocomplete mid-sentence.
|
||||
if (!v.startsWith('/') || v.includes('\n')) { hide(); return; }
|
||||
const query = v.trim();
|
||||
items = all
|
||||
const groupItems = _exactCommandGroupItems(all, query);
|
||||
if (groupItems.length) {
|
||||
items = groupItems.slice(0, MAX_VISIBLE);
|
||||
} else {
|
||||
items = all
|
||||
.map(e => ({ e, s: _scoreMatch(e, query) }))
|
||||
.filter(x => x.s > 0)
|
||||
.sort((a, b) => b.s - a.s)
|
||||
.slice(0, MAX_VISIBLE)
|
||||
.map(x => x.e);
|
||||
}
|
||||
if (!items.length && query.length > 1) { hide(); return; }
|
||||
if (!items.length) {
|
||||
// Just "/" with no matches — fall back to showing everything up to MAX_VISIBLE
|
||||
@@ -207,6 +240,19 @@ export function initSlashAutocomplete(textarea) {
|
||||
_render(popup, items, selectedIdx, query);
|
||||
};
|
||||
|
||||
_loadSkillEntries().then(skillEntries => {
|
||||
if (!skillEntries.length) return;
|
||||
const seen = new Set(all.map(e => e.token));
|
||||
const merged = all.slice();
|
||||
for (const entry of skillEntries) {
|
||||
if (seen.has(entry.token)) continue;
|
||||
seen.add(entry.token);
|
||||
merged.push(entry);
|
||||
}
|
||||
all = merged;
|
||||
if (visible) refresh();
|
||||
});
|
||||
|
||||
const insert = (token) => {
|
||||
textarea.value = token + ' ';
|
||||
textarea.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
+351
-71
@@ -21,6 +21,7 @@ import workspaceModule from './workspace.js';
|
||||
import settingsModule from './settings.js';
|
||||
import cookbookModule from './cookbook.js';
|
||||
import { EVAL_PROMPTS } from './compare/index.js';
|
||||
import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js';
|
||||
|
||||
// ── Module state ──────────────────────────────────────────────────────
|
||||
|
||||
@@ -58,11 +59,28 @@ const SETUP_PROVIDER_URLS = {
|
||||
'opencode-go': { name: 'OpenCode Go', url: 'https://opencode.ai/zen/go/v1' },
|
||||
};
|
||||
const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'ollama', 'xai', 'anthropic', 'groq', 'gemini', 'opencode-zen', 'opencode-go'];
|
||||
const SETUP_PROVIDER_HINT = SETUP_PROVIDER_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_NAMES[SETUP_PROVIDER_NAMES.length - 1];
|
||||
const SETUP_DEVICE_AUTH_PROVIDERS = [
|
||||
{ key: 'copilot', name: 'GitHub Copilot', aliases: ['github'], command: '/setup copilot' },
|
||||
{ key: 'chatgpt-subscription', name: 'ChatGPT Subscription', aliases: ['chatgptsubscription', 'chatgpt-sub', 'codex'], command: '/setup chatgpt-subscription' },
|
||||
];
|
||||
const SETUP_PROVIDER_HINT_NAMES = SETUP_PROVIDER_NAMES.concat(SETUP_DEVICE_AUTH_PROVIDERS.map(provider => provider.key));
|
||||
const SETUP_PROVIDER_HINT = SETUP_PROVIDER_HINT_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_HINT_NAMES[SETUP_PROVIDER_HINT_NAMES.length - 1];
|
||||
const SETUP_LOCAL_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><rect x="2" y="3" width="20" height="14" rx="2"/><path d="M8 21h8"/><path d="M12 17v4"/></svg>';
|
||||
const SETUP_API_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>';
|
||||
const SETUP_SETTINGS_ICON = '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-2px;margin-right:5px;"><circle cx="12" cy="12" r="3"/><path d="M19.4 15a1.65 1.65 0 0 0 .33 1.82l.06.06a2 2 0 0 1-2.83 2.83l-.06-.06a1.65 1.65 0 0 0-1.82-.33 1.65 1.65 0 0 0-1 1.51V21a2 2 0 0 1-4 0v-.09A1.65 1.65 0 0 0 9 19.4a1.65 1.65 0 0 0-1.82.33l-.06.06a2 2 0 0 1-2.83-2.83l.06-.06a1.65 1.65 0 0 0 .33-1.82 1.65 1.65 0 0 0-1.51-1H3a2 2 0 0 1 0-4h.09A1.65 1.65 0 0 0 4.6 9a1.65 1.65 0 0 0-.33-1.82l-.06-.06a2 2 0 0 1 2.83-2.83l.06.06a1.65 1.65 0 0 0 1.82.33H9a1.65 1.65 0 0 0 1-1.51V3a2 2 0 0 1 4 0v.09a1.65 1.65 0 0 0 1 1.51 1.65 1.65 0 0 0 1.82-.33l.06-.06a2 2 0 0 1 2.83 2.83l-.06.06a1.65 1.65 0 0 0-.33 1.82V9a1.65 1.65 0 0 0 1.51 1H21a2 2 0 0 1 0 4h-.09a1.65 1.65 0 0 0-1.51 1z"/></svg>';
|
||||
|
||||
function _setupApiProviderChips() {
|
||||
return SETUP_PROVIDER_NAMES.map(name =>
|
||||
'<span class="setup-clickable-provider" data-setup-kind="api-key" data-setup-provider="' + name + '" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
|
||||
).join(' ');
|
||||
}
|
||||
|
||||
function _setupDeviceAuthProviderChips() {
|
||||
return SETUP_DEVICE_AUTH_PROVIDERS.map(provider =>
|
||||
'<span class="setup-clickable-provider" data-setup-kind="device-auth" data-setup-provider="' + provider.key + '" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Run ' + provider.command + '">' + provider.name + '</span>'
|
||||
).join(' ');
|
||||
}
|
||||
|
||||
function _setupProviderFromInput(input) {
|
||||
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '');
|
||||
const aliases = {
|
||||
@@ -84,6 +102,17 @@ function _setupProviderFromInput(input) {
|
||||
return SETUP_PROVIDER_URLS[aliases[raw] || raw] || null;
|
||||
}
|
||||
|
||||
function _setupDeviceAuthProviderFromInput(input) {
|
||||
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '').replace(/_/g, '-');
|
||||
if (!raw) return '';
|
||||
for (const provider of SETUP_DEVICE_AUTH_PROVIDERS) {
|
||||
const candidates = [provider.key, provider.name, ...(provider.aliases || [])]
|
||||
.map(value => String(value || '').toLowerCase().replace(/\s+/g, '').replace(/_/g, '-'));
|
||||
if (candidates.includes(raw)) return provider.key;
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
function _extractSetupProviderCredential(input) {
|
||||
const raw = (input || '').trim();
|
||||
if (!raw) return null;
|
||||
@@ -158,9 +187,8 @@ function _setupReply(text, remember = true) {
|
||||
}
|
||||
|
||||
function _showSetupEndpointChoices() {
|
||||
const providers = SETUP_PROVIDER_NAMES.map(name =>
|
||||
'<span class="setup-clickable-provider" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
|
||||
).join(' ');
|
||||
const providers = _setupApiProviderChips();
|
||||
const deviceAuthProviders = _setupDeviceAuthProviderChips();
|
||||
return slashReply(
|
||||
'<div class="setup-guide-no-censor" style="display:grid;gap:10px;">' +
|
||||
'<div>' +
|
||||
@@ -178,6 +206,7 @@ function _showSetupEndpointChoices() {
|
||||
'<div>Paste provider name then API key (example):</div>' +
|
||||
'<pre style="margin:4px 0 0;"><code class="setup-clickable-code" style="cursor:pointer;text-decoration:underline;" title="Click to fill in chat">deepseek sk-...</code></pre>' +
|
||||
'<div style="margin-top:8px;font-size:1em;"><span>Supported providers:</span><br>' + providers + '</div>' +
|
||||
'<div style="margin-top:8px;font-size:1em;"><span>Account sign-in:</span><br>' + deviceAuthProviders + '</div>' +
|
||||
'</div>' +
|
||||
'</div>'
|
||||
);
|
||||
@@ -208,9 +237,8 @@ function _showSetupEndpointChoicesStreamed(options = {}) {
|
||||
text: 'deepseek sk-...',
|
||||
copyText: 'deepseek sk-...',
|
||||
},
|
||||
{ kind: 'p', html: '<strong>Supported providers:</strong><br>' + SETUP_PROVIDER_NAMES.map(name =>
|
||||
'<span class="setup-clickable-provider" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
|
||||
).join(' ') },
|
||||
{ kind: 'p', html: '<strong>Supported providers:</strong><br>' + _setupApiProviderChips() },
|
||||
{ kind: 'p', html: '<strong>Account sign-in:</strong><br>' + _setupDeviceAuthProviderChips() },
|
||||
];
|
||||
return typewriterBlocksReply(blocks, { gap: '4px', bodyClass: 'setup-guide-no-censor', interval: 3 });
|
||||
}
|
||||
@@ -231,7 +259,7 @@ async function _hasConfiguredModels() {
|
||||
}
|
||||
|
||||
function _setupProviderPrompt() {
|
||||
const chips = SETUP_PROVIDER_NAMES.map(name =>
|
||||
const chips = SETUP_PROVIDER_HINT_NAMES.map(name =>
|
||||
'<span style="font-weight:650;">' + name + '</span>'
|
||||
).join(' ');
|
||||
slashReply('<b>Supported providers:</b><br>' + chips);
|
||||
@@ -286,6 +314,53 @@ function slashReply(text) {
|
||||
return { el: div, body };
|
||||
}
|
||||
|
||||
let _skillCatalogCache = { at: 0, items: [] };
|
||||
|
||||
async function _loadSkillSlashCatalog(force = false) {
|
||||
const now = Date.now();
|
||||
if (!force && (now - _skillCatalogCache.at) < 15000) return _skillCatalogCache.items;
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/api/skills/slash-catalog`, { credentials: 'same-origin' });
|
||||
if (!res.ok) throw new Error('catalog unavailable');
|
||||
const data = await res.json();
|
||||
const items = Array.isArray(data.skills) ? data.skills : [];
|
||||
_skillCatalogCache = { at: now, items };
|
||||
return items;
|
||||
} catch {
|
||||
return _skillCatalogCache.items || [];
|
||||
}
|
||||
}
|
||||
|
||||
function _submitComposedMessage(text) {
|
||||
const msgInput = document.getElementById('message');
|
||||
const form = document.getElementById('chat-form');
|
||||
if (!msgInput || !form) return false;
|
||||
msgInput.value = text;
|
||||
msgInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
if (typeof form.requestSubmit === 'function') form.requestSubmit();
|
||||
else form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true }));
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _invokeSkillByName(name, requestText, ctx) {
|
||||
const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/invoke`, {
|
||||
method: 'POST',
|
||||
credentials: 'same-origin',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ request: requestText || '' })
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => null);
|
||||
slashReply(ctx?.esc ? ctx.esc(err?.detail || 'Skill is not available') : 'Skill is not available');
|
||||
return true;
|
||||
}
|
||||
const data = await res.json();
|
||||
if (!data.message || !_submitComposedMessage(data.message)) {
|
||||
slashReply('Could not start skill invocation.');
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/** Minimal footer for slash replies: copy + dismiss */
|
||||
function _slashFooter(msgEl) {
|
||||
const footer = document.createElement('div');
|
||||
@@ -681,6 +756,13 @@ async function handleSetupWizard(mode, input) {
|
||||
await _setupProviderPrompt();
|
||||
return;
|
||||
}
|
||||
const deviceAuthProvider = _setupDeviceAuthProviderFromInput(input);
|
||||
if (deviceAuthProvider) {
|
||||
_addMessage('user', input);
|
||||
setupMode = false;
|
||||
await _setupProviderDeviceFlow(deviceAuthProvider);
|
||||
return;
|
||||
}
|
||||
const paired = _extractSetupProviderCredential(input);
|
||||
const provider = paired?.provider || _setupProviderFromInput(input);
|
||||
if (!provider) {
|
||||
@@ -1429,6 +1511,42 @@ async function _cmdModels(args, ctx) {
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _cmdModel(args, ctx) {
|
||||
const sub = (args[0] || '').toLowerCase();
|
||||
if (sub === 'list' || sub === 'ls') return _cmdModels(args.slice(1), ctx);
|
||||
|
||||
const model = sessionModule.getCurrentModel ? sessionModule.getCurrentModel() : '';
|
||||
const endpoint = sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : '';
|
||||
slashReply(`<pre>${[
|
||||
`Current model: ${ctx.esc(model || 'None selected')}`,
|
||||
endpoint ? `Endpoint: ${ctx.esc(endpoint)}` : 'Endpoint: not available',
|
||||
'',
|
||||
'Usage: /model list to show all available models'
|
||||
].join('\n')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _cmdMcp(args, ctx) {
|
||||
const res = await fetch(`${API_BASE}/api/mcp/servers`, { credentials: 'same-origin' });
|
||||
if (!res.ok) {
|
||||
slashReply('MCP status is unavailable for this user.');
|
||||
return true;
|
||||
}
|
||||
const servers = await res.json();
|
||||
if (!Array.isArray(servers) || !servers.length) {
|
||||
slashReply('No MCP servers configured.');
|
||||
return true;
|
||||
}
|
||||
const lines = servers.map(s => {
|
||||
const status = s.status || (s.is_enabled ? 'enabled' : 'disabled');
|
||||
const enabled = Number(s.enabled_tool_count ?? s.tool_count ?? 0);
|
||||
const total = Number(s.tool_count ?? enabled);
|
||||
return `${s.name || s.id || 'MCP server'} - ${status} (${enabled}/${total} tools)`;
|
||||
});
|
||||
slashReply(`<pre>${lines.map(line => ctx.esc(line)).join('\n')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
// ── Memory ──
|
||||
|
||||
async function _cmdMemoryList(args, ctx) {
|
||||
@@ -1507,6 +1625,73 @@ async function _cmdMemorySearch(args, ctx) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// ── Skills ──
|
||||
|
||||
async function _cmdSkills(args, ctx) {
|
||||
const sub = (args[0] || 'list').toLowerCase();
|
||||
const rest = args.slice(1);
|
||||
|
||||
if (sub === 'list' || sub === 'ls') {
|
||||
const skills = await _loadSkillSlashCatalog(true);
|
||||
if (!skills.length) {
|
||||
slashReply('No published skills available for slash commands.');
|
||||
return true;
|
||||
}
|
||||
const lines = skills.map(s => {
|
||||
const uses = Number(s.uses || 0);
|
||||
const useText = uses > 0 ? ` uses:${uses}` : '';
|
||||
return `${ctx.esc(String(s.token || '').padEnd(24))}${ctx.esc(s.help || '')}${useText}`;
|
||||
});
|
||||
slashReply(`<pre>${lines.join('\n')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (sub === 'search' || sub === 'find') {
|
||||
const query = rest.join(' ').trim();
|
||||
if (!query) { slashReply('Usage: /skills search query'); return true; }
|
||||
const res = await fetch(`${API_BASE}/api/skills/search`, {
|
||||
method: 'POST',
|
||||
credentials: 'same-origin',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ query })
|
||||
});
|
||||
if (!res.ok) { slashReply('Skill search failed.'); return true; }
|
||||
const data = await res.json();
|
||||
const skills = Array.isArray(data.skills) ? data.skills : [];
|
||||
if (!skills.length) { slashReply(`No skills found for "${ctx.esc(query)}".`); return true; }
|
||||
const lines = skills.map(s =>
|
||||
ctx.esc(`/${s.name || s.id || ''}`.padEnd(24)) + ctx.esc(s.description || '')
|
||||
);
|
||||
slashReply(`<pre>${lines.join('\n')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (sub === 'view' || sub === 'cat' || sub === 'show') {
|
||||
const name = (rest[0] || '').trim();
|
||||
if (!name) { slashReply('Usage: /skills view name'); return true; }
|
||||
const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/markdown`, { credentials: 'same-origin' });
|
||||
if (!res.ok) { slashReply(`Skill "${ctx.esc(name)}" was not found.`); return true; }
|
||||
const data = await res.json();
|
||||
slashReply(`<pre>${ctx.esc(data.markdown || '')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (sub === 'use' || sub === 'run') {
|
||||
const name = (rest[0] || '').trim();
|
||||
if (!name) { slashReply('Usage: /skills use name request'); return true; }
|
||||
return _invokeSkillByName(name, rest.slice(1).join(' ').trim(), ctx);
|
||||
}
|
||||
|
||||
slashReply('Usage: /skills list | search query | view name | use name request');
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _cmdReloadSkills(args, ctx) {
|
||||
const skills = await _loadSkillSlashCatalog(true);
|
||||
slashReply(`Reloaded skills. ${skills.length} skill command${skills.length === 1 ? '' : 's'} available.`);
|
||||
return true;
|
||||
}
|
||||
|
||||
// ── Note (quick Notes shortcut) ──
|
||||
|
||||
async function _cmdNote(args, ctx) {
|
||||
@@ -1799,6 +1984,53 @@ Uploads: ${d.uploads || '?'}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
async function _cmdUsage(args, ctx) {
|
||||
const sid = ctx.sid;
|
||||
if (!sid) {
|
||||
slashReply('No active session.');
|
||||
return true;
|
||||
}
|
||||
|
||||
let session = null;
|
||||
try {
|
||||
const sessions = sessionModule.getSessions ? sessionModule.getSessions() : [];
|
||||
session = (sessions || []).find(s => s.id === sid) || null;
|
||||
if (!session) {
|
||||
const res = await fetch(`${API_BASE}/api/sessions`, { credentials: 'same-origin' });
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
const items = Array.isArray(data) ? data : (data.sessions || data.items || []);
|
||||
session = items.find(s => s.id === sid) || null;
|
||||
}
|
||||
}
|
||||
} catch (_) {}
|
||||
|
||||
const model = session?.model || 'Unknown';
|
||||
const endpointUrl = session?.endpoint_url || (
|
||||
sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : ''
|
||||
);
|
||||
const messageCount = Number(session?.message_count || 0);
|
||||
const totalTokens = Number(session?.total_tokens || 0);
|
||||
const costTracked = chatRenderer.isCostTrackedEndpoint ? chatRenderer.isCostTrackedEndpoint(endpointUrl) : true;
|
||||
const cost = costTracked && chatRenderer.getSessionCost ? Number(chatRenderer.getSessionCost(sid) || 0) : 0;
|
||||
const costLine = costTracked
|
||||
? (cost > 0
|
||||
? `Estimated local cost: $${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}`
|
||||
: 'Estimated local cost: unavailable or zero')
|
||||
: 'Estimated local cost: not tracked for this endpoint';
|
||||
|
||||
slashReply(`<pre>${[
|
||||
`Session: ${ctx.esc(session?.name || 'Current chat')}`,
|
||||
`Model: ${ctx.esc(model)}`,
|
||||
`Messages: ${messageCount.toLocaleString()}`,
|
||||
`Recorded tokens: ${totalTokens.toLocaleString()}`,
|
||||
costLine,
|
||||
'',
|
||||
'Provider account usage is not available from here; check the provider dashboard for account quota/billing.'
|
||||
].join('\n')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
|
||||
// ── Context compaction ──
|
||||
|
||||
async function _cmdCompact(args, ctx) {
|
||||
@@ -4783,39 +5015,53 @@ function _clearSetupCommandInput() {
|
||||
}
|
||||
}
|
||||
|
||||
// GitHub Copilot device-flow sign-in, driven from chat (mirrors the Settings
|
||||
// "Connect GitHub Copilot" button). Replies via the setup guide messages.
|
||||
async function _setupCopilot() {
|
||||
async function _setupProviderDeviceFlow(providerKey) {
|
||||
_clearSetupGuideMessages();
|
||||
await _setupReply('Starting GitHub Copilot sign-in…');
|
||||
let start;
|
||||
const config = PROVIDER_DEVICE_FLOWS[providerKey];
|
||||
if (!config) {
|
||||
await _setupReply('Provider not recognised.');
|
||||
return;
|
||||
}
|
||||
await _setupReply(`Starting ${config.label} sign-in...`);
|
||||
try {
|
||||
const r = await fetch(`${API_BASE}/api/copilot/device/start`, { method: 'POST', body: new FormData(), credentials: 'same-origin' });
|
||||
start = await r.json();
|
||||
if (!r.ok) { await _setupReply(start.detail || 'Failed to start Copilot sign-in.'); return; }
|
||||
} catch (e) { await _setupReply('Request failed.'); return; }
|
||||
const authUrl = start.verification_uri_complete || start.verification_uri || '';
|
||||
await _setupReply(`Opening GitHub — approve the request (code ${start.user_code}). Waiting…`);
|
||||
try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {}
|
||||
const deadline = Date.now() + (start.expires_in || 900) * 1000;
|
||||
const stepMs = Math.max((start.interval || 5), 2) * 1000;
|
||||
const poll = async () => {
|
||||
if (Date.now() > deadline) { await _setupReply('Copilot sign-in expired — run /setup copilot again.'); return; }
|
||||
try {
|
||||
const fd = new FormData(); fd.append('poll_id', start.poll_id);
|
||||
const r = await fetch(`${API_BASE}/api/copilot/device/poll`, { method: 'POST', body: fd, credentials: 'same-origin' });
|
||||
const d = await r.json();
|
||||
if (d.status === 'authorized') {
|
||||
const n = ((d.endpoint && d.endpoint.models) || []).length;
|
||||
await _setupReply(`Connected — ${n} Copilot model${n !== 1 ? 's' : ''} available.`);
|
||||
if (modelsModule) modelsModule.refreshModels(true);
|
||||
return;
|
||||
}
|
||||
if (d.status === 'failed') { await _setupReply('Copilot sign-in failed (' + (d.error || 'denied') + ').'); return; }
|
||||
} catch (e) { /* transient — keep polling */ }
|
||||
setTimeout(poll, stepMs);
|
||||
};
|
||||
setTimeout(poll, stepMs);
|
||||
const result = await runProviderDeviceFlow(providerKey, {
|
||||
onStart: async ({ start, authUrl }) => {
|
||||
const place = providerKey === 'copilot' ? 'GitHub' : 'OpenAI';
|
||||
const action = providerKey === 'copilot' ? 'approve the request' : 'enter the code';
|
||||
if (providerKey === 'chatgpt-subscription') {
|
||||
slashReply(
|
||||
'<div class="setup-guide-no-censor" style="display:grid;gap:6px;">' +
|
||||
'<div>Open this URL in your browser, enter the code, then come back here. Waiting...</div>' +
|
||||
'<div>Code: <code>' + uiModule.esc(start.user_code || '') + '</code></div>' +
|
||||
'<div><a href="' + uiModule.esc(authUrl || '') + '" target="_blank" rel="noopener noreferrer">' + uiModule.esc(authUrl || '') + '</a></div>' +
|
||||
'</div>'
|
||||
);
|
||||
return;
|
||||
}
|
||||
await _setupReply(`Opening ${place} - ${action} (code ${start.user_code}). Waiting...`);
|
||||
},
|
||||
openWindow: (url) => {
|
||||
if (providerKey === 'chatgpt-subscription') return;
|
||||
try { if (url) window.open(url, '_blank', 'noopener'); } catch (e) {}
|
||||
},
|
||||
});
|
||||
if (result.status === 'authorized') {
|
||||
const n = ((result.endpoint && result.endpoint.models) || []).length;
|
||||
await _setupReply(`Connected - ${n} ${config.label} model${n !== 1 ? 's' : ''} available.`);
|
||||
if (modelsModule) modelsModule.refreshModels(true);
|
||||
return;
|
||||
}
|
||||
if (result.status === 'failed') {
|
||||
await _setupReply(`${config.label} sign-in failed (${result.error || 'denied'}).`);
|
||||
return;
|
||||
}
|
||||
if (result.status === 'expired') {
|
||||
await _setupReply(`${config.label} sign-in expired - run /setup ${providerKey} again.`);
|
||||
return;
|
||||
}
|
||||
} catch (e) {
|
||||
await _setupReply(formatDeviceFlowError(e));
|
||||
}
|
||||
}
|
||||
|
||||
async function _cmdSetup(args, ctx) {
|
||||
@@ -4823,7 +5069,11 @@ async function _cmdSetup(args, ctx) {
|
||||
_clearSetupCommandInput();
|
||||
const topic = (args[0] || '').trim().toLowerCase();
|
||||
const topicArgs = args.slice(1);
|
||||
if (topic === 'copilot' || topic === 'github') { await _setupCopilot(); return true; }
|
||||
const deviceAuthProvider = _setupDeviceAuthProviderFromInput(topic);
|
||||
if (deviceAuthProvider) {
|
||||
await _setupProviderDeviceFlow(deviceAuthProvider);
|
||||
return true;
|
||||
}
|
||||
const provider = _setupProviderFromInput(topic);
|
||||
if (provider) {
|
||||
_clearSetupGuideMessages();
|
||||
@@ -5463,8 +5713,20 @@ async function _cmdHelp(args, ctx) {
|
||||
lines.push('');
|
||||
}
|
||||
}
|
||||
const skillCommands = await _loadSkillSlashCatalog(false);
|
||||
if (skillCommands.length) {
|
||||
lines.push('Skills:');
|
||||
for (const skill of skillCommands.slice(0, 20)) {
|
||||
const token = String(skill.token || '').padEnd(21);
|
||||
lines.push(` ${ctx.esc(token)}${ctx.esc(skill.help || '')}`);
|
||||
}
|
||||
if (skillCommands.length > 20) {
|
||||
lines.push(` ... ${skillCommands.length - 20} more. Use /skills list`);
|
||||
}
|
||||
lines.push('');
|
||||
}
|
||||
lines.push('Tip: /<command> --help for details');
|
||||
lines.push('Shortcuts: /new /rename /fork /web /bash /memories /forget');
|
||||
lines.push('Shortcuts: /new /rename /fork /web /bash /memories /skills');
|
||||
slashReply(`<pre style="line-height:1.7">${lines.join('\n')}</pre>`);
|
||||
return true;
|
||||
}
|
||||
@@ -5539,6 +5801,20 @@ const COMMANDS = {
|
||||
'search': { handler: _cmdMemorySearch, alias: ['grep'], help: 'Search memories', usage: '/memory search q' }
|
||||
}
|
||||
},
|
||||
skills: {
|
||||
alias: ['skill'],
|
||||
category: 'Memory',
|
||||
help: 'List, search, inspect, or run skills',
|
||||
handler: _cmdSkills,
|
||||
usage: '/skills list | search query | view name | use name request',
|
||||
},
|
||||
'reload-skills': {
|
||||
alias: ['reload_skills'],
|
||||
category: 'Memory',
|
||||
help: 'Refresh the slash skill catalog',
|
||||
handler: _cmdReloadSkills,
|
||||
usage: '/reload-skills',
|
||||
},
|
||||
rag: {
|
||||
alias: [],
|
||||
category: 'RAG',
|
||||
@@ -5572,7 +5848,7 @@ const COMMANDS = {
|
||||
category: 'Getting started',
|
||||
help: 'Add local or API model endpoints',
|
||||
handler: _cmdSetup,
|
||||
usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup endpoint',
|
||||
usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup chatgpt-subscription',
|
||||
// Provider subs so the autocomplete popup surfaces "/setup deepseek",
|
||||
// "/setup openai", etc. when the user types "/setup de". Each sub's
|
||||
// handler is a thin wrapper that re-prepends the sub name and
|
||||
@@ -5590,6 +5866,7 @@ const COMMANDS = {
|
||||
xai: { help: 'xAI (Grok)', alias: ['grok'], usage: '/setup xai xai-...', handler: (a, c) => _cmdSetup(['xai', ...a], c) },
|
||||
ollama: { help: 'Ollama Cloud', usage: '/setup ollama KEY', handler: (a, c) => _cmdSetup(['ollama', ...a], c) },
|
||||
copilot: { help: 'GitHub Copilot', usage: '/setup copilot', handler: (a, c) => _cmdSetup(['copilot', ...a], c) },
|
||||
'chatgpt-subscription': { help: 'ChatGPT Subscription', alias: ['codex'], usage: '/setup chatgpt-subscription', handler: (a, c) => _cmdSetup(['chatgpt-subscription', ...a], c) },
|
||||
local: { help: 'Local model server (vLLM / LM Studio / llama.cpp / Ollama)',
|
||||
usage: '/setup local http://localhost:8000/v1',
|
||||
handler: (a, c) => _cmdSetup(['local', ...a], c) },
|
||||
@@ -5767,8 +6044,22 @@ const COMMANDS = {
|
||||
handler: (args, ctx) => _cmdToolPanel('compare', args, ctx),
|
||||
usage: '/compare'
|
||||
},
|
||||
mcp: {
|
||||
alias: [],
|
||||
category: 'Tools',
|
||||
help: 'Show MCP server status',
|
||||
handler: _cmdMcp,
|
||||
usage: '/mcp'
|
||||
},
|
||||
model: {
|
||||
alias: [],
|
||||
category: 'Settings',
|
||||
help: 'Show current chat model',
|
||||
handler: _cmdModel,
|
||||
usage: '/model · /model list'
|
||||
},
|
||||
models: {
|
||||
alias: ['model'],
|
||||
alias: [],
|
||||
category: 'Settings',
|
||||
help: 'List available models',
|
||||
handler: _cmdModels,
|
||||
@@ -5799,10 +6090,16 @@ const COMMANDS = {
|
||||
handler: _cmdStats,
|
||||
usage: '/stats'
|
||||
},
|
||||
usage: {
|
||||
alias: ['cost', 'tokens'],
|
||||
category: 'Utility',
|
||||
help: 'Show local usage for the current chat',
|
||||
handler: _cmdUsage,
|
||||
usage: '/usage'
|
||||
},
|
||||
compact: {
|
||||
alias: [],
|
||||
category: 'Utility',
|
||||
hidden: true,
|
||||
help: 'Compact older chat messages',
|
||||
handler: _cmdCompact,
|
||||
usage: '/compact'
|
||||
@@ -6075,33 +6372,13 @@ async function handleSlashCommand(input) {
|
||||
}
|
||||
|
||||
// --- 4. Skill invocation: /<skill-name> [request] ---
|
||||
// If `rawCmd` matches a published skill, pin its SKILL.md to the user's
|
||||
// message and re-submit. Lets you fire a stored procedure on demand
|
||||
// without the model having to discover the skill itself.
|
||||
// If `rawCmd` matches a published skill, the backend records usage and
|
||||
// returns a skill-pinned message to submit as the next agent turn.
|
||||
try {
|
||||
const skillRes = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(rawCmd)}/markdown`, { credentials: 'same-origin' });
|
||||
if (skillRes.ok) {
|
||||
const skillData = await skillRes.json();
|
||||
const md = skillData.markdown || '';
|
||||
if (md) {
|
||||
_showUser();
|
||||
const request = args.join(' ').trim();
|
||||
const msgInput = document.getElementById('message');
|
||||
const composed =
|
||||
`Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n` +
|
||||
`--- BEGIN SKILL ---\n${md}\n--- END SKILL ---\n\n` +
|
||||
(request ? `Request: ${request}` : `Request: (use the skill as appropriate)`);
|
||||
if (msgInput) {
|
||||
msgInput.value = composed;
|
||||
const form = document.getElementById('chat-form');
|
||||
if (form && typeof form.requestSubmit === 'function') {
|
||||
form.requestSubmit();
|
||||
} else if (form) {
|
||||
form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true }));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
const catalog = await _loadSkillSlashCatalog(false);
|
||||
if (catalog.some(s => s.name === rawCmd)) {
|
||||
_showUser();
|
||||
return await _invokeSkillByName(rawCmd, args.join(' ').trim(), ctx);
|
||||
}
|
||||
} catch (_) { /* fall through to fuzzy match */ }
|
||||
|
||||
@@ -6158,10 +6435,13 @@ export function initSlashCommands(deps) {
|
||||
const providerEl = e.target.closest('.setup-clickable-provider');
|
||||
if (providerEl) {
|
||||
e.preventDefault();
|
||||
const providerKey = providerEl.dataset.setupProvider || providerEl.textContent.trim();
|
||||
const providerName = providerEl.textContent.trim();
|
||||
const messageInput = document.getElementById('message');
|
||||
if (messageInput) {
|
||||
const text = providerName + ' sk-';
|
||||
const text = providerEl.dataset.setupKind === 'device-auth'
|
||||
? '/setup ' + providerKey
|
||||
: providerName + ' sk-';
|
||||
messageInput.value = text;
|
||||
messageInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
messageInput.focus();
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
"""Static regressions for Add Models provider device-flow UX."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_INDEX = (_REPO / "static" / "index.html").read_text(encoding="utf-8")
|
||||
_ADMIN = (_REPO / "static" / "js" / "admin.js").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _between(src: str, start: str, end: str) -> str:
|
||||
start_idx = src.index(start)
|
||||
end_idx = src.index(end, start_idx)
|
||||
return src[start_idx:end_idx]
|
||||
|
||||
|
||||
def test_copilot_and_chatgpt_subscription_are_dropdown_device_auth_options():
|
||||
assert 'value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot' in _INDEX
|
||||
assert 'value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription' in _INDEX
|
||||
assert 'id="adm-deviceAuthStatus"' in _INDEX
|
||||
|
||||
|
||||
def test_provider_selection_is_inert_and_add_button_starts_device_flow():
|
||||
change_block = _between(_ADMIN, "provider.addEventListener('change'", "urlInput.addEventListener('input'")
|
||||
add_block = _between(_ADMIN, "el('adm-epAddBtn').addEventListener('click'", "async function _startProviderDeviceAuth")
|
||||
|
||||
assert "_startProviderDeviceAuth" not in change_block
|
||||
assert "_startProviderDeviceAuth(deviceAuthProvider" in add_block
|
||||
|
||||
|
||||
def test_device_auth_selection_disables_and_dims_api_test_button():
|
||||
form_block = _between(_ADMIN, "function _setApiFormForProvider()", "function _renderPickerMenu()")
|
||||
|
||||
assert "testBtn.disabled = true" in form_block
|
||||
assert "testBtn.style.opacity = '0.45'" in form_block
|
||||
assert "testBtn.style.cursor = 'not-allowed'" in form_block
|
||||
assert "testBtn.disabled = false" in form_block
|
||||
assert "testBtn.style.opacity = ''" in form_block
|
||||
assert "testBtn.style.cursor = ''" in form_block
|
||||
|
||||
|
||||
def test_device_auth_keeps_manual_auth_button_without_auto_opening_tab():
|
||||
auth_block = _between(_ADMIN, "async function _startProviderDeviceAuth", "// Local \"Add\" button")
|
||||
|
||||
assert "Authorize with OpenAI" in auth_block
|
||||
assert "Authorize on GitHub" in auth_block
|
||||
assert "adm-copilot-panel" in auth_block
|
||||
assert "adm-device-auth-copy" in auth_block
|
||||
assert "openWindow: () => {}" in auth_block
|
||||
assert "A new tab opened" not in auth_block
|
||||
|
||||
|
||||
def test_loud_oauth_copy_and_removed_button_hooks_do_not_return():
|
||||
forbidden = [
|
||||
"Click Add to start",
|
||||
"uses account sign-in",
|
||||
"Uses ChatGPT/Codex OAuth, not an OpenAI API key.",
|
||||
"adm-chatgptStatus",
|
||||
"adm-chatgptConnectBtn",
|
||||
"adm-copilotConnectBtn",
|
||||
"adm-copilotStatus",
|
||||
]
|
||||
for needle in forbidden:
|
||||
assert needle not in _INDEX
|
||||
assert needle not in _ADMIN
|
||||
@@ -0,0 +1,280 @@
|
||||
"""DB-backed ChatGPT Subscription endpoint provisioning tests."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.database import Base, ModelEndpoint, ProviderAuthSession
|
||||
import routes.chatgpt_subscription_routes as csr
|
||||
|
||||
|
||||
def _mem_db(monkeypatch):
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
# Match production (core.database SessionLocal is autoflush=False): a pending
|
||||
# db.delete(ep) is NOT flushed before the orphan-auth reference-count SELECT,
|
||||
# which is exactly why _delete_orphaned_provider_auth needs exclude_ep_id.
|
||||
TestSessionLocal = sessionmaker(bind=engine, autoflush=False)
|
||||
monkeypatch.setattr(csr, "SessionLocal", TestSessionLocal)
|
||||
return TestSessionLocal
|
||||
|
||||
|
||||
def test_provision_creates_owner_scoped_auth_session_and_endpoint(monkeypatch):
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5", "o4-mini"])
|
||||
|
||||
res = csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice")
|
||||
|
||||
assert res["name"] == "ChatGPT Subscription"
|
||||
assert res["base_url"] == csr.chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||
assert res["models"] == ["gpt-5.5", "o4-mini"]
|
||||
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
auth = db.query(ProviderAuthSession).first()
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == res["id"]).first()
|
||||
assert auth is not None
|
||||
assert auth.owner == "alice"
|
||||
assert auth.provider == csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER
|
||||
assert auth.access_token == "AT"
|
||||
assert auth.refresh_token == "RT"
|
||||
assert auth.auth_mode == "chatgpt"
|
||||
assert ep is not None
|
||||
assert ep.owner == "alice"
|
||||
assert ep.api_key is None
|
||||
assert ep.provider_auth_id == auth.id
|
||||
assert ep.endpoint_kind == "api"
|
||||
assert ep.model_refresh_mode == "manual"
|
||||
assert ep.supports_tools is False
|
||||
assert json.loads(ep.cached_models) == ["gpt-5.5", "o4-mini"]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_provision_refreshes_existing_auth_session_and_endpoint(monkeypatch):
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5"])
|
||||
|
||||
first = csr._provision_endpoint({"access_token": "OLD", "refresh_token": "OLD-RT"}, "bob")
|
||||
second = csr._provision_endpoint({"access_token": "NEW", "refresh_token": "NEW-RT"}, "bob")
|
||||
|
||||
assert first["id"] == second["id"]
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
auth_rows = db.query(ProviderAuthSession).filter(ProviderAuthSession.owner == "bob").all()
|
||||
ep_rows = db.query(ModelEndpoint).filter(ModelEndpoint.owner == "bob").all()
|
||||
assert len(auth_rows) == 1
|
||||
assert len(ep_rows) == 1
|
||||
assert auth_rows[0].access_token == "NEW"
|
||||
assert auth_rows[0].refresh_token == "NEW-RT"
|
||||
assert ep_rows[0].provider_auth_id == auth_rows[0].id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_provision_rejects_missing_tokens(monkeypatch):
|
||||
_mem_db(monkeypatch)
|
||||
with pytest.raises(ValueError, match="missing access_token or refresh_token"):
|
||||
csr._provision_endpoint({"access_token": "AT"}, "alice")
|
||||
|
||||
|
||||
def test_provision_rejects_accounts_without_usable_models(monkeypatch):
|
||||
_mem_db(monkeypatch)
|
||||
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: [])
|
||||
|
||||
with pytest.raises(ValueError, match="no usable Codex models"):
|
||||
csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice")
|
||||
|
||||
|
||||
def _add_auth_and_endpoints(db, *, auth_id="auth1", ep_ids=("ep1",)):
|
||||
db.add(ProviderAuthSession(
|
||||
id=auth_id, provider=csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||
owner="alice", base_url="https://chatgpt.com/backend-api/codex",
|
||||
refresh_token="RT", auth_mode="chatgpt",
|
||||
))
|
||||
for ep_id in ep_ids:
|
||||
db.add(ModelEndpoint(
|
||||
id=ep_id, name="ChatGPT Subscription",
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
provider_auth_id=auth_id, owner="alice",
|
||||
))
|
||||
db.commit()
|
||||
|
||||
|
||||
def test_delete_orphaned_provider_auth_revokes_when_last_endpoint_removed(monkeypatch):
|
||||
from routes.model_routes import _delete_orphaned_provider_auth
|
||||
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
|
||||
# Mirror the production delete route: db.delete(ep) is issued (but not yet
|
||||
# flushed/committed) BEFORE the orphan check runs.
|
||||
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||
db.delete(ep1)
|
||||
# ep1 (its only referencing endpoint) is being deleted, so the auth clears.
|
||||
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is True
|
||||
db.commit()
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_delete_orphaned_provider_auth_requires_exclude_ep_id_for_pending_delete(monkeypatch):
|
||||
from routes.model_routes import _delete_orphaned_provider_auth
|
||||
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
|
||||
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||
db.delete(ep1)
|
||||
# Without exclude_ep_id, the un-flushed pending delete leaves ep1 visible
|
||||
# to the reference-count SELECT (autoflush=False), so the helper must
|
||||
# conservatively KEEP the auth row. This is the bug exclude_ep_id fixes.
|
||||
assert _delete_orphaned_provider_auth(db, "auth1") is False
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_delete_orphaned_provider_auth_keeps_auth_while_another_endpoint_uses_it(monkeypatch):
|
||||
from routes.model_routes import _delete_orphaned_provider_auth
|
||||
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
|
||||
# ep2 still references auth1, so deleting ep1 must NOT revoke it.
|
||||
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_delete_orphaned_provider_auth_noop_without_auth_id(monkeypatch):
|
||||
from routes.model_routes import _delete_orphaned_provider_auth
|
||||
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
assert _delete_orphaned_provider_auth(db, None, exclude_ep_id="ep1") is False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_delete_orphaned_provider_auth_noop_when_auth_row_missing(monkeypatch):
|
||||
from routes.model_routes import _delete_orphaned_provider_auth
|
||||
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
# Endpoint points at an auth_id whose ProviderAuthSession is already gone.
|
||||
db.add(ModelEndpoint(
|
||||
id="ep1", name="ChatGPT Subscription",
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
provider_auth_id="ghost", owner="alice",
|
||||
))
|
||||
db.commit()
|
||||
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||
db.delete(ep1)
|
||||
# No other endpoint references "ghost" and no auth row exists → no-op, no error.
|
||||
assert _delete_orphaned_provider_auth(db, "ghost", exclude_ep_id="ep1") is False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _delete_route(monkeypatch, TestSessionLocal):
|
||||
"""Resolve the real DELETE /model-endpoints/{ep_id} route, wired to the test DB.
|
||||
|
||||
Neutralizes the route's unrelated cleanup side effects (settings/prefs files,
|
||||
in-memory session manager) so the test stays hermetic and focuses on the
|
||||
provider-auth revocation wiring.
|
||||
"""
|
||||
import routes.model_routes as mr
|
||||
import routes.prefs_routes as prefs_routes
|
||||
import src.ai_interaction as ai_interaction
|
||||
|
||||
monkeypatch.setattr(mr, "SessionLocal", TestSessionLocal)
|
||||
monkeypatch.setattr(mr, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(mr, "_load_settings", lambda: {})
|
||||
monkeypatch.setattr(mr, "_save_settings", lambda settings: None)
|
||||
monkeypatch.setattr(prefs_routes, "_load", lambda: {})
|
||||
monkeypatch.setattr(prefs_routes, "_save", lambda prefs: None)
|
||||
monkeypatch.setattr(ai_interaction, "get_session_manager", lambda: None)
|
||||
|
||||
router = mr.setup_model_routes(model_discovery=None)
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == "/api/model-endpoints/{ep_id}" and "DELETE" in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
raise AssertionError("DELETE /api/model-endpoints/{ep_id} not found")
|
||||
|
||||
|
||||
def test_delete_endpoint_route_revokes_orphaned_provider_auth(monkeypatch):
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
delete_endpoint = _delete_route(monkeypatch, TestSessionLocal)
|
||||
result = delete_endpoint("ep1", object())
|
||||
|
||||
assert result["deleted"] is True
|
||||
# The last (only) endpoint backed by auth1 is gone, so the route revokes it.
|
||||
assert result["cleared_provider_auth"] is True
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
|
||||
assert db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() is None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_delete_endpoint_route_keeps_auth_when_shared(monkeypatch):
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
delete_endpoint = _delete_route(monkeypatch, TestSessionLocal)
|
||||
result = delete_endpoint("ep1", object())
|
||||
|
||||
assert result["deleted"] is True
|
||||
# ep2 still references auth1, so deleting ep1 must NOT revoke the credentials.
|
||||
assert result["cleared_provider_auth"] is False
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_delete_orphaned_provider_auth_revokes_only_after_last_of_several(monkeypatch):
|
||||
from routes.model_routes import _delete_orphaned_provider_auth
|
||||
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
|
||||
|
||||
# Delete ep1 first: ep2 still references auth1, so the row survives.
|
||||
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||
db.delete(ep1)
|
||||
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False
|
||||
db.commit()
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||
|
||||
# Now delete the last endpoint ep2: the auth row is finally cleared.
|
||||
ep2 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep2").first()
|
||||
db.delete(ep2)
|
||||
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep2") is True
|
||||
db.commit()
|
||||
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,138 @@
|
||||
"""Shared device-flow route helper regressions."""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from routes import device_flow
|
||||
|
||||
|
||||
def _client(monkeypatch, now_ref, start_flow, poll_flow):
|
||||
store = device_flow.PendingDeviceFlowStore(time_func=lambda: now_ref[0])
|
||||
router = device_flow.create_device_flow_router(
|
||||
prefix="/api/test-device",
|
||||
tags=["test-device"],
|
||||
store=store,
|
||||
start_flow=start_flow,
|
||||
poll_flow=poll_flow,
|
||||
)
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
monkeypatch.setattr(device_flow, "require_admin", lambda request: None)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _start(_request, _form):
|
||||
return device_flow.DeviceFlowStart(
|
||||
pending={"secret": "server-only", "owner": "alice"},
|
||||
response={"user_code": "ABCD-EFGH", "verification_uri": "https://example.test/device"},
|
||||
interval=5,
|
||||
expires_in=20,
|
||||
)
|
||||
|
||||
|
||||
def test_pending_poll_is_throttled_until_interval(monkeypatch):
|
||||
now = [100.0]
|
||||
calls = []
|
||||
|
||||
def poll(_request, pending):
|
||||
calls.append(dict(pending))
|
||||
return device_flow.DeviceFlowPoll.pending()
|
||||
|
||||
client = _client(monkeypatch, now, _start, poll)
|
||||
start = client.post("/api/test-device/device/start").json()
|
||||
|
||||
first = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
|
||||
assert first.json() == {"status": "pending"}
|
||||
assert calls == [{"secret": "server-only", "owner": "alice"}]
|
||||
|
||||
second = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
|
||||
assert second.json() == {"status": "pending"}
|
||||
assert len(calls) == 1
|
||||
|
||||
now[0] += 5
|
||||
third = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
|
||||
assert third.json() == {"status": "pending"}
|
||||
assert len(calls) == 2
|
||||
|
||||
|
||||
def test_slow_down_updates_poll_interval(monkeypatch):
|
||||
now = [100.0]
|
||||
calls = []
|
||||
|
||||
def poll(_request, _pending):
|
||||
calls.append(now[0])
|
||||
if len(calls) == 1:
|
||||
return device_flow.DeviceFlowPoll.slow_down(interval=10)
|
||||
return device_flow.DeviceFlowPoll.authorized({"id": "ep1", "models": ["gpt-4o"]})
|
||||
|
||||
client = _client(monkeypatch, now, _start, poll)
|
||||
poll_id = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"}
|
||||
now[0] += 9
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"}
|
||||
assert len(calls) == 1
|
||||
|
||||
now[0] += 1
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {
|
||||
"status": "authorized",
|
||||
"endpoint": {"id": "ep1", "models": ["gpt-4o"]},
|
||||
}
|
||||
|
||||
|
||||
def test_authorized_and_failed_polls_remove_pending_session(monkeypatch):
|
||||
now = [100.0]
|
||||
outcomes = [
|
||||
device_flow.DeviceFlowPoll.authorized({"id": "ep1"}),
|
||||
device_flow.DeviceFlowPoll.failed("access_denied"),
|
||||
]
|
||||
|
||||
def poll(_request, _pending):
|
||||
return outcomes.pop(0)
|
||||
|
||||
client = _client(monkeypatch, now, _start, poll)
|
||||
first = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||
second = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": first}).json()["status"] == "authorized"
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": first}).status_code == 404
|
||||
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": second}).json() == {
|
||||
"status": "failed",
|
||||
"error": "access_denied",
|
||||
}
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": second}).status_code == 404
|
||||
|
||||
|
||||
def test_cancel_and_expiry_remove_pending_session(monkeypatch):
|
||||
now = [100.0]
|
||||
|
||||
def poll(_request, _pending):
|
||||
return device_flow.DeviceFlowPoll.pending()
|
||||
|
||||
client = _client(monkeypatch, now, _start, poll)
|
||||
cancelled = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||
assert client.post("/api/test-device/device/cancel", data={"poll_id": cancelled}).json() == {"status": "cancelled"}
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": cancelled}).status_code == 404
|
||||
|
||||
expired = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||
now[0] += 21
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": expired}).status_code == 404
|
||||
|
||||
|
||||
def test_routes_are_admin_gated(monkeypatch):
|
||||
now = [100.0]
|
||||
|
||||
def poll(_request, _pending):
|
||||
return device_flow.DeviceFlowPoll.pending()
|
||||
|
||||
client = _client(monkeypatch, now, _start, poll)
|
||||
|
||||
def deny(_request):
|
||||
raise HTTPException(403, "admin required")
|
||||
|
||||
monkeypatch.setattr(device_flow, "require_admin", deny)
|
||||
assert client.post("/api/test-device/device/start").status_code == 403
|
||||
assert client.post("/api/test-device/device/poll", data={"poll_id": "missing"}).status_code == 403
|
||||
assert client.post("/api/test-device/device/cancel", data={"poll_id": "missing"}).status_code == 403
|
||||
@@ -25,32 +25,36 @@ from unittest.mock import MagicMock
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules
|
||||
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state
|
||||
|
||||
# Match test_model_routes.py: if another test stubbed src.endpoint_resolver
|
||||
# during collection, drop the stub so the real URL helpers load here.
|
||||
clear_fake_endpoint_resolver_modules()
|
||||
with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"):
|
||||
# Match test_model_routes.py: if another test stubbed src.endpoint_resolver
|
||||
# during collection, drop the stub so the real URL helpers load here.
|
||||
clear_fake_endpoint_resolver_modules()
|
||||
|
||||
if "core.database" not in sys.modules:
|
||||
_core_db = types.ModuleType("core.database")
|
||||
for _name in [
|
||||
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
||||
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
||||
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer",
|
||||
]:
|
||||
setattr(_core_db, _name, MagicMock())
|
||||
sys.modules["core.database"] = _core_db
|
||||
if "core.database" not in sys.modules:
|
||||
_core_db = types.ModuleType("core.database")
|
||||
for _name in [
|
||||
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
||||
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
||||
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer",
|
||||
"ProviderAuthSession", "Base",
|
||||
]:
|
||||
setattr(_core_db, _name, MagicMock())
|
||||
_core_db.utcnow_naive = MagicMock()
|
||||
sys.modules["core.database"] = _core_db
|
||||
|
||||
import routes.model_routes as model_routes
|
||||
import src.endpoint_resolver as endpoint_resolver
|
||||
from routes.model_routes import (
|
||||
_probe_endpoint,
|
||||
_ping_endpoint,
|
||||
_probe_single_model,
|
||||
_classify_endpoint,
|
||||
_rewrite_loopback_for_docker,
|
||||
_PROVIDER_CURATED,
|
||||
)
|
||||
import routes.model_routes as model_routes
|
||||
import src.endpoint_resolver as endpoint_resolver
|
||||
from routes.model_routes import (
|
||||
_probe_endpoint,
|
||||
_ping_endpoint,
|
||||
_probe_single_model,
|
||||
_resolve_probe_key,
|
||||
_classify_endpoint,
|
||||
_rewrite_loopback_for_docker,
|
||||
_PROVIDER_CURATED,
|
||||
)
|
||||
|
||||
|
||||
def _patch_resolve(monkeypatch):
|
||||
@@ -117,6 +121,26 @@ class TestProbeEndpointParsing:
|
||||
)
|
||||
assert _probe_endpoint("https://api.example.com/v1") == []
|
||||
|
||||
def test_chatgpt_subscription_probe_uses_discovery_only(self, monkeypatch):
|
||||
_patch_resolve(monkeypatch)
|
||||
calls = []
|
||||
|
||||
def fake_fetch(access_token, timeout=5):
|
||||
calls.append((access_token, timeout))
|
||||
return ["gpt-5.5"]
|
||||
|
||||
monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", fake_fetch)
|
||||
|
||||
assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS", timeout=7) == ["gpt-5.5"]
|
||||
assert calls == [("ACCESS", 7)]
|
||||
|
||||
def test_chatgpt_subscription_probe_without_discovery_returns_empty(self, monkeypatch):
|
||||
_patch_resolve(monkeypatch)
|
||||
monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", lambda access_token, timeout=5: [])
|
||||
|
||||
assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS") == []
|
||||
assert _probe_endpoint("https://chatgpt.com/backend-api/codex") == []
|
||||
|
||||
|
||||
# ── _ping_endpoint: reachability classification ──
|
||||
|
||||
@@ -321,6 +345,51 @@ class TestProbeSingleModel:
|
||||
_probe_single_model("https://api.anthropic.com/v1", "sk-ant", "claude-sonnet-4-5", with_tools=True)
|
||||
assert "input_schema" in captured["payload"]["tools"][0]
|
||||
|
||||
def test_chatgpt_subscription_skips_completion_probe(self, monkeypatch):
|
||||
# This provider speaks the Responses/Codex API. A chat-completions probe
|
||||
# would 400 and (via the re-probe flow) hide every model, so it must be
|
||||
# short-circuited as discovery-only without any HTTP call.
|
||||
_patch_resolve(monkeypatch)
|
||||
|
||||
def boom(*args, **kwargs):
|
||||
raise AssertionError("must not send a completion probe for chatgpt-subscription")
|
||||
|
||||
monkeypatch.setattr(model_routes.httpx, "post", boom)
|
||||
result = _probe_single_model("https://chatgpt.com/backend-api/codex", None, "gpt-5.1-codex")
|
||||
assert result["status"] == "ok"
|
||||
assert result.get("skipped") is True
|
||||
# Pin the full documented return shape — downstream JSON/UI reads latency_ms.
|
||||
assert result["latency_ms"] == 0
|
||||
|
||||
|
||||
# ── _resolve_probe_key: static key vs provider-auth runtime token ──
|
||||
|
||||
class TestResolveProbeKey:
|
||||
def test_static_endpoint_uses_api_key(self):
|
||||
ep = types.SimpleNamespace(id="e1", api_key="sk-static", provider_auth_id=None, owner=None)
|
||||
assert _resolve_probe_key(ep) == "sk-static"
|
||||
|
||||
def test_provider_auth_endpoint_resolves_runtime_token(self, monkeypatch):
|
||||
ep = types.SimpleNamespace(id="e2", api_key=None, provider_auth_id="auth123", owner="alice")
|
||||
seen = {}
|
||||
|
||||
def fake_runtime(endpoint, owner=None):
|
||||
seen["owner"] = owner
|
||||
return ("https://chatgpt.com/backend-api/codex", "live-bearer")
|
||||
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", fake_runtime)
|
||||
assert _resolve_probe_key(ep) == "live-bearer"
|
||||
assert seen["owner"] == "alice"
|
||||
|
||||
def test_provider_auth_resolution_failure_returns_none(self, monkeypatch):
|
||||
ep = types.SimpleNamespace(id="e3", api_key=None, provider_auth_id="auth123", owner=None)
|
||||
|
||||
def boom(endpoint, owner=None):
|
||||
raise RuntimeError("reauth required")
|
||||
|
||||
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", boom)
|
||||
assert _resolve_probe_key(ep) is None
|
||||
|
||||
|
||||
# ── _classify_endpoint: Tailscale CGNAT range ──
|
||||
|
||||
|
||||
@@ -75,6 +75,28 @@ def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
|
||||
assert payload["temperature"] == 1.2
|
||||
|
||||
|
||||
def test_chatgpt_subscription_payload_uses_max_output_tokens():
|
||||
payload = llm_core._build_chatgpt_responses_payload(
|
||||
"gpt-5.1-codex",
|
||||
[{"role": "user", "content": "Say OK"}],
|
||||
temperature=0.2,
|
||||
max_tokens=37,
|
||||
)
|
||||
|
||||
assert payload["max_output_tokens"] == 37
|
||||
|
||||
|
||||
def test_chatgpt_subscription_payload_omits_empty_max_output_tokens():
|
||||
payload = llm_core._build_chatgpt_responses_payload(
|
||||
"gpt-5.1-codex",
|
||||
[{"role": "user", "content": "Say OK"}],
|
||||
temperature=0.2,
|
||||
max_tokens=0,
|
||||
)
|
||||
|
||||
assert "max_output_tokens" not in payload
|
||||
|
||||
|
||||
def _anthropic_payload(temperature):
|
||||
return llm_core._build_anthropic_payload(
|
||||
"claude-3-5-sonnet",
|
||||
|
||||
+92
-42
@@ -11,49 +11,51 @@ from types import SimpleNamespace
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules
|
||||
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state
|
||||
|
||||
# Other tests stub this module during collection. These helper tests need
|
||||
# the real URL normalization helpers so Anthropic /v1 handling is covered.
|
||||
clear_fake_endpoint_resolver_modules()
|
||||
with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"):
|
||||
# Other tests stub this module during collection. These helper tests need
|
||||
# the real URL normalization helpers so Anthropic /v1 handling is covered.
|
||||
clear_fake_endpoint_resolver_modules()
|
||||
|
||||
if "core.database" not in sys.modules:
|
||||
_core_db = types.ModuleType("core.database")
|
||||
for _name in [
|
||||
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
||||
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
||||
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun",
|
||||
"McpServer",
|
||||
]:
|
||||
setattr(_core_db, _name, MagicMock())
|
||||
sys.modules["core.database"] = _core_db
|
||||
if "core.database" not in sys.modules:
|
||||
_core_db = types.ModuleType("core.database")
|
||||
for _name in [
|
||||
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
||||
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
||||
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun",
|
||||
"McpServer", "ProviderAuthSession", "Base",
|
||||
]:
|
||||
setattr(_core_db, _name, MagicMock())
|
||||
_core_db.utcnow_naive = MagicMock()
|
||||
sys.modules["core.database"] = _core_db
|
||||
|
||||
import routes.model_routes as model_routes
|
||||
import src.database as src_database
|
||||
import src.endpoint_resolver as endpoint_resolver
|
||||
import src.llm_core as llm_core
|
||||
from routes.model_routes import (
|
||||
_match_provider_curated,
|
||||
_curate_models,
|
||||
_visible_models,
|
||||
_normalize_model_ids,
|
||||
_api_key_fingerprint,
|
||||
_is_chat_model,
|
||||
_classify_endpoint,
|
||||
_effective_endpoint_kind,
|
||||
_probe_endpoint,
|
||||
_ping_endpoint,
|
||||
_parse_model_list,
|
||||
_normalize_refresh_mode,
|
||||
_truthy,
|
||||
_speech_settings_using_endpoint,
|
||||
_clear_speech_settings_for_endpoint,
|
||||
_endpoint_settings_using_endpoint,
|
||||
_clear_endpoint_settings_for_endpoint,
|
||||
_clear_user_pref_endpoint_refs,
|
||||
_PROVIDER_CURATED,
|
||||
)
|
||||
from src.llm_core import ANTHROPIC_MODELS
|
||||
import routes.model_routes as model_routes
|
||||
import src.database as src_database
|
||||
import src.endpoint_resolver as endpoint_resolver
|
||||
import src.llm_core as llm_core
|
||||
from routes.model_routes import (
|
||||
_match_provider_curated,
|
||||
_curate_models,
|
||||
_visible_models,
|
||||
_normalize_model_ids,
|
||||
_api_key_fingerprint,
|
||||
_is_chat_model,
|
||||
_classify_endpoint,
|
||||
_effective_endpoint_kind,
|
||||
_probe_endpoint,
|
||||
_ping_endpoint,
|
||||
_parse_model_list,
|
||||
_normalize_refresh_mode,
|
||||
_truthy,
|
||||
_speech_settings_using_endpoint,
|
||||
_clear_speech_settings_for_endpoint,
|
||||
_endpoint_settings_using_endpoint,
|
||||
_clear_endpoint_settings_for_endpoint,
|
||||
_clear_user_pref_endpoint_refs,
|
||||
_PROVIDER_CURATED,
|
||||
)
|
||||
from src.llm_core import ANTHROPIC_MODELS
|
||||
|
||||
|
||||
# ── speech endpoint settings ──
|
||||
@@ -687,8 +689,7 @@ class _PinnedFakeRequest:
|
||||
|
||||
|
||||
def _get_route(path, method):
|
||||
from routes.model_routes import setup_model_routes
|
||||
router = setup_model_routes(model_discovery=None)
|
||||
router = model_routes.setup_model_routes(model_discovery=None)
|
||||
for route in router.routes:
|
||||
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||
return route.endpoint
|
||||
@@ -787,6 +788,55 @@ def test_reprobe_preserves_pinned_models(monkeypatch):
|
||||
assert json.loads(ep.cached_models) == ["m1"]
|
||||
|
||||
|
||||
def test_reprobe_chatgpt_subscription_does_not_hide_models(monkeypatch):
|
||||
# The whole point of the _probe_single_model short-circuit is that re-probing
|
||||
# a chatgpt-subscription endpoint must NOT mark every (un-probeable) model as
|
||||
# failed and write them all into hidden_models. Assert that end-to-end at the
|
||||
# route level, with the REAL _probe_single_model doing the skip.
|
||||
ep = _make_endpoint(
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
api_key=None,
|
||||
hidden_models=json.dumps(["stale-hidden"]),
|
||||
)
|
||||
db = _PinnedFakeDb([ep])
|
||||
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||
monkeypatch.setattr(model_routes, "_normalize_base", lambda url: url.rstrip("/"))
|
||||
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["gpt-5.1-codex", "gpt-5.1"])
|
||||
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
|
||||
# Any completion probe would be a bug for this provider.
|
||||
monkeypatch.setattr(
|
||||
model_routes.httpx, "post",
|
||||
lambda *a, **k: (_ for _ in ()).throw(AssertionError("must not probe chatgpt-subscription")),
|
||||
)
|
||||
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
|
||||
|
||||
response = endpoint("ep1", _PinnedFakeRequest())
|
||||
chunks = []
|
||||
|
||||
async def _drain():
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(chunk.decode() if isinstance(chunk, bytes) else chunk)
|
||||
|
||||
asyncio.run(_drain())
|
||||
|
||||
events = []
|
||||
for chunk in chunks:
|
||||
for line in chunk.splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[len("data: "):]))
|
||||
|
||||
done = next(e for e in events if e.get("type") == "probe_done")
|
||||
results = [e for e in events if e.get("type") == "probe_result"]
|
||||
|
||||
# Every model was skipped as ok; none failed → nothing hidden.
|
||||
assert done["hidden"] == 0
|
||||
assert done["ok"] == len(results) == 2
|
||||
assert all(r["status"] == "ok" and r.get("skipped") is True for r in results)
|
||||
# The stale hidden_models is cleared, not repopulated with every model.
|
||||
assert ep.hidden_models is None
|
||||
|
||||
|
||||
def test_visible_models_handles_malformed_strings():
|
||||
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
||||
# never raise; a malformed hidden string is normalized too.
|
||||
|
||||
@@ -42,6 +42,10 @@ class TestHostMatch:
|
||||
|
||||
|
||||
class TestDetectProviderRealHosts:
|
||||
def test_chatgpt_subscription_codex_backend(self):
|
||||
assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex") == "chatgpt-subscription"
|
||||
assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex/responses") == "chatgpt-subscription"
|
||||
|
||||
def test_anthropic(self):
|
||||
assert llm_core._detect_provider("https://api.anthropic.com") == "anthropic"
|
||||
|
||||
@@ -93,6 +97,12 @@ class TestBuildersRejectLookalikeHosts:
|
||||
def test_real_anthropic_chat(self):
|
||||
assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages"
|
||||
|
||||
def test_chatgpt_subscription_chat_uses_responses(self):
|
||||
assert build_chat_url("https://chatgpt.com/backend-api/codex") == "https://chatgpt.com/backend-api/codex/responses"
|
||||
|
||||
def test_chatgpt_subscription_models_uses_no_live_probe(self):
|
||||
assert build_models_url("https://chatgpt.com/backend-api/codex") is None
|
||||
|
||||
def test_lookalike_anthropic_chat_is_openai(self):
|
||||
assert build_chat_url("https://notanthropic.com") == "https://notanthropic.com/chat/completions"
|
||||
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Node-driven tests for the shared provider device-flow runner."""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_HELPER = _REPO / "static" / "js" / "providerDeviceFlow.js"
|
||||
pytestmark = pytest.mark.skipif(not shutil.which("node"), reason="node not on PATH")
|
||||
|
||||
|
||||
def _run_node(script: str):
|
||||
proc = subprocess.run(
|
||||
["node", "--input-type=module"],
|
||||
input=script,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(_REPO),
|
||||
timeout=30,
|
||||
)
|
||||
assert proc.returncode == 0, proc.stderr
|
||||
return json.loads(proc.stdout.strip())
|
||||
|
||||
|
||||
def test_copilot_success_uses_complete_verification_uri():
|
||||
js = f"""
|
||||
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||
const calls = [];
|
||||
const opened = [];
|
||||
let polls = 0;
|
||||
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||
const fetchImpl = async (url) => {{
|
||||
calls.push(url);
|
||||
if (url.endsWith('/device/start')) {{
|
||||
return response(true, 200, {{
|
||||
poll_id: 'poll-1',
|
||||
user_code: 'GH-CODE',
|
||||
verification_uri: 'https://github.com/login/device',
|
||||
verification_uri_complete: 'https://github.com/login/device?user_code=GH-CODE',
|
||||
interval: 2,
|
||||
expires_in: 30,
|
||||
}});
|
||||
}}
|
||||
polls += 1;
|
||||
return response(true, 200, polls === 1
|
||||
? {{ status: 'pending' }}
|
||||
: {{ status: 'authorized', endpoint: {{ id: 'ep1', models: ['gpt-4o'] }} }}
|
||||
);
|
||||
}};
|
||||
const result = await runProviderDeviceFlow('copilot', {{
|
||||
fetchImpl,
|
||||
openWindow: (url) => opened.push(url),
|
||||
sleep: async () => {{}},
|
||||
now: () => 0,
|
||||
}});
|
||||
console.log(JSON.stringify({{ result, calls, opened }}));
|
||||
"""
|
||||
out = _run_node(js)
|
||||
assert out["result"]["status"] == "authorized"
|
||||
assert out["result"]["endpoint"]["id"] == "ep1"
|
||||
assert out["opened"] == ["https://github.com/login/device?user_code=GH-CODE"]
|
||||
assert out["calls"] == ["/api/copilot/device/start", "/api/copilot/device/poll", "/api/copilot/device/poll"]
|
||||
|
||||
|
||||
def test_chatgpt_success_uses_plain_verification_uri():
|
||||
js = f"""
|
||||
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||
const opened = [];
|
||||
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||
const fetchImpl = async (url) => {{
|
||||
if (url.endsWith('/device/start')) {{
|
||||
return response(true, 200, {{
|
||||
poll_id: 'poll-1',
|
||||
user_code: 'OA-CODE',
|
||||
verification_uri: 'https://auth.openai.com/codex/device',
|
||||
interval: 2,
|
||||
expires_in: 30,
|
||||
}});
|
||||
}}
|
||||
return response(true, 200, {{ status: 'authorized', endpoint: {{ id: 'chatgpt', models: ['gpt-5.5'] }} }});
|
||||
}};
|
||||
const result = await runProviderDeviceFlow('chatgpt-subscription', {{
|
||||
fetchImpl,
|
||||
openWindow: (url) => opened.push(url),
|
||||
sleep: async () => {{}},
|
||||
now: () => 0,
|
||||
}});
|
||||
console.log(JSON.stringify({{ result, opened }}));
|
||||
"""
|
||||
out = _run_node(js)
|
||||
assert out["result"]["status"] == "authorized"
|
||||
assert out["opened"] == ["https://auth.openai.com/codex/device"]
|
||||
|
||||
|
||||
def test_start_errors_surface_backend_detail():
|
||||
js = f"""
|
||||
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||
try {{
|
||||
await runProviderDeviceFlow('copilot', {{
|
||||
fetchImpl: async () => response(false, 502, {{ detail: 'GitHub device-code request failed: upstream down' }}),
|
||||
openWindow: () => {{}},
|
||||
sleep: async () => {{}},
|
||||
now: () => 0,
|
||||
}});
|
||||
}} catch (err) {{
|
||||
console.log(JSON.stringify({{ message: err.message }}));
|
||||
}}
|
||||
"""
|
||||
out = _run_node(js)
|
||||
assert out["message"] == "GitHub device-code request failed: upstream down"
|
||||
|
||||
|
||||
def test_thrown_fetch_errors_are_preserved():
|
||||
js = f"""
|
||||
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||
try {{
|
||||
await runProviderDeviceFlow('chatgpt-subscription', {{
|
||||
fetchImpl: async () => {{ throw new Error('network offline'); }},
|
||||
openWindow: () => {{}},
|
||||
sleep: async () => {{}},
|
||||
now: () => 0,
|
||||
}});
|
||||
}} catch (err) {{
|
||||
console.log(JSON.stringify({{ message: err.message }}));
|
||||
}}
|
||||
"""
|
||||
out = _run_node(js)
|
||||
assert out["message"] == "network offline"
|
||||
|
||||
|
||||
def test_expired_flow_returns_expired_status():
|
||||
js = f"""
|
||||
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||
let currentTime = 0;
|
||||
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||
const result = await runProviderDeviceFlow('copilot', {{
|
||||
fetchImpl: async (url) => url.endsWith('/device/start')
|
||||
? response(true, 200, {{
|
||||
poll_id: 'poll-1',
|
||||
user_code: 'GH-CODE',
|
||||
verification_uri: 'https://github.com/login/device',
|
||||
interval: 2,
|
||||
expires_in: 1,
|
||||
}})
|
||||
: response(true, 200, {{ status: 'pending' }}),
|
||||
openWindow: () => {{}},
|
||||
sleep: async () => {{ currentTime += 2000; }},
|
||||
now: () => currentTime,
|
||||
}});
|
||||
console.log(JSON.stringify(result));
|
||||
"""
|
||||
out = _run_node(js)
|
||||
assert out == {"status": "expired"}
|
||||
@@ -24,7 +24,7 @@ _sd = types.ModuleType("src.database")
|
||||
_sd.ModelEndpoint = MagicMock()
|
||||
sys.modules.setdefault("src.database", _sd)
|
||||
|
||||
from routes.research_routes import _owned_enabled_endpoint # noqa: E402
|
||||
from routes.research_routes import _owned_enabled_endpoint, _resolve_endpoint_runtime # noqa: E402
|
||||
|
||||
|
||||
class _Predicate:
|
||||
@@ -129,3 +129,29 @@ def test_null_owner_is_legacy_single_user_noop():
|
||||
rows = [_ep("ep-x", "bob"), _ep("ep-y", "alice")]
|
||||
ep = _resolve(rows, None, "ep-x")
|
||||
assert ep is not None and ep.id == "ep-x"
|
||||
|
||||
|
||||
def test_runtime_resolution_uses_provider_auth_for_chatgpt_subscription(monkeypatch):
|
||||
ep = SimpleNamespace(
|
||||
id="ep-chatgpt",
|
||||
owner="alice",
|
||||
base_url="https://chatgpt.com/backend-api/codex",
|
||||
api_key=None,
|
||||
provider_auth_id="auth-1",
|
||||
cached_models='["gpt-5.5"]',
|
||||
hidden_models=None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"src.chatgpt_subscription.resolve_runtime_credentials",
|
||||
lambda auth_id, owner=None: {
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "fresh-access-token",
|
||||
},
|
||||
)
|
||||
|
||||
url, model, headers = _resolve_endpoint_runtime(ep, owner="alice", model="")
|
||||
|
||||
assert url == "https://chatgpt.com/backend-api/codex/responses"
|
||||
assert model == "gpt-5.5"
|
||||
assert headers["Authorization"] == "Bearer fresh-access-token"
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
"""resolve_session_auth must not persist the ChatGPT Subscription bearer.
|
||||
|
||||
The ChatGPT Subscription access token is a short-lived OAuth bearer re-resolved
|
||||
(and refreshed) on every request. resolve_session_auth() may set it on the
|
||||
in-memory session for the current request, but it must never write it back into
|
||||
the sessions table — otherwise the live token sits at rest as
|
||||
"Authorization: Bearer ...". Only the encrypted refresh token in
|
||||
ProviderAuthSession is allowed to persist.
|
||||
"""
|
||||
|
||||
import types
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import routes.chat_helpers as chat_helpers
|
||||
import src.endpoint_resolver as endpoint_resolver
|
||||
from core.database import Base, ModelEndpoint, Session as DbSession
|
||||
|
||||
_CODEX_BASE = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _mem_db(monkeypatch):
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(bind=engine)
|
||||
# Match production SessionLocal (core.database) which is autoflush=False.
|
||||
TestSessionLocal = sessionmaker(bind=engine, autoflush=False)
|
||||
monkeypatch.setattr(chat_helpers, "SessionLocal", TestSessionLocal)
|
||||
return TestSessionLocal
|
||||
|
||||
|
||||
def test_chatgpt_subscription_auth_is_not_written_to_sessions_table(monkeypatch):
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
db.add(ModelEndpoint(
|
||||
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
|
||||
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
|
||||
))
|
||||
db.add(DbSession(
|
||||
id="sess1", name="chat", endpoint_url=_CODEX_BASE,
|
||||
model="gpt-5.1-codex", owner="alice", headers={},
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# A live access token is resolved at request time.
|
||||
monkeypatch.setattr(
|
||||
endpoint_resolver, "resolve_endpoint_runtime",
|
||||
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
|
||||
)
|
||||
|
||||
sess = types.SimpleNamespace(
|
||||
id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex",
|
||||
owner="alice", headers={},
|
||||
)
|
||||
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
|
||||
|
||||
# In-memory session got request-local auth for this request...
|
||||
assert any(k.lower() == "authorization" for k in sess.headers)
|
||||
assert sess.headers["Authorization"] == "Bearer live-access-token"
|
||||
|
||||
# ...but the DB row must NOT have the bearer persisted.
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||
stored = row.headers or {}
|
||||
assert not any(k.lower() == "authorization" for k in stored), (
|
||||
f"ChatGPT bearer leaked into sessions table: {stored}"
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_non_subscription_auth_is_still_persisted_to_sessions_table(monkeypatch):
|
||||
"""The early-return must be scoped to ChatGPT Subscription only.
|
||||
|
||||
Ordinary endpoints rely on resolve_session_auth() persisting the resolved
|
||||
headers into the sessions table so they aren't re-resolved on every request.
|
||||
If the is_chatgpt_subscription guard ever widened, this would silently break;
|
||||
this test pins the persistence path as still reached for normal endpoints.
|
||||
"""
|
||||
base = "https://api.example.com/v1"
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
db.add(ModelEndpoint(
|
||||
id="ep1", name="Generic", base_url=base,
|
||||
owner="alice", is_enabled=True, api_key="sk-static",
|
||||
))
|
||||
db.add(DbSession(
|
||||
id="sess1", name="chat", endpoint_url=base,
|
||||
model="gpt-x", owner="alice", headers={},
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
monkeypatch.setattr(
|
||||
endpoint_resolver, "resolve_endpoint_runtime",
|
||||
lambda ep, owner=None: (base, "sk-static"),
|
||||
)
|
||||
|
||||
sess = types.SimpleNamespace(
|
||||
id="sess1", endpoint_url=base, model="gpt-x", owner="alice", headers={},
|
||||
)
|
||||
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
|
||||
|
||||
# In-memory session got auth...
|
||||
assert any(k.lower() in ("authorization", "x-api-key") for k in sess.headers)
|
||||
|
||||
# ...AND it was persisted to the DB row (the normal, non-subscription path).
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||
stored = row.headers or {}
|
||||
assert any(k.lower() in ("authorization", "x-api-key") for k in stored), (
|
||||
f"non-subscription auth was not persisted: {stored}"
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_chatgpt_subscription_clears_previously_persisted_bearer(monkeypatch):
|
||||
"""A bearer left at rest by an older code path is stripped on next resolve."""
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
db.add(ModelEndpoint(
|
||||
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
|
||||
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
|
||||
))
|
||||
# Simulate the leak: a stale bearer already sitting in the sessions table.
|
||||
db.add(DbSession(
|
||||
id="sess1", name="chat", endpoint_url=_CODEX_BASE,
|
||||
model="gpt-5.1-codex", owner="alice",
|
||||
headers={"Authorization": "Bearer stale-leaked-token"},
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
monkeypatch.setattr(
|
||||
endpoint_resolver,
|
||||
"resolve_endpoint_runtime",
|
||||
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
|
||||
)
|
||||
|
||||
sess = types.SimpleNamespace(
|
||||
id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex",
|
||||
owner="alice", headers={},
|
||||
)
|
||||
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
|
||||
|
||||
# The stale bearer must have been stripped from the DB row.
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||
stored = row.headers or {}
|
||||
assert not any(k.lower() == "authorization" for k in stored), (
|
||||
f"stale ChatGPT bearer was not cleared: {stored}"
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_chatgpt_subscription_fallback_auth_is_not_written_to_sessions_table(monkeypatch):
|
||||
"""Fallback endpoint selection must keep the resolved bearer request-local."""
|
||||
TestSessionLocal = _mem_db(monkeypatch)
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
db.add(ModelEndpoint(
|
||||
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
|
||||
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
|
||||
cached_models='["gpt-5.1-codex"]',
|
||||
))
|
||||
db.add(DbSession(
|
||||
id="sess1", name="chat", endpoint_url="https://old.example/v1",
|
||||
model="old-model", owner="alice", headers={},
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
monkeypatch.setattr(
|
||||
endpoint_resolver,
|
||||
"resolve_endpoint_runtime",
|
||||
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
|
||||
)
|
||||
|
||||
sess = types.SimpleNamespace(
|
||||
id="sess1", endpoint_url="https://old.example/v1", model="old-model",
|
||||
owner="alice", headers={},
|
||||
)
|
||||
result = chat_helpers.try_fallback_endpoint(sess, "sess1")
|
||||
|
||||
assert result == {
|
||||
"model": "gpt-5.1-codex",
|
||||
"endpoint_url": _CODEX_BASE + "/responses",
|
||||
"endpoint_name": "ChatGPT Subscription",
|
||||
}
|
||||
assert sess.headers["Authorization"] == "Bearer live-access-token"
|
||||
|
||||
db = TestSessionLocal()
|
||||
try:
|
||||
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||
assert row.model == "gpt-5.1-codex"
|
||||
assert row.endpoint_url == _CODEX_BASE + "/responses"
|
||||
stored = row.headers or {}
|
||||
assert not any(k.lower() == "authorization" for k in stored), (
|
||||
f"ChatGPT fallback bearer leaked into sessions table: {stored}"
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
@@ -386,7 +386,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
|
||||
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
|
||||
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {})
|
||||
monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester")
|
||||
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model: None)
|
||||
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
|
||||
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
|
||||
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
|
||||
|
||||
|
||||
@@ -137,3 +137,12 @@ def test_unauthenticated_caller_rejected(monkeypatch):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
SR._verify_session_owner(req, "sid")
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_disabled_allows_owner_stamped_session(monkeypatch):
|
||||
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||
monkeypatch.setattr(SR, "SessionLocal", _session_local_returning("admin"))
|
||||
req = _req(api_token=False, current_user=None)
|
||||
|
||||
# Single-user/auth-disabled mode should verify existence but not compare owner.
|
||||
SR._verify_session_owner(req, "sid-owned-by-admin")
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Static regressions for `/setup` account sign-in providers."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_SLASH = (_REPO / "static" / "js" / "slashCommands.js").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _between(src: str, start: str, end: str) -> str:
|
||||
start_idx = src.index(start)
|
||||
end_idx = src.index(end, start_idx)
|
||||
return src[start_idx:end_idx]
|
||||
|
||||
|
||||
def test_setup_guide_lists_account_sign_in_providers():
|
||||
guide_block = _between(_SLASH, "function _showSetupEndpointChoices", "async function _hasConfiguredModels")
|
||||
|
||||
assert 'data-setup-provider="' in _SLASH
|
||||
assert "provider.key" in _SLASH
|
||||
assert "'copilot'" in _SLASH
|
||||
assert "'chatgpt-subscription'" in _SLASH
|
||||
assert "/setup copilot" in _SLASH
|
||||
assert "/setup chatgpt-subscription" in _SLASH
|
||||
|
||||
|
||||
def test_clicking_account_sign_in_provider_prefills_setup_command_not_api_key():
|
||||
click_block = _between(_SLASH, "const providerEl = e.target.closest('.setup-clickable-provider')", "// 3. Check")
|
||||
|
||||
assert "providerEl.dataset.setupProvider" in click_block
|
||||
assert "providerEl.dataset.setupKind === 'device-auth'" in click_block
|
||||
assert "'/setup ' + providerKey" in click_block
|
||||
|
||||
|
||||
def test_setup_chatgpt_subscription_prints_auth_url_without_auto_opening_tab():
|
||||
flow_block = _between(_SLASH, "async function _setupProviderDeviceFlow", "async function _cmdSetup")
|
||||
|
||||
assert "providerKey === 'chatgpt-subscription'" in flow_block
|
||||
assert "Open this URL" in flow_block
|
||||
assert "authUrl" in flow_block
|
||||
assert 'href="\' + uiModule.esc(authUrl || \'\') + \'"' in flow_block
|
||||
assert "if (providerKey === 'chatgpt-subscription') return;" in flow_block
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Static regressions for slash autocomplete command-group expansion."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_REPO = Path(__file__).resolve().parent.parent
|
||||
_AC = (_REPO / "static" / "js" / "slashAutocomplete.js").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_exact_parent_command_expands_subcommands_before_top_level_row_cap():
|
||||
assert "function _exactCommandGroupItems" in _AC
|
||||
assert "entry.token.toLowerCase().startsWith(prefix)" in _AC
|
||||
assert "items = groupItems.slice(0, MAX_VISIBLE);" in _AC
|
||||
|
||||
|
||||
def test_setup_group_has_room_for_chatgpt_subscription_suggestion():
|
||||
assert "const MAX_VISIBLE = 14;" in _AC
|
||||
Reference in New Issue
Block a user