mirror of
https://github.com/pewdiepie-archdaemon/odysseus.git
synced 2026-06-17 10:15:27 -04:00
feat: add ChatGPT Subscription provider (#2876)
* feat: Add ChatGPT Subscription support and related features - Introduced a new provider option for ChatGPT Subscription in the endpoint selection UI. - Implemented OAuth flow for ChatGPT Subscription sign-in, including polling for authorization status. - Updated admin interface to handle ChatGPT Subscription, including disabling API key input and providing user guidance. - Enhanced cost tracking logic to differentiate between subscription and non-subscription endpoints. - Added new slash commands for managing skills, including listing, searching, and invoking skills. - Implemented caching for skill catalog to optimize performance. - Updated tests to cover new ChatGPT Subscription functionality and ensure proper endpoint probing. - Refactored existing code to accommodate new features and improve maintainability. * refactor: share provider device-flow setup - reuse one device-flow backend for Copilot and ChatGPT Subscription - add one frontend device-flow helper for Settings and /setup - put GitHub Copilot back into Add Models, now as a dropdown option - make provider selection just select; clicking Add starts sign-in - stop ChatGPT Subscription setup from opening auth tabs automatically - make /setup copilot and /setup chatgpt-subscription work from chat - show ChatGPT Subscription in the /setup suggestions - show the real error message when setup fails - add focused tests for the shared flow and setup UI * feat(chatgpt-subscription): harden credential lifecycle and streamline auth UX Backend: - Resolve runtime bearer for provider-auth endpoints at probe time via a shared _resolve_probe_key() that delegates to resolve_endpoint_runtime, applied across all probe/refresh call sites. - Skip live completion probes and health pings for discovery-only providers (centralized behind _is_discovery_only_provider) — the Codex/Responses API has no such endpoints, so status is derived from cached models. - Never persist the short lived ChatGPT bearer to the plaintext sessions table; proactively clear any stale bearer left by an earlier code path. - Revoke orphaned ProviderAuthSession credentials when the last endpoint backing them is deleted (_delete_orphaned_provider_auth), surfaced via cleared_provider_auth in the delete response. Frontend (admin.js): - Auto-start the device-auth flow on provider selection so the authorization panel (code + Authorize) shows immediately instead of behind a "Sign in" click. - Remove the redundant top button for device auth providers, move retry into the panel via an inline "Try again". - Drop the self-evident hint text and add an execCommand clipboard fallback so Copy works in non-secure (HTTP/LAN) contexts. * fix: harden chatgpt subscription provider * chore: remove PR media from branch * Fix chatgpt subscription recovery and token handling --------- Co-authored-by: 5p00kyy <admin@5p00ky.dev>
This commit is contained in:
@@ -598,6 +598,10 @@ app.include_router(setup_model_routes(model_discovery))
|
|||||||
from routes.copilot_routes import setup_copilot_routes
|
from routes.copilot_routes import setup_copilot_routes
|
||||||
app.include_router(setup_copilot_routes())
|
app.include_router(setup_copilot_routes())
|
||||||
|
|
||||||
|
# ChatGPT Subscription device-flow login
|
||||||
|
from routes.chatgpt_subscription_routes import setup_chatgpt_subscription_routes
|
||||||
|
app.include_router(setup_chatgpt_subscription_routes())
|
||||||
|
|
||||||
# TTS
|
# TTS
|
||||||
from routes.tts_routes import setup_tts_routes
|
from routes.tts_routes import setup_tts_routes
|
||||||
app.include_router(setup_tts_routes(tts_service))
|
app.include_router(setup_tts_routes(tts_service))
|
||||||
|
|||||||
@@ -361,6 +361,24 @@ class ModelEndpoint(TimestampMixin, Base):
|
|||||||
# is the historical default. When non-null, the model picker only shows
|
# is the historical default. When non-null, the model picker only shows
|
||||||
# the endpoint to that user (admins always see everything).
|
# the endpoint to that user (admins always see everything).
|
||||||
owner = Column(String, nullable=True, index=True)
|
owner = Column(String, nullable=True, index=True)
|
||||||
|
# Optional OAuth/session-backed credential row. Used by subscription-backed
|
||||||
|
# providers that need refresh tokens instead of a static API key.
|
||||||
|
provider_auth_id = Column(String, nullable=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderAuthSession(TimestampMixin, Base):
|
||||||
|
"""Encrypted OAuth/session credentials for refresh-aware model providers."""
|
||||||
|
__tablename__ = "provider_auth_sessions"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True, index=True)
|
||||||
|
provider = Column(String, nullable=False, index=True)
|
||||||
|
owner = Column(String, nullable=True, index=True)
|
||||||
|
label = Column(String, nullable=True)
|
||||||
|
base_url = Column(String, nullable=False)
|
||||||
|
access_token = Column(EncryptedText, nullable=True)
|
||||||
|
refresh_token = Column(EncryptedText, nullable=True)
|
||||||
|
last_refresh = Column(DateTime, nullable=True)
|
||||||
|
auth_mode = Column(String, nullable=True)
|
||||||
|
|
||||||
class McpServer(TimestampMixin, Base):
|
class McpServer(TimestampMixin, Base):
|
||||||
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
"""Admin-configured MCP (Model Context Protocol) tool servers."""
|
||||||
@@ -801,6 +819,26 @@ def _migrate_add_model_endpoint_owner_column():
|
|||||||
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
logging.getLogger(__name__).warning(f"model_endpoints.owner migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_add_provider_auth_id_column():
|
||||||
|
"""Add provider_auth_id column to model_endpoints if it doesn't exist."""
|
||||||
|
import sqlite3
|
||||||
|
db_path = DATABASE_URL.replace("sqlite:///", "")
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("PRAGMA table_info(model_endpoints)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
if columns and "provider_auth_id" not in columns:
|
||||||
|
conn.execute("ALTER TABLE model_endpoints ADD COLUMN provider_auth_id VARCHAR")
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS ix_model_endpoints_provider_auth_id ON model_endpoints(provider_auth_id)")
|
||||||
|
conn.commit()
|
||||||
|
logging.getLogger(__name__).info("Migrated: added 'provider_auth_id' column + index to model_endpoints")
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(f"model_endpoints.provider_auth_id migration failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _migrate_add_model_type_column():
|
def _migrate_add_model_type_column():
|
||||||
"""Add model_type column to model_endpoints if it doesn't exist."""
|
"""Add model_type column to model_endpoints if it doesn't exist."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
@@ -1599,6 +1637,7 @@ def init_db():
|
|||||||
_migrate_add_model_type_column()
|
_migrate_add_model_type_column()
|
||||||
_migrate_add_model_endpoint_refresh_columns()
|
_migrate_add_model_endpoint_refresh_columns()
|
||||||
_migrate_add_model_endpoint_owner_column()
|
_migrate_add_model_endpoint_owner_column()
|
||||||
|
_migrate_add_provider_auth_id_column()
|
||||||
_migrate_add_supports_tools_column()
|
_migrate_add_supports_tools_column()
|
||||||
_migrate_add_task_run_model_column()
|
_migrate_add_task_run_model_column()
|
||||||
_migrate_add_owner_column()
|
_migrate_add_owner_column()
|
||||||
|
|||||||
+86
-28
@@ -196,14 +196,26 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||||
"""
|
"""
|
||||||
import requests as _req
|
import requests as _req
|
||||||
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, normalize_base
|
from src.endpoint_resolver import (
|
||||||
|
build_chat_url,
|
||||||
|
build_headers,
|
||||||
|
build_models_url,
|
||||||
|
normalize_base,
|
||||||
|
resolve_endpoint_runtime,
|
||||||
|
)
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
|
||||||
current_url = sess.endpoint_url or ""
|
current_url = sess.endpoint_url or ""
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(
|
q = db.query(ModelEndpoint).filter(
|
||||||
ModelEndpoint.is_enabled == True
|
ModelEndpoint.is_enabled == True
|
||||||
).all()
|
)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -212,26 +224,33 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
# Skip current endpoint
|
# Skip current endpoint
|
||||||
if current_url and base in current_url:
|
if current_url and base in current_url:
|
||||||
continue
|
continue
|
||||||
# Quick ping
|
|
||||||
ping_url = build_models_url(base)
|
|
||||||
headers = build_headers(ep.api_key, base)
|
|
||||||
try:
|
try:
|
||||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
r.raise_for_status()
|
except Exception:
|
||||||
data = r.json()
|
continue
|
||||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
ping_url = build_models_url(base)
|
||||||
if not models:
|
headers = build_headers(api_key, base)
|
||||||
models = [
|
try:
|
||||||
m.get("name") or m.get("model")
|
if ping_url:
|
||||||
for m in (data.get("models") or [])
|
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||||
if m.get("name") or m.get("model")
|
r.raise_for_status()
|
||||||
]
|
data = r.json()
|
||||||
|
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
|
if not models:
|
||||||
|
models = [
|
||||||
|
m.get("name") or m.get("model")
|
||||||
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
models = json.loads(ep.cached_models or "[]")
|
||||||
if not models:
|
if not models:
|
||||||
continue
|
continue
|
||||||
# Found a working endpoint — update session
|
# Found a working endpoint — update session
|
||||||
new_model = models[0]
|
new_model = models[0]
|
||||||
chat_url = build_chat_url(base)
|
chat_url = build_chat_url(base)
|
||||||
new_headers = build_headers(ep.api_key, base)
|
new_headers = build_headers(api_key, base)
|
||||||
|
persisted_headers = {} if is_chatgpt_subscription_base(base) else new_headers
|
||||||
|
|
||||||
sess.model = new_model
|
sess.model = new_model
|
||||||
sess.endpoint_url = chat_url
|
sess.endpoint_url = chat_url
|
||||||
@@ -243,7 +262,7 @@ def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
|||||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||||
"model": new_model,
|
"model": new_model,
|
||||||
"endpoint_url": chat_url,
|
"endpoint_url": chat_url,
|
||||||
"headers": json.dumps(new_headers),
|
"headers": persisted_headers,
|
||||||
})
|
})
|
||||||
_db.commit()
|
_db.commit()
|
||||||
finally:
|
finally:
|
||||||
@@ -336,16 +355,26 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _has_auth_keys(headers) -> bool:
|
||||||
|
"""True if a headers dict carries an Authorization/x-api-key entry."""
|
||||||
|
return isinstance(headers, dict) and any(
|
||||||
|
k.lower() in ('authorization', 'x-api-key') for k in headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
try:
|
||||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
)
|
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(sess, "endpoint_url", "") or "")
|
||||||
if has_auth:
|
except Exception:
|
||||||
|
is_chatgpt_subscription = False
|
||||||
|
has_auth = _has_auth_keys(sess.headers)
|
||||||
|
if has_auth and not is_chatgpt_subscription:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import build_headers, normalize_base
|
from src.endpoint_resolver import build_headers, resolve_endpoint_runtime
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
target_url = getattr(sess, "endpoint_url", "") or ""
|
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||||
@@ -361,10 +390,30 @@ def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
|||||||
for ep in q.all():
|
for ep in q.all():
|
||||||
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||||
continue
|
continue
|
||||||
if not ep.api_key:
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to resolve provider auth for session %s: %s", session_id, e)
|
||||||
|
return
|
||||||
|
if not api_key:
|
||||||
|
# No usable key (e.g. ChatGPT Subscription needs re-auth).
|
||||||
|
return
|
||||||
|
sess.headers = build_headers(api_key, base)
|
||||||
|
if is_chatgpt_subscription:
|
||||||
|
# The bearer is short-lived and re-resolved per request, so it
|
||||||
|
# stays request-local and is never written to the plaintext
|
||||||
|
# sessions.headers column. Proactively strip any bearer an
|
||||||
|
# older code path may have persisted so it does not linger.
|
||||||
|
stale_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
stale_q = stale_q.filter(DBSession.owner == owner)
|
||||||
|
stored = stale_q.first()
|
||||||
|
if stored is not None and _has_auth_keys(stored.headers):
|
||||||
|
stale_q.update({"headers": {}})
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Cleared persisted ChatGPT Subscription bearer from session {session_id}")
|
||||||
|
logger.debug(f"Resolved request-local ChatGPT Subscription auth for session {session_id}")
|
||||||
return
|
return
|
||||||
base = normalize_base(ep.base_url or "")
|
|
||||||
sess.headers = build_headers(ep.api_key, base)
|
|
||||||
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
if owner:
|
if owner:
|
||||||
update_q = update_q.filter(DBSession.owner == owner)
|
update_q = update_q.filter(DBSession.owner == owner)
|
||||||
@@ -408,7 +457,12 @@ def _normalize_model_id_from_cache(sess) -> Optional[str]:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
try:
|
try:
|
||||||
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
if normalize_base(getattr(ep, "base_url", "") or "") != session_base:
|
||||||
@@ -542,7 +596,11 @@ async def build_chat_context(
|
|||||||
|
|
||||||
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
# Normalize model ID. Prefer cached endpoint models so group chat does not
|
||||||
# re-hit slow local /models endpoints on every participant turn.
|
# re-hit slow local /models endpoints on every participant turn.
|
||||||
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(sess.endpoint_url, sess.model)
|
norm = _normalize_model_id_from_cache(sess) or normalize_model_id(
|
||||||
|
sess.endpoint_url,
|
||||||
|
sess.model,
|
||||||
|
owner=getattr(sess, "owner", None),
|
||||||
|
)
|
||||||
if norm:
|
if norm:
|
||||||
sess.model = norm
|
sess.model = norm
|
||||||
|
|
||||||
|
|||||||
+57
-12
@@ -169,13 +169,20 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
Covers the window between endpoint setup and the first chat send: the
|
Covers the window between endpoint setup and the first chat send: the
|
||||||
picker showed a model in the dropdown but the session record never got
|
picker showed a model in the dropdown but the session record never got
|
||||||
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
written (Issue #587 — UI uses the cached endpoint list, not s.model).
|
||||||
Without this, we'd POST the upstream with model="" and get a generic
|
For ChatGPT Subscription, also repairs stale OpenAI API model names such as
|
||||||
401/503 instead of using the model the user already picked.
|
``gpt-5`` that are not accepted by the Codex-backed ChatGPT account route.
|
||||||
|
|
||||||
Returns True iff sess.model was repaired.
|
|
||||||
"""
|
"""
|
||||||
if getattr(sess, "model", None):
|
current_model = (getattr(sess, "model", "") or "").strip()
|
||||||
return False
|
endpoint_url = (getattr(sess, "endpoint_url", "") or "").strip()
|
||||||
|
is_chatgpt_subscription = False
|
||||||
|
if current_model:
|
||||||
|
try:
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
is_chatgpt_subscription = is_chatgpt_subscription_base(endpoint_url)
|
||||||
|
if not is_chatgpt_subscription:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Prefer the endpoint whose base URL matches the session — we know the
|
# Prefer the endpoint whose base URL matches the session — we know the
|
||||||
@@ -194,16 +201,51 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
break
|
break
|
||||||
if not ep:
|
if not ep:
|
||||||
return False
|
return False
|
||||||
|
if not is_chatgpt_subscription:
|
||||||
|
try:
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
is_chatgpt_subscription = is_chatgpt_subscription_base(getattr(ep, "base_url", "") or endpoint_url)
|
||||||
|
except Exception:
|
||||||
|
is_chatgpt_subscription = False
|
||||||
try:
|
try:
|
||||||
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
cached = json.loads(ep.cached_models) if isinstance(ep.cached_models, str) else (ep.cached_models or [])
|
||||||
except Exception:
|
except Exception:
|
||||||
cached = []
|
cached = []
|
||||||
if not cached:
|
if not cached:
|
||||||
|
visible = []
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||||
|
except Exception:
|
||||||
|
visible = cached
|
||||||
|
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||||
return False
|
return False
|
||||||
try:
|
if is_chatgpt_subscription:
|
||||||
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
live_models = []
|
||||||
except Exception:
|
if getattr(ep, "provider_auth_id", None):
|
||||||
visible = cached
|
try:
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
if api_key:
|
||||||
|
live_models = fetch_available_models(api_key)
|
||||||
|
if live_models:
|
||||||
|
ep.cached_models = json.dumps(live_models)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
live_models = []
|
||||||
|
# ChatGPT Subscription recovery must use the live Codex catalog.
|
||||||
|
# Cached rows are only trusted above to avoid revalidating a model
|
||||||
|
# that is already present in the visible picker list.
|
||||||
|
cached = live_models
|
||||||
|
if not cached:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
visible = _visible_models(cached, getattr(ep, "hidden_models", None))
|
||||||
|
except Exception:
|
||||||
|
visible = cached
|
||||||
|
if current_model and current_model in {str(item).strip() for item in visible}:
|
||||||
|
return False
|
||||||
if not visible:
|
if not visible:
|
||||||
return False
|
return False
|
||||||
model = visible[0]
|
model = visible[0]
|
||||||
@@ -213,14 +255,17 @@ def _recover_empty_session_model(sess, session_id: str, owner: str | None = None
|
|||||||
# Persist so the next request, websocket reconnect, or page reload
|
# Persist so the next request, websocket reconnect, or page reload
|
||||||
# picks up the same model (we'd otherwise re-pick on every send
|
# picks up the same model (we'd otherwise re-pick on every send
|
||||||
# and silently switch on the user if the cached order shifts).
|
# and silently switch on the user if the cached order shifts).
|
||||||
db_session = db.query(DBSession).filter(DBSession.id == session_id).first()
|
db_session_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
db_session_q = db_session_q.filter(DBSession.owner == owner)
|
||||||
|
db_session = db_session_q.first()
|
||||||
if db_session:
|
if db_session:
|
||||||
db_session.model = model
|
db_session.model = model
|
||||||
db_session.updated_at = datetime.utcnow()
|
db_session.updated_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
sess.model = model
|
sess.model = model
|
||||||
logger.info(
|
logger.info(
|
||||||
"Recovered empty session model for %s — picked %r from endpoint %s",
|
"Recovered session model for %s — picked %r from endpoint %s",
|
||||||
session_id, model, ep.id,
|
session_id, model, ep.id,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -0,0 +1,170 @@
|
|||||||
|
"""ChatGPT Subscription device-flow setup routes."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from core.database import ModelEndpoint, ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
from routes.device_flow import (
|
||||||
|
DeviceFlowPoll,
|
||||||
|
DeviceFlowStart,
|
||||||
|
PendingDeviceFlowStore,
|
||||||
|
create_device_flow_router,
|
||||||
|
)
|
||||||
|
from src.auth_helpers import get_current_user
|
||||||
|
from src import chatgpt_subscription
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||||
|
|
||||||
|
|
||||||
|
def _provision_endpoint(tokens: Dict, owner: Optional[str]) -> Dict:
|
||||||
|
access_token = tokens.get("access_token")
|
||||||
|
refresh_token = tokens.get("refresh_token")
|
||||||
|
if not access_token or not refresh_token:
|
||||||
|
raise ValueError("ChatGPT token response was missing access_token or refresh_token")
|
||||||
|
|
||||||
|
base = chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||||
|
models = chatgpt_subscription.fetch_available_models(access_token)
|
||||||
|
if not models:
|
||||||
|
raise ValueError("ChatGPT Subscription connected, but no usable Codex models were discovered for this account.")
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
auth = (
|
||||||
|
db.query(ProviderAuthSession)
|
||||||
|
.filter(
|
||||||
|
ProviderAuthSession.provider == chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
ProviderAuthSession.owner == owner,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if auth is None:
|
||||||
|
auth = ProviderAuthSession(
|
||||||
|
id=str(uuid.uuid4())[:8],
|
||||||
|
provider=chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
owner=owner,
|
||||||
|
label="ChatGPT Subscription",
|
||||||
|
base_url=base,
|
||||||
|
auth_mode="chatgpt",
|
||||||
|
)
|
||||||
|
db.add(auth)
|
||||||
|
auth.base_url = base
|
||||||
|
auth.access_token = access_token
|
||||||
|
auth.refresh_token = refresh_token
|
||||||
|
auth.last_refresh = utcnow_naive()
|
||||||
|
auth.auth_mode = "chatgpt"
|
||||||
|
|
||||||
|
ep = (
|
||||||
|
db.query(ModelEndpoint)
|
||||||
|
.filter(
|
||||||
|
ModelEndpoint.base_url == base,
|
||||||
|
ModelEndpoint.provider_auth_id == auth.id,
|
||||||
|
ModelEndpoint.owner == owner,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if ep is None:
|
||||||
|
ep = ModelEndpoint(
|
||||||
|
id=str(uuid.uuid4())[:8],
|
||||||
|
name="ChatGPT Subscription",
|
||||||
|
base_url=base,
|
||||||
|
model_type="llm",
|
||||||
|
endpoint_kind="api",
|
||||||
|
owner=owner,
|
||||||
|
)
|
||||||
|
db.add(ep)
|
||||||
|
ep.name = "ChatGPT Subscription"
|
||||||
|
ep.base_url = base
|
||||||
|
ep.api_key = None
|
||||||
|
ep.provider_auth_id = auth.id
|
||||||
|
ep.is_enabled = True
|
||||||
|
ep.supports_tools = False
|
||||||
|
ep.model_type = "llm"
|
||||||
|
ep.endpoint_kind = "api"
|
||||||
|
ep.model_refresh_mode = "manual"
|
||||||
|
ep.cached_models = json.dumps(models)
|
||||||
|
db.commit()
|
||||||
|
result = {
|
||||||
|
"id": ep.id,
|
||||||
|
"name": ep.name,
|
||||||
|
"base_url": ep.base_url,
|
||||||
|
"models": models,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from routes.model_routes import _invalidate_models_cache
|
||||||
|
|
||||||
|
_invalidate_models_cache()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _start_device_flow(request: Request, _form) -> DeviceFlowStart:
|
||||||
|
try:
|
||||||
|
data = chatgpt_subscription.request_device_code()
|
||||||
|
except Exception as exc:
|
||||||
|
raise chatgpt_subscription.to_http_exception(exc)
|
||||||
|
|
||||||
|
device_auth_id = data.get("device_auth_id")
|
||||||
|
user_code = data.get("user_code")
|
||||||
|
if not device_auth_id or not user_code:
|
||||||
|
raise HTTPException(502, "ChatGPT did not return a complete device code")
|
||||||
|
verification_uri = data.get("verification_uri") or f"{chatgpt_subscription.CHATGPT_OAUTH_ISSUER}/codex/device"
|
||||||
|
return DeviceFlowStart(
|
||||||
|
pending={
|
||||||
|
"device_auth_id": device_auth_id,
|
||||||
|
"user_code": user_code,
|
||||||
|
"owner": get_current_user(request) or None,
|
||||||
|
},
|
||||||
|
response={
|
||||||
|
"user_code": user_code,
|
||||||
|
"verification_uri": verification_uri,
|
||||||
|
},
|
||||||
|
interval=int(data.get("interval") or 5),
|
||||||
|
expires_in=int(data.get("expires_in") or 900),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||||
|
try:
|
||||||
|
data = chatgpt_subscription.poll_device_auth(pending["device_auth_id"], pending["user_code"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("ChatGPT device poll failed: %s", exc)
|
||||||
|
return DeviceFlowPoll.pending(str(exc))
|
||||||
|
|
||||||
|
authorization_code = data.get("authorization_code")
|
||||||
|
code_verifier = data.get("code_verifier")
|
||||||
|
if authorization_code and code_verifier:
|
||||||
|
try:
|
||||||
|
tokens = chatgpt_subscription.exchange_authorization_code(authorization_code, code_verifier)
|
||||||
|
result = _provision_endpoint(tokens, pending["owner"])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("ChatGPT Subscription endpoint provisioning failed")
|
||||||
|
raise chatgpt_subscription.to_http_exception(exc)
|
||||||
|
return DeviceFlowPoll.authorized(result)
|
||||||
|
|
||||||
|
err = data.get("error") or data.get("status")
|
||||||
|
if err in ("authorization_pending", "pending", None):
|
||||||
|
return DeviceFlowPoll.pending()
|
||||||
|
if err == "slow_down":
|
||||||
|
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||||
|
if err in ("expired_token", "access_denied", "denied"):
|
||||||
|
return DeviceFlowPoll.failed(err)
|
||||||
|
return DeviceFlowPoll.pending(err or "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_chatgpt_subscription_routes():
|
||||||
|
return create_device_flow_router(
|
||||||
|
prefix="/api/chatgpt-subscription",
|
||||||
|
tags=["chatgpt-subscription"],
|
||||||
|
store=_DEVICE_FLOW_STORE,
|
||||||
|
start_flow=_start_device_flow,
|
||||||
|
poll_flow=_poll_device_flow,
|
||||||
|
)
|
||||||
+67
-117
@@ -20,39 +20,26 @@ All routes are admin-gated (endpoint/provider management is an admin action).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Request, Form, HTTPException
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from core.database import SessionLocal, ModelEndpoint
|
from core.database import SessionLocal, ModelEndpoint
|
||||||
from core.middleware import require_admin
|
from routes.device_flow import (
|
||||||
|
DeviceFlowPoll,
|
||||||
|
DeviceFlowStart,
|
||||||
|
PendingDeviceFlowStore,
|
||||||
|
create_device_flow_router,
|
||||||
|
)
|
||||||
from src.auth_helpers import get_current_user
|
from src.auth_helpers import get_current_user
|
||||||
from src import copilot
|
from src import copilot
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
|
_DEVICE_FLOW_STORE = PendingDeviceFlowStore()
|
||||||
# bearer-like secret, so it lives here (server memory) rather than in the
|
|
||||||
# browser. Entries expire with the GitHub device code.
|
|
||||||
#
|
|
||||||
# NOTE: this is per-process state. The device flow assumes a single worker
|
|
||||||
# (Odysseus' default): with multiple uvicorn workers, the poll request can land
|
|
||||||
# on a worker that never saw the start, returning "Unknown or expired login
|
|
||||||
# session". Move this to a shared store (DB/Redis) if running multi-worker.
|
|
||||||
_PENDING: Dict[str, Dict] = {}
|
|
||||||
_PENDING_LOCK = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_expired() -> None:
|
|
||||||
now = time.time()
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
|
|
||||||
_PENDING.pop(k, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
||||||
@@ -112,112 +99,75 @@ def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def setup_copilot_routes() -> APIRouter:
|
def _start_device_flow(request: Request, form) -> DeviceFlowStart:
|
||||||
router = APIRouter(prefix="/api/copilot", tags=["copilot"])
|
host = copilot.GITHUB_HOST
|
||||||
|
ent = str(form.get("enterprise_url") or "").strip()
|
||||||
|
if ent:
|
||||||
|
host = copilot.normalize_domain(ent)
|
||||||
|
try:
|
||||||
|
data = copilot.request_device_code(host)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
status = e.response.status_code if e.response is not None else "unknown"
|
||||||
|
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
||||||
|
|
||||||
@router.post("/device/start")
|
device_code = data.get("device_code")
|
||||||
def device_start(request: Request, enterprise_url: str = Form("")):
|
if not device_code:
|
||||||
require_admin(request)
|
raise HTTPException(502, "GitHub did not return a device code")
|
||||||
_prune_expired()
|
|
||||||
host = copilot.GITHUB_HOST
|
|
||||||
ent = (enterprise_url or "").strip()
|
|
||||||
if ent:
|
|
||||||
host = copilot.normalize_domain(ent)
|
|
||||||
try:
|
|
||||||
data = copilot.request_device_code(host)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
status = e.response.status_code if e.response is not None else "unknown"
|
|
||||||
raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(502, f"GitHub device-code request failed: {e}")
|
|
||||||
|
|
||||||
device_code = data.get("device_code")
|
# verification_uri_complete embeds the user code, so the browser tab we
|
||||||
if not device_code:
|
# open lands the user straight on GitHub's "Authorize" screen with the
|
||||||
raise HTTPException(502, "GitHub did not return a device code")
|
# code pre-filled — one click, no manual code entry.
|
||||||
interval = int(data.get("interval") or 5)
|
return DeviceFlowStart(
|
||||||
expires_in = int(data.get("expires_in") or 900)
|
pending={
|
||||||
poll_id = uuid.uuid4().hex
|
"device_code": device_code,
|
||||||
with _PENDING_LOCK:
|
"host": host,
|
||||||
_PENDING[poll_id] = {
|
"enterprise_url": ent,
|
||||||
"device_code": device_code,
|
"owner": get_current_user(request) or None,
|
||||||
"host": host,
|
},
|
||||||
"enterprise_url": ent,
|
response={
|
||||||
"interval": interval,
|
|
||||||
"owner": get_current_user(request) or None,
|
|
||||||
"expires_at": time.time() + expires_in,
|
|
||||||
"next_poll_at": 0.0,
|
|
||||||
}
|
|
||||||
# verification_uri_complete embeds the user code, so the browser tab we
|
|
||||||
# open lands the user straight on GitHub's "Authorize" screen with the
|
|
||||||
# code pre-filled — one click, no manual code entry.
|
|
||||||
return {
|
|
||||||
"poll_id": poll_id,
|
|
||||||
"user_code": data.get("user_code"),
|
"user_code": data.get("user_code"),
|
||||||
"verification_uri": data.get("verification_uri"),
|
"verification_uri": data.get("verification_uri"),
|
||||||
"verification_uri_complete": data.get("verification_uri_complete"),
|
"verification_uri_complete": data.get("verification_uri_complete"),
|
||||||
"interval": interval,
|
},
|
||||||
"expires_in": expires_in,
|
interval=int(data.get("interval") or 5),
|
||||||
}
|
expires_in=int(data.get("expires_in") or 900),
|
||||||
|
)
|
||||||
|
|
||||||
@router.post("/device/poll")
|
|
||||||
def device_poll(request: Request, poll_id: str = Form(...)):
|
|
||||||
require_admin(request)
|
|
||||||
_prune_expired()
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
pending = _PENDING.get(poll_id)
|
|
||||||
if not pending:
|
|
||||||
raise HTTPException(404, "Unknown or expired login session")
|
|
||||||
|
|
||||||
# Enforce GitHub's polling interval server-side so a chatty client
|
def _poll_device_flow(_request: Request, pending: Dict) -> DeviceFlowPoll:
|
||||||
# can't trip slow_down.
|
try:
|
||||||
now = time.time()
|
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
||||||
if now < pending.get("next_poll_at", 0):
|
except Exception as e:
|
||||||
return {"status": "pending"}
|
return DeviceFlowPoll.pending(f"poll error: {e}")
|
||||||
|
|
||||||
|
token = data.get("access_token")
|
||||||
|
if token:
|
||||||
|
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
||||||
try:
|
try:
|
||||||
data = copilot.poll_access_token(pending["host"], pending["device_code"])
|
result = _provision_endpoint(token, base, pending["owner"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "pending", "detail": f"poll error: {e}"}
|
logger.exception("Copilot endpoint provisioning failed")
|
||||||
|
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
||||||
|
return DeviceFlowPoll.authorized(result)
|
||||||
|
|
||||||
token = data.get("access_token")
|
err = data.get("error")
|
||||||
if token:
|
if err == "authorization_pending":
|
||||||
base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
|
return DeviceFlowPoll.pending()
|
||||||
try:
|
if err == "slow_down":
|
||||||
result = _provision_endpoint(token, base, pending["owner"])
|
return DeviceFlowPoll.slow_down(int(data.get("interval") or 0) or None)
|
||||||
except Exception as e:
|
if err in ("expired_token", "access_denied"):
|
||||||
logger.exception("Copilot endpoint provisioning failed")
|
return DeviceFlowPoll.failed(err)
|
||||||
with _PENDING_LOCK:
|
# Unknown error — surface but keep the session for another try.
|
||||||
_PENDING.pop(poll_id, None)
|
return DeviceFlowPoll.pending(err or "unknown")
|
||||||
raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
_PENDING.pop(poll_id, None)
|
|
||||||
return {"status": "authorized", "endpoint": result}
|
|
||||||
|
|
||||||
err = data.get("error")
|
|
||||||
if err == "authorization_pending":
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
if poll_id in _PENDING:
|
|
||||||
_PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
|
|
||||||
return {"status": "pending"}
|
|
||||||
if err == "slow_down":
|
|
||||||
new_interval = int(data.get("interval") or (pending["interval"] + 5))
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
if poll_id in _PENDING:
|
|
||||||
_PENDING[poll_id]["interval"] = new_interval
|
|
||||||
_PENDING[poll_id]["next_poll_at"] = now + new_interval
|
|
||||||
return {"status": "pending"}
|
|
||||||
if err in ("expired_token", "access_denied"):
|
|
||||||
with _PENDING_LOCK:
|
|
||||||
_PENDING.pop(poll_id, None)
|
|
||||||
return {"status": "failed", "error": err}
|
|
||||||
# Unknown error — surface but keep the session for another try.
|
|
||||||
return {"status": "pending", "detail": err or "unknown"}
|
|
||||||
|
|
||||||
@router.post("/device/cancel")
|
def setup_copilot_routes():
|
||||||
def device_cancel(request: Request, poll_id: str = Form(...)):
|
return create_device_flow_router(
|
||||||
require_admin(request)
|
prefix="/api/copilot",
|
||||||
with _PENDING_LOCK:
|
tags=["copilot"],
|
||||||
_PENDING.pop(poll_id, None)
|
store=_DEVICE_FLOW_STORE,
|
||||||
return {"status": "cancelled"}
|
start_flow=_start_device_flow,
|
||||||
|
poll_flow=_poll_device_flow,
|
||||||
return router
|
)
|
||||||
|
|||||||
@@ -0,0 +1,193 @@
|
|||||||
|
"""Shared OAuth/device-flow route scaffolding for provider setup."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Iterable, Mapping, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
|
|
||||||
|
from core.middleware import require_admin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceFlowStart:
|
||||||
|
"""Provider-specific start result consumed by the shared route wrapper."""
|
||||||
|
|
||||||
|
pending: Mapping[str, Any]
|
||||||
|
response: Mapping[str, Any]
|
||||||
|
interval: int = 5
|
||||||
|
expires_in: int = 900
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceFlowPoll:
|
||||||
|
"""Normalized provider poll outcome."""
|
||||||
|
|
||||||
|
status: str
|
||||||
|
endpoint: Optional[Mapping[str, Any]] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
detail: Optional[str] = None
|
||||||
|
interval: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pending(cls, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="pending", detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow_down(cls, interval: Optional[int] = None, detail: Optional[str] = None) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="slow_down", interval=interval, detail=detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def authorized(cls, endpoint: Mapping[str, Any]) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="authorized", endpoint=endpoint)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def failed(cls, error: str) -> "DeviceFlowPoll":
|
||||||
|
return cls(status="failed", error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class PendingDeviceFlowStore:
|
||||||
|
"""Thread-safe in-memory pending device-flow store.
|
||||||
|
|
||||||
|
Device codes and provider-side secrets stay inside this process. Each entry
|
||||||
|
stores provider payload separately from poll metadata so provider callbacks
|
||||||
|
only receive the fields they created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, time_func: Callable[[], float] = time.time):
|
||||||
|
self._pending: dict[str, dict[str, Any]] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._time = time_func
|
||||||
|
|
||||||
|
def _now(self) -> float:
|
||||||
|
return float(self._time())
|
||||||
|
|
||||||
|
def prune_expired(self) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
for key in [k for k, v in self._pending.items() if v.get("expires_at", 0) < now]:
|
||||||
|
self._pending.pop(key, None)
|
||||||
|
|
||||||
|
def add(self, payload: Mapping[str, Any], *, interval: int, expires_in: int) -> str:
|
||||||
|
self.prune_expired()
|
||||||
|
poll_id = uuid.uuid4().hex
|
||||||
|
with self._lock:
|
||||||
|
self._pending[poll_id] = {
|
||||||
|
"payload": dict(payload),
|
||||||
|
"interval": max(int(interval or 5), 1),
|
||||||
|
"expires_at": self._now() + max(int(expires_in or 900), 1),
|
||||||
|
"next_poll_at": 0.0,
|
||||||
|
}
|
||||||
|
return poll_id
|
||||||
|
|
||||||
|
def get_payload(self, poll_id: str) -> Optional[dict[str, Any]]:
|
||||||
|
self.prune_expired()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
return dict(entry.get("payload") or {})
|
||||||
|
|
||||||
|
def is_throttled(self, poll_id: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
return bool(entry and self._now() < float(entry.get("next_poll_at") or 0))
|
||||||
|
|
||||||
|
def schedule_next(self, poll_id: str) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is not None:
|
||||||
|
entry["next_poll_at"] = now + int(entry.get("interval") or 5)
|
||||||
|
|
||||||
|
def slow_down(self, poll_id: str, interval: Optional[int] = None) -> None:
|
||||||
|
now = self._now()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._pending.get(poll_id)
|
||||||
|
if entry is not None:
|
||||||
|
new_interval = int(interval or (int(entry.get("interval") or 5) + 5))
|
||||||
|
entry["interval"] = max(new_interval, 1)
|
||||||
|
entry["next_poll_at"] = now + entry["interval"]
|
||||||
|
|
||||||
|
def pop(self, poll_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._pending.pop(poll_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
async def _maybe_await(value: Any) -> Any:
|
||||||
|
if inspect.isawaitable(value):
|
||||||
|
return await value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _pending_response(detail: Optional[str] = None) -> dict[str, Any]:
|
||||||
|
response: dict[str, Any] = {"status": "pending"}
|
||||||
|
if detail:
|
||||||
|
response["detail"] = detail
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_device_flow_router(
|
||||||
|
*,
|
||||||
|
prefix: str,
|
||||||
|
tags: Iterable[str],
|
||||||
|
store: PendingDeviceFlowStore,
|
||||||
|
start_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowStart],
|
||||||
|
poll_flow: Callable[[Request, Mapping[str, Any]], DeviceFlowPoll],
|
||||||
|
) -> APIRouter:
|
||||||
|
"""Create standard `/device/start|poll|cancel` routes for a provider."""
|
||||||
|
|
||||||
|
router = APIRouter(prefix=prefix, tags=list(tags))
|
||||||
|
|
||||||
|
@router.post("/device/start")
|
||||||
|
async def device_start(request: Request):
|
||||||
|
require_admin(request)
|
||||||
|
form = await request.form()
|
||||||
|
start = await _maybe_await(start_flow(request, form))
|
||||||
|
interval = int(start.interval or 5)
|
||||||
|
expires_in = int(start.expires_in or 900)
|
||||||
|
poll_id = store.add(start.pending, interval=interval, expires_in=expires_in)
|
||||||
|
response = dict(start.response)
|
||||||
|
response.update({"poll_id": poll_id, "interval": interval, "expires_in": expires_in})
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.post("/device/poll")
|
||||||
|
async def device_poll(request: Request, poll_id: str = Form(...)):
|
||||||
|
require_admin(request)
|
||||||
|
payload = store.get_payload(poll_id)
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(404, "Unknown or expired login session")
|
||||||
|
if store.is_throttled(poll_id):
|
||||||
|
return {"status": "pending"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
outcome = await _maybe_await(poll_flow(request, payload))
|
||||||
|
except Exception:
|
||||||
|
store.pop(poll_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if outcome.status == "authorized":
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "authorized", "endpoint": dict(outcome.endpoint or {})}
|
||||||
|
if outcome.status == "failed":
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "failed", "error": outcome.error or "denied"}
|
||||||
|
if outcome.status == "slow_down":
|
||||||
|
store.slow_down(poll_id, outcome.interval)
|
||||||
|
return _pending_response(outcome.detail)
|
||||||
|
|
||||||
|
store.schedule_next(poll_id)
|
||||||
|
return _pending_response(outcome.detail)
|
||||||
|
|
||||||
|
@router.post("/device/cancel")
|
||||||
|
def device_cancel(request: Request, poll_id: str = Form(...)):
|
||||||
|
require_admin(request)
|
||||||
|
store.pop(poll_id)
|
||||||
|
return {"status": "cancelled"}
|
||||||
|
|
||||||
|
return router
|
||||||
+107
-13
@@ -283,6 +283,7 @@ _HOST_TO_CURATED = (
|
|||||||
("fireworks.ai", "fireworks"),
|
("fireworks.ai", "fireworks"),
|
||||||
("googleapis.com", "google"),
|
("googleapis.com", "google"),
|
||||||
("x.ai", "xai"),
|
("x.ai", "xai"),
|
||||||
|
|
||||||
("openrouter.ai", "openrouter"),
|
("openrouter.ai", "openrouter"),
|
||||||
("ollama.com", "ollama"),
|
("ollama.com", "ollama"),
|
||||||
("opencode.ai/zen/go", "opencode-go"),
|
("opencode.ai/zen/go", "opencode-go"),
|
||||||
@@ -493,6 +494,8 @@ _NON_CHAT_EXACT_PREFIXES = (
|
|||||||
def _is_chat_model(model_id: str) -> bool:
|
def _is_chat_model(model_id: str) -> bool:
|
||||||
"""Return True if the model ID looks like a chat/completions-capable model."""
|
"""Return True if the model ID looks like a chat/completions-capable model."""
|
||||||
mid = model_id.lower()
|
mid = model_id.lower()
|
||||||
|
if mid in {"gpt-5.1-codex"}:
|
||||||
|
return True
|
||||||
for prefix in _NON_CHAT_PREFIXES:
|
for prefix in _NON_CHAT_PREFIXES:
|
||||||
if mid.startswith(prefix):
|
if mid.startswith(prefix):
|
||||||
return False
|
return False
|
||||||
@@ -505,9 +508,67 @@ def _is_chat_model(model_id: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _probe_single_model(base: str, api_key: str, model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
def _delete_orphaned_provider_auth(db, auth_id: Optional[str], exclude_ep_id: Optional[str] = None) -> bool:
|
||||||
|
"""Delete a ProviderAuthSession once no endpoint still references it.
|
||||||
|
|
||||||
|
Subscription providers (e.g. ChatGPT Subscription) keep their refresh token
|
||||||
|
in ProviderAuthSession rather than ModelEndpoint.api_key. When the last
|
||||||
|
endpoint backed by that auth row is removed, the stored credentials should
|
||||||
|
be cleared instead of lingering. Returns True if a row was deleted.
|
||||||
|
``exclude_ep_id`` drops the endpoint currently being deleted from the
|
||||||
|
reference count so it does not keep its own auth alive.
|
||||||
|
"""
|
||||||
|
if not auth_id:
|
||||||
|
return False
|
||||||
|
from core.database import ProviderAuthSession
|
||||||
|
still_referenced = db.query(ModelEndpoint.id).filter(
|
||||||
|
ModelEndpoint.provider_auth_id == auth_id,
|
||||||
|
ModelEndpoint.id != exclude_ep_id,
|
||||||
|
).first()
|
||||||
|
if still_referenced is not None:
|
||||||
|
return False
|
||||||
|
auth_row = db.query(ProviderAuthSession).filter(ProviderAuthSession.id == auth_id).first()
|
||||||
|
if auth_row is None:
|
||||||
|
return False
|
||||||
|
db.delete(auth_row)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_discovery_only_provider(provider: str) -> bool:
|
||||||
|
"""Provider that only supports model discovery, not live probing.
|
||||||
|
|
||||||
|
ChatGPT Subscription speaks the Responses/Codex API and has no
|
||||||
|
chat-completions or general health endpoint, so completion probes and
|
||||||
|
reachability pings are skipped — status is derived from cached models.
|
||||||
|
"""
|
||||||
|
return provider == "chatgpt-subscription"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_probe_key(ep) -> Optional[str]:
|
||||||
|
"""API key/bearer to probe an endpoint with.
|
||||||
|
|
||||||
|
Delegates to ``resolve_endpoint_runtime``, which already returns the static
|
||||||
|
``ModelEndpoint.api_key`` for keyed endpoints and resolves (and refreshes)
|
||||||
|
the runtime bearer for session-backed providers (e.g. ChatGPT Subscription).
|
||||||
|
Returns None if resolution fails (e.g. re-auth required) so probing skips
|
||||||
|
rather than raising. Reads only already-loaded scalar attributes of ``ep``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
_base, key = resolve_endpoint_runtime(ep, owner=getattr(ep, "owner", None))
|
||||||
|
return key
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Probe key resolution failed for %s: %s", getattr(ep, "id", "?"), e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _probe_single_model(base: str, api_key: Optional[str], model_id: str, timeout: int = 10, with_tools: bool = False) -> dict:
|
||||||
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
"""Send a realistic completion request to a single model. Returns {status, latency_ms, error?}."""
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
|
if _is_discovery_only_provider(provider):
|
||||||
|
# Responses/Codex API, not chat-completions: a completion probe would
|
||||||
|
# 400 and the re-probe flow would then hide every model. Discovery-only.
|
||||||
|
return {"status": "ok", "latency_ms": 0, "skipped": True}
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Say OK"},
|
{"role": "user", "content": "Say OK"},
|
||||||
@@ -621,6 +682,11 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
For Anthropic, queries their /v1/models API, falling back to hardcoded list."""
|
||||||
from src.endpoint_resolver import resolve_url
|
from src.endpoint_resolver import resolve_url
|
||||||
base = resolve_url(_normalize_base(base_url))
|
base = resolve_url(_normalize_base(base_url))
|
||||||
|
if _detect_provider(base) == "chatgpt-subscription":
|
||||||
|
from src.chatgpt_subscription import fetch_available_models
|
||||||
|
if api_key:
|
||||||
|
return fetch_available_models(api_key, timeout=timeout)
|
||||||
|
return []
|
||||||
if _detect_provider(base) == "anthropic":
|
if _detect_provider(base) == "anthropic":
|
||||||
# Try Anthropic's /v1/models endpoint first
|
# Try Anthropic's /v1/models endpoint first
|
||||||
url = build_models_url(base)
|
url = build_models_url(base)
|
||||||
@@ -647,6 +713,10 @@ def _probe_endpoint(base_url: str, api_key: str = None, timeout: int = 5) -> Lis
|
|||||||
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
logger.warning(f"Anthropic /v1/models failed, using hardcoded list: {e}")
|
||||||
return list(ANTHROPIC_MODELS)
|
return list(ANTHROPIC_MODELS)
|
||||||
url = build_models_url(base)
|
url = build_models_url(base)
|
||||||
|
if not url:
|
||||||
|
curated_key = _match_provider_curated(base, None)
|
||||||
|
fallback = _PROVIDER_CURATED.get(curated_key) if curated_key else None
|
||||||
|
return list(fallback or [])
|
||||||
headers = build_headers(api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
try:
|
try:
|
||||||
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
r = httpx.get(url, headers=headers, timeout=timeout, verify=llm_verify())
|
||||||
@@ -998,6 +1068,17 @@ def setup_model_routes(model_discovery):
|
|||||||
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
ok, info = _should_refresh_endpoint(ep, now, force=force)
|
||||||
if not ok:
|
if not ok:
|
||||||
continue
|
continue
|
||||||
|
if getattr(ep, "provider_auth_id", None):
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
info["base"], info["api_key"] = resolve_endpoint_runtime(
|
||||||
|
ep,
|
||||||
|
owner=getattr(ep, "owner", None),
|
||||||
|
)
|
||||||
|
info["key"] = _refresh_key(info["base"], info["api_key"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Skipping model refresh for %s: could not resolve provider auth: %s", getattr(ep, "name", ep.id), e)
|
||||||
|
continue
|
||||||
groups.setdefault(info["key"], {
|
groups.setdefault(info["key"], {
|
||||||
"base": info["base"],
|
"base": info["base"],
|
||||||
"api_key": info["api_key"],
|
"api_key": info["api_key"],
|
||||||
@@ -1266,12 +1347,20 @@ def setup_model_routes(model_discovery):
|
|||||||
"endpoint_kind": kind,
|
"endpoint_kind": kind,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
t0 = _time.time()
|
if _is_discovery_only_provider(provider):
|
||||||
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
# No general health endpoint — an unauthenticated GET just
|
||||||
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
# 401s. Report status from cached models instead of pinging.
|
||||||
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
entry["latency_ms"] = None
|
||||||
entry["error"] = ping.get("error")
|
entry["status"] = "online" if cached_count else "offline"
|
||||||
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
entry["error"] = None
|
||||||
|
entry["model_count"] = cached_count
|
||||||
|
else:
|
||||||
|
t0 = _time.time()
|
||||||
|
ping = _ping_endpoint(base, ep.api_key, timeout=1.5)
|
||||||
|
entry["latency_ms"] = round((_time.time() - t0) * 1000)
|
||||||
|
entry["status"] = "online" if ping.get("reachable") or cached_count else "offline"
|
||||||
|
entry["error"] = ping.get("error")
|
||||||
|
entry["model_count"] = cached_count or (len(ANTHROPIC_MODELS) if provider == "anthropic" else 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
entry["latency_ms"] = None
|
entry["latency_ms"] = None
|
||||||
entry["status"] = "online" if cached_count else "offline"
|
entry["status"] = "online" if cached_count else "offline"
|
||||||
@@ -1304,7 +1393,7 @@ def setup_model_routes(model_discovery):
|
|||||||
if ep_id and ep_id not in endpoints_cache:
|
if ep_id and ep_id not in endpoints_cache:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||||
if ep:
|
if ep:
|
||||||
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": ep.api_key}
|
endpoints_cache[ep_id] = {"base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||||
ep_data = endpoints_cache.get(ep_id)
|
ep_data = endpoints_cache.get(ep_id)
|
||||||
if not ep_data:
|
if not ep_data:
|
||||||
# Try to find by base_url from the model's endpoint field
|
# Try to find by base_url from the model's endpoint field
|
||||||
@@ -1343,7 +1432,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"id": ep.id,
|
"id": ep.id,
|
||||||
"name": ep.name,
|
"name": ep.name,
|
||||||
"base_url": ep.base_url,
|
"base_url": ep.base_url,
|
||||||
"api_key": ep.api_key,
|
"api_key": _resolve_probe_key(ep),
|
||||||
})
|
})
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -1432,12 +1521,14 @@ def setup_model_routes(model_discovery):
|
|||||||
# Endpoint counts as reachable if it has any model — including
|
# Endpoint counts as reachable if it has any model — including
|
||||||
# admin-pinned IDs that a probe would never surface.
|
# admin-pinned IDs that a probe would never surface.
|
||||||
status = "online" if (all_models or pinned) else "offline"
|
status = "online" if (all_models or pinned) else "offline"
|
||||||
|
base = _normalize_base(r.base_url)
|
||||||
ping = None
|
ping = None
|
||||||
if not all_models and not pinned and r.is_enabled:
|
# Discovery-only providers have no health endpoint — an
|
||||||
|
# unauthenticated ping just 401s, so don't bother.
|
||||||
|
if not all_models and not pinned and r.is_enabled and not _is_discovery_only_provider(_detect_provider(base)):
|
||||||
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
ping = _ping_endpoint(r.base_url, r.api_key, timeout=1.0)
|
||||||
if ping.get("reachable"):
|
if ping.get("reachable"):
|
||||||
status = "empty"
|
status = "empty"
|
||||||
base = _normalize_base(r.base_url)
|
|
||||||
kind = _effective_endpoint_kind(r, base)
|
kind = _effective_endpoint_kind(r, base)
|
||||||
results.append({
|
results.append({
|
||||||
"id": r.id,
|
"id": r.id,
|
||||||
@@ -1713,7 +1804,7 @@ def setup_model_routes(model_discovery):
|
|||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == ep_id).first()
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found")
|
raise HTTPException(404, "Endpoint not found")
|
||||||
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": ep.api_key}
|
ep_data = {"id": ep.id, "name": ep.name, "base_url": ep.base_url, "api_key": _resolve_probe_key(ep)}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -1777,7 +1868,7 @@ def setup_model_routes(model_discovery):
|
|||||||
category = _classify_endpoint(base, kind)
|
category = _classify_endpoint(base, kind)
|
||||||
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
timeout = _manual_refresh_timeout(ep, category, refresh_timeout)
|
||||||
try:
|
try:
|
||||||
probed = _probe_endpoint(base, ep.api_key, timeout=timeout)
|
probed = _probe_endpoint(base, _resolve_probe_key(ep), timeout=timeout)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
logger.warning("Manual model refresh failed for endpoint %s at %s: %s", ep_id, base, exc)
|
||||||
probed = []
|
probed = []
|
||||||
@@ -2116,7 +2207,9 @@ def setup_model_routes(model_discovery):
|
|||||||
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
cleared_user_preferences = _clear_user_prefs_for_endpoint(ep_id)
|
||||||
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
cleared_sessions = _clear_sessions_for_endpoint(db, ep.base_url)
|
||||||
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
cleared_loaded_sessions = _clear_loaded_sessions_for_endpoint(ep.base_url)
|
||||||
|
auth_id = getattr(ep, "provider_auth_id", None)
|
||||||
db.delete(ep)
|
db.delete(ep)
|
||||||
|
cleared_provider_auth = _delete_orphaned_provider_auth(db, auth_id, exclude_ep_id=ep_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
_invalidate_models_cache()
|
_invalidate_models_cache()
|
||||||
_local_probe_cache["data"] = None
|
_local_probe_cache["data"] = None
|
||||||
@@ -2126,6 +2219,7 @@ def setup_model_routes(model_discovery):
|
|||||||
"cleared_user_preferences": cleared_user_preferences,
|
"cleared_user_preferences": cleared_user_preferences,
|
||||||
"cleared_sessions": cleared_sessions,
|
"cleared_sessions": cleared_sessions,
|
||||||
"cleared_loaded_sessions": cleared_loaded_sessions,
|
"cleared_loaded_sessions": cleared_loaded_sessions,
|
||||||
|
"cleared_provider_auth": cleared_provider_auth,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
+39
-26
@@ -75,6 +75,38 @@ def _owned_enabled_endpoint(db, owner, endpoint_id=None):
|
|||||||
return owner_filter(q, ModelEndpoint, owner).first()
|
return owner_filter(q, ModelEndpoint, owner).first()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_endpoint_runtime(ep, owner=None, model: Optional[str] = None):
|
||||||
|
"""Resolve a ModelEndpoint row into (chat_url, model, headers).
|
||||||
|
|
||||||
|
Mirrors endpoint_resolver.resolve_endpoint's provider-auth handling for
|
||||||
|
panel-selected research endpoints. ChatGPT Subscription endpoints keep
|
||||||
|
OAuth tokens in ProviderAuthSession, so ep.api_key is intentionally empty.
|
||||||
|
"""
|
||||||
|
from src.endpoint_resolver import (
|
||||||
|
build_chat_url,
|
||||||
|
build_headers,
|
||||||
|
resolve_endpoint_runtime as resolve_model_endpoint_runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
base, api_key = resolve_model_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not resolve endpoint credentials for research: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
ep_model = (model or "").strip()
|
||||||
|
if not ep_model:
|
||||||
|
try:
|
||||||
|
models = json.loads(ep.cached_models) if ep.cached_models else []
|
||||||
|
if models:
|
||||||
|
ep_model = _first_chat_model(models)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not ep_model:
|
||||||
|
return None
|
||||||
|
return build_chat_url(base), ep_model, build_headers(api_key, base)
|
||||||
|
|
||||||
|
|
||||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||||
router = APIRouter(tags=["research"])
|
router = APIRouter(tags=["research"])
|
||||||
|
|
||||||
@@ -371,7 +403,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
|
|
||||||
if body.endpoint_id:
|
if body.endpoint_id:
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Owner-scoped: never resolve another user's private endpoint
|
# Owner-scoped: never resolve another user's private endpoint
|
||||||
@@ -380,18 +411,10 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
ep = _owned_enabled_endpoint(db, user, body.endpoint_id)
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(404, "Endpoint not found or disabled")
|
raise HTTPException(404, "Endpoint not found or disabled")
|
||||||
base = normalize_base(ep.base_url)
|
resolved = _resolve_endpoint_runtime(ep, owner=user, model=body.model)
|
||||||
ep_url = build_chat_url(base)
|
if not resolved:
|
||||||
ep_headers = build_headers(ep.api_key, base)
|
raise HTTPException(400, "Endpoint is not configured with a usable model.")
|
||||||
ep_model = body.model or ""
|
ep_url, ep_model, ep_headers = resolved
|
||||||
if not ep_model:
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
|
||||||
if models:
|
|
||||||
ep_model = _first_chat_model(models)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
else:
|
else:
|
||||||
@@ -408,7 +431,6 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
ep_url, ep_model, ep_headers = resolve_endpoint("chat", owner=user)
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
from src.database import SessionLocal
|
from src.database import SessionLocal
|
||||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
# Owner-scoped first-enabled fallback: the caller's own rows
|
# Owner-scoped first-enabled fallback: the caller's own rows
|
||||||
@@ -417,18 +439,9 @@ def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
|||||||
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
# /api/v1/chat fallback (webhook_routes._first_enabled_endpoint).
|
||||||
ep = _owned_enabled_endpoint(db, user)
|
ep = _owned_enabled_endpoint(db, user)
|
||||||
if ep:
|
if ep:
|
||||||
base = normalize_base(ep.base_url)
|
resolved = _resolve_endpoint_runtime(ep, owner=user)
|
||||||
ep_url = build_chat_url(base)
|
if resolved:
|
||||||
ep_headers = build_headers(ep.api_key, base)
|
ep_url, ep_model, ep_headers = resolved
|
||||||
ep_model = ""
|
|
||||||
if ep.cached_models:
|
|
||||||
try:
|
|
||||||
import json as _json
|
|
||||||
models = _json.loads(ep.cached_models)
|
|
||||||
if models:
|
|
||||||
ep_model = _first_chat_model(models)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
if not ep_url:
|
if not ep_url:
|
||||||
|
|||||||
+22
-17
@@ -92,18 +92,13 @@ def _reject_compact_during_active_run(session_id: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
|
||||||
"""Verify the current user owns the session. Raises 404 if not.
|
"""Verify the current user owns the session, honoring single-user modes.
|
||||||
|
|
||||||
Ownership is checked against the DB row when one exists (unchanged). If
|
Authenticated requests must match the stored DB or in-memory owner. When
|
||||||
there is no DB row but the caller owns an in-memory "ghost" session — one
|
auth is disabled and no user is present, treat the app as single-user mode:
|
||||||
that lives only in ``session_manager`` because it was never persisted, or
|
verify that the session exists, but do not compare its stored owner. This
|
||||||
its DB row was removed out-of-band — fall back to the in-memory owner so the
|
keeps QA/dev instances with AUTH_ENABLED=false from rejecting owner-stamped
|
||||||
user can still manage and delete it. Without this fallback such sessions are
|
rows created while auth was previously enabled.
|
||||||
listed by ``/api/sessions`` (they come from the in-memory manager) yet every
|
|
||||||
per-session operation 404s, making them impossible to delete (issue #1044).
|
|
||||||
|
|
||||||
``session_manager`` is optional and defaults to ``None`` so existing callers
|
|
||||||
that only care about persisted sessions keep their exact prior behavior.
|
|
||||||
"""
|
"""
|
||||||
user = effective_user(request)
|
user = effective_user(request)
|
||||||
if not user and not _auth_disabled():
|
if not user and not _auth_disabled():
|
||||||
@@ -114,13 +109,13 @@ def _verify_session_owner(request: Request, session_id: str, session_manager=Non
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
if row is not None:
|
if row is not None:
|
||||||
if row.owner != user:
|
if user and row.owner != user:
|
||||||
raise HTTPException(404, f"Session {session_id} not found")
|
raise HTTPException(404, f"Session {session_id} not found")
|
||||||
return
|
return
|
||||||
# No DB row — allow the caller to act on an in-memory ghost they own.
|
# No DB row — allow the caller to act on an in-memory ghost they own.
|
||||||
if session_manager is not None:
|
if session_manager is not None:
|
||||||
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
ghost = getattr(session_manager, "sessions", {}).get(session_id)
|
||||||
if ghost is not None and getattr(ghost, "owner", None) == user:
|
if ghost is not None and (not user or getattr(ghost, "owner", None) == user):
|
||||||
return
|
return
|
||||||
raise HTTPException(404, f"Session {session_id} not found")
|
raise HTTPException(404, f"Session {session_id} not found")
|
||||||
|
|
||||||
@@ -372,8 +367,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
pass
|
pass
|
||||||
elif not model_to_use:
|
elif not model_to_use:
|
||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
ids = list_model_ids(
|
||||||
headers=validation_headers)
|
endpoint_url,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=validation_headers,
|
||||||
|
owner=user,
|
||||||
|
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||||
|
)
|
||||||
if not ids:
|
if not ids:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
# Default to the first CHAT model — endpoints often list embedding/
|
# Default to the first CHAT model — endpoints often list embedding/
|
||||||
@@ -387,8 +387,13 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
import os as _os
|
import os as _os
|
||||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
avail = list_model_ids(
|
||||||
headers=validation_headers)
|
endpoint_url,
|
||||||
|
timeout=REQUEST_TIMEOUT,
|
||||||
|
headers=validation_headers,
|
||||||
|
owner=user,
|
||||||
|
endpoint_id=endpoint_id.strip() if endpoint_id else None,
|
||||||
|
)
|
||||||
if not avail:
|
if not avail:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
if model_to_use not in avail:
|
if model_to_use not in avail:
|
||||||
|
|||||||
@@ -1109,6 +1109,35 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
idx = skills_manager.index_for(owner=user)
|
idx = skills_manager.index_for(owner=user)
|
||||||
return {"index": idx, "count": len(idx)}
|
return {"index": idx, "count": len(idx)}
|
||||||
|
|
||||||
|
@router.get("/slash-catalog")
|
||||||
|
async def get_slash_catalog(request: Request):
|
||||||
|
"""Return skills that are available as slash commands.
|
||||||
|
|
||||||
|
Mirrors the agent prompt's published-skill index so the UI never offers
|
||||||
|
a slash command the model would not normally be allowed to discover.
|
||||||
|
"""
|
||||||
|
user = _owner(request)
|
||||||
|
all_skills = {s.get("name"): s for s in skills_manager.load(owner=user)}
|
||||||
|
entries = []
|
||||||
|
for s in skills_manager.index_for(owner=user):
|
||||||
|
name = (s.get("name") or "").strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
full = all_skills.get(name) or {}
|
||||||
|
category = (s.get("category") or full.get("category") or "general").strip() or "general"
|
||||||
|
entries.append({
|
||||||
|
"type": "skill",
|
||||||
|
"token": f"/{name}",
|
||||||
|
"name": name,
|
||||||
|
"category": f"Skills / {category}",
|
||||||
|
"help": s.get("description") or full.get("description") or "",
|
||||||
|
"usage": f"/{name} <request>",
|
||||||
|
"uses": int(full.get("uses") or 0),
|
||||||
|
"last_used": full.get("last_used"),
|
||||||
|
})
|
||||||
|
entries.sort(key=lambda row: row["name"])
|
||||||
|
return {"skills": entries, "count": len(entries)}
|
||||||
|
|
||||||
@router.get("/builtin")
|
@router.get("/builtin")
|
||||||
async def list_builtin_skills(request: Request):
|
async def list_builtin_skills(request: Request):
|
||||||
"""Read-only list of the agent's built-in tool capabilities (research,
|
"""Read-only list of the agent's built-in tool capabilities (research,
|
||||||
@@ -1272,6 +1301,47 @@ def setup_skills_routes(skills_manager: SkillsManager) -> APIRouter:
|
|||||||
_fire_skill_added(user)
|
_fire_skill_added(user)
|
||||||
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
return {"ok": True, "deduped": bool(entry.get("_deduped")), "skill": entry}
|
||||||
|
|
||||||
|
@router.post("/{skill_id}/invoke")
|
||||||
|
async def invoke_skill(request: Request, skill_id: str):
|
||||||
|
"""Build a skill-pinned prompt for slash-command invocation.
|
||||||
|
|
||||||
|
This is intentionally server-side so availability, ownership, and usage
|
||||||
|
accounting use the same rules as the SkillsManager.
|
||||||
|
"""
|
||||||
|
user = _owner(request)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
body = {}
|
||||||
|
request_text = (body.get("request") or "").strip() if isinstance(body, dict) else ""
|
||||||
|
|
||||||
|
invokable = {
|
||||||
|
s.get("name"): s for s in skills_manager.index_for(owner=user)
|
||||||
|
if (s.get("name") or "").strip()
|
||||||
|
}
|
||||||
|
match = invokable.get(skill_id)
|
||||||
|
if not match:
|
||||||
|
raise HTTPException(404, "Skill is not available for slash invocation")
|
||||||
|
|
||||||
|
name = match.get("name")
|
||||||
|
md = skills_manager.read_skill_md(name, owner=user)
|
||||||
|
if md is None:
|
||||||
|
raise HTTPException(404, "Skill source unavailable")
|
||||||
|
|
||||||
|
skills_manager.record_use(name, owner=user)
|
||||||
|
message = (
|
||||||
|
"Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n"
|
||||||
|
f"--- BEGIN SKILL ---\n{md}\n--- END SKILL ---\n\n"
|
||||||
|
+ (f"Request: {request_text}" if request_text else "Request: (use the skill as appropriate)")
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"type": "skill",
|
||||||
|
"name": name,
|
||||||
|
"command": f"/{name}",
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
@router.get("/{skill_id}")
|
@router.get("/{skill_id}")
|
||||||
async def get_skill(request: Request, skill_id: str):
|
async def get_skill(request: Request, skill_id: str):
|
||||||
user = _owner(request)
|
user = _owner(request)
|
||||||
|
|||||||
+21
-10
@@ -325,22 +325,33 @@ def setup_webhook_routes(
|
|||||||
endpoint_url = build_chat_url(base_url)
|
endpoint_url = build_chat_url(base_url)
|
||||||
model = body.model or "auto"
|
model = body.model or "auto"
|
||||||
api_key = ep.api_key
|
api_key = ep.api_key
|
||||||
|
if getattr(ep, "provider_auth_id", None):
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import resolve_endpoint_runtime
|
||||||
|
base_url, api_key = resolve_endpoint_runtime(ep, owner=token_owner)
|
||||||
|
endpoint_url = build_chat_url(base_url)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(500, "Could not resolve endpoint credentials")
|
||||||
|
|
||||||
if model == "auto":
|
if model == "auto":
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5) as client:
|
async with httpx.AsyncClient(timeout=5) as client:
|
||||||
models_url = build_models_url(base_url)
|
models_url = build_models_url(base_url)
|
||||||
hdrs = build_headers(api_key, base_url)
|
hdrs = build_headers(api_key, base_url)
|
||||||
resp = await client.get(models_url, headers=hdrs)
|
if models_url:
|
||||||
resp.raise_for_status()
|
resp = await client.get(models_url, headers=hdrs)
|
||||||
data = resp.json()
|
resp.raise_for_status()
|
||||||
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
data = resp.json()
|
||||||
if not ids:
|
ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
ids = [
|
if not ids:
|
||||||
m.get("name") or m.get("model")
|
ids = [
|
||||||
for m in (data.get("models") or [])
|
m.get("name") or m.get("model")
|
||||||
if m.get("name") or m.get("model")
|
for m in (data.get("models") or [])
|
||||||
]
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
import json as _json
|
||||||
|
ids = _json.loads(ep.cached_models or "[]")
|
||||||
model = ids[0] if ids else "auto"
|
model = ids[0] if ids else "auto"
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(500, "Could not discover models from endpoint")
|
raise HTTPException(500, "Could not discover models from endpoint")
|
||||||
|
|||||||
+39
-25
@@ -57,7 +57,7 @@ def set_rag_manager(rag_mgr, personal_docs_mgr=None):
|
|||||||
# Model resolution
|
# Model resolution
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from src.endpoint_resolver import normalize_base as _normalize_base, build_chat_url, build_headers, build_models_url
|
from src.endpoint_resolver import build_chat_url, build_headers, build_models_url, resolve_endpoint_runtime
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Dict]:
|
||||||
@@ -98,9 +98,12 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
|||||||
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
(f" matching '{target_endpoint_name}'" if target_endpoint_name else ""))
|
||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
# Anthropic: match against hardcoded model list
|
# Anthropic: match against hardcoded model list
|
||||||
@@ -114,16 +117,20 @@ def _resolve_model(spec: str, owner: Optional[str] = None) -> Tuple[str, str, Di
|
|||||||
else:
|
else:
|
||||||
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
# OpenAI-compatible and native Ollama: probe the provider's model list.
|
||||||
try:
|
try:
|
||||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
models_url = build_models_url(base)
|
||||||
r.raise_for_status()
|
if models_url:
|
||||||
data = r.json()
|
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
r.raise_for_status()
|
||||||
if not model_ids:
|
data = r.json()
|
||||||
model_ids = [
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
m.get("name") or m.get("model")
|
if not model_ids:
|
||||||
for m in (data.get("models") or [])
|
model_ids = [
|
||||||
if m.get("name") or m.get("model")
|
m.get("name") or m.get("model")
|
||||||
]
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_ids = json.loads(ep.cached_models or "[]")
|
||||||
except Exception:
|
except Exception:
|
||||||
model_ids = []
|
model_ids = []
|
||||||
|
|
||||||
@@ -1121,25 +1128,32 @@ async def do_list_models(content: str, session_id: Optional[str] = None, owner:
|
|||||||
total_models = 0
|
total_models = 0
|
||||||
|
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
base = _normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
model_ids = []
|
model_ids = []
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
model_ids = list(ANTHROPIC_MODELS)
|
model_ids = list(ANTHROPIC_MODELS)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
r = httpx.get(build_models_url(base), headers=headers, timeout=5)
|
models_url = build_models_url(base)
|
||||||
r.raise_for_status()
|
if models_url:
|
||||||
data = r.json()
|
r = httpx.get(models_url, headers=headers, timeout=5)
|
||||||
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
r.raise_for_status()
|
||||||
if not model_ids:
|
data = r.json()
|
||||||
model_ids = [
|
model_ids = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||||
m.get("name") or m.get("model")
|
if not model_ids:
|
||||||
for m in (data.get("models") or [])
|
model_ids = [
|
||||||
if m.get("name") or m.get("model")
|
m.get("name") or m.get("model")
|
||||||
]
|
for m in (data.get("models") or [])
|
||||||
|
if m.get("name") or m.get("model")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_ids = json.loads(ep.cached_models or "[]")
|
||||||
except Exception:
|
except Exception:
|
||||||
model_ids = ["(endpoint offline)"]
|
model_ids = ["(endpoint offline)"]
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,311 @@
|
|||||||
|
"""ChatGPT subscription / Codex backend OAuth helpers.
|
||||||
|
|
||||||
|
This provider is intentionally separate from OpenAI API-key endpoints. It uses
|
||||||
|
OpenAI account OAuth device authorization, stores refresh tokens server-side,
|
||||||
|
and resolves a fresh bearer token at request time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from core.database import ProviderAuthSession, SessionLocal, utcnow_naive
|
||||||
|
|
||||||
|
DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL = (
|
||||||
|
os.getenv("CHATGPT_SUBSCRIPTION_BASE_URL", "").strip().rstrip("/")
|
||||||
|
or "https://chatgpt.com/backend-api/codex"
|
||||||
|
)
|
||||||
|
CHATGPT_SUBSCRIPTION_PROVIDER = "chatgpt-subscription"
|
||||||
|
CHATGPT_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
|
CHATGPT_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||||
|
CHATGPT_OAUTH_ISSUER = "https://auth.openai.com"
|
||||||
|
CHATGPT_OAUTH_REDIRECT_URI = f"{CHATGPT_OAUTH_ISSUER}/deviceauth/callback"
|
||||||
|
CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||||
|
_AUTH_REFRESH_LOCKS: dict[str, threading.Lock] = {}
|
||||||
|
_AUTH_REFRESH_LOCKS_GUARD = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_lock_for(auth_id: str) -> threading.Lock:
|
||||||
|
with _AUTH_REFRESH_LOCKS_GUARD:
|
||||||
|
lock = _AUTH_REFRESH_LOCKS.get(auth_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = threading.Lock()
|
||||||
|
_AUTH_REFRESH_LOCKS[auth_id] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionError(RuntimeError):
|
||||||
|
"""Base error for ChatGPT subscription provider failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionReauthRequired(ChatGPTSubscriptionError):
|
||||||
|
"""Stored OAuth credentials are invalid or expired beyond refresh."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionRateLimited(ChatGPTSubscriptionError):
|
||||||
|
"""Upstream quota/rate limit; reconnecting will not fix it."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTSubscriptionAuthNotFound(ChatGPTSubscriptionError):
|
||||||
|
"""No matching owner-scoped auth session exists."""
|
||||||
|
|
||||||
|
|
||||||
|
def is_chatgpt_subscription_base(url: str) -> bool:
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(url or "")
|
||||||
|
host = (parsed.hostname or "").lower().rstrip(".")
|
||||||
|
path = (parsed.path or "").rstrip("/")
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return host == "chatgpt.com" and (
|
||||||
|
path == "/backend-api/codex" or path.startswith("/backend-api/codex/")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chatgpt_headers(access_token: Optional[str]) -> Dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"Origin": "https://chatgpt.com",
|
||||||
|
"Referer": "https://chatgpt.com/codex",
|
||||||
|
"User-Agent": "Odysseus ChatGPT Subscription",
|
||||||
|
}
|
||||||
|
if access_token:
|
||||||
|
headers["Authorization"] = f"Bearer {access_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_available_models(access_token: str, timeout: float = 10.0) -> list[str]:
|
||||||
|
if not access_token:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
response = httpx.get(
|
||||||
|
"https://chatgpt.com/backend-api/codex/models?client_version=1.0.0",
|
||||||
|
headers=chatgpt_headers(access_token),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
return []
|
||||||
|
data = response.json()
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
entries = data.get("models", []) if isinstance(data, dict) else []
|
||||||
|
sortable: list[tuple[int, str]] = []
|
||||||
|
for item in entries:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
slug = item.get("slug")
|
||||||
|
if not isinstance(slug, str) or not slug.strip():
|
||||||
|
continue
|
||||||
|
visibility = item.get("visibility", "")
|
||||||
|
if isinstance(visibility, str) and visibility.strip().lower() in {"hide", "hidden"}:
|
||||||
|
continue
|
||||||
|
priority = item.get("priority")
|
||||||
|
rank = int(priority) if isinstance(priority, (int, float)) else 10_000
|
||||||
|
sortable.append((rank, slug.strip()))
|
||||||
|
sortable.sort(key=lambda item: (item[0], item[1]))
|
||||||
|
ordered: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for _, slug in sortable:
|
||||||
|
if slug not in seen:
|
||||||
|
ordered.append(slug)
|
||||||
|
seen.add(slug)
|
||||||
|
return ordered
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_for_oauth_response(response: httpx.Response, action: str) -> None:
|
||||||
|
if response.status_code < 400:
|
||||||
|
return
|
||||||
|
code = ""
|
||||||
|
message = f"ChatGPT Subscription {action} failed with HTTP {response.status_code}."
|
||||||
|
try:
|
||||||
|
payload = response.json()
|
||||||
|
err = payload.get("error") if isinstance(payload, dict) else None
|
||||||
|
if isinstance(err, dict):
|
||||||
|
code = str(err.get("code") or err.get("type") or "").strip()
|
||||||
|
msg = err.get("message")
|
||||||
|
if msg:
|
||||||
|
message = f"ChatGPT Subscription {action} failed: {msg}"
|
||||||
|
elif isinstance(err, str):
|
||||||
|
code = err.strip()
|
||||||
|
desc = payload.get("error_description") or payload.get("message")
|
||||||
|
if desc:
|
||||||
|
message = f"ChatGPT Subscription {action} failed: {desc}"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if response.status_code == 429:
|
||||||
|
raise ChatGPTSubscriptionRateLimited(
|
||||||
|
"ChatGPT Subscription quota or rate limit was reached. Credentials are still valid."
|
||||||
|
)
|
||||||
|
if response.status_code in (401, 403) or code in {"invalid_grant", "invalid_token", "invalid_request", "refresh_token_reused"}:
|
||||||
|
raise ChatGPTSubscriptionReauthRequired(message)
|
||||||
|
raise ChatGPTSubscriptionError(message)
|
||||||
|
|
||||||
|
|
||||||
|
def _json_or_error(response: httpx.Response, action: str) -> Dict[str, Any]:
|
||||||
|
_raise_for_oauth_response(response, action)
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except Exception as exc:
|
||||||
|
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned invalid JSON.") from exc
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ChatGPTSubscriptionError(f"ChatGPT Subscription {action} returned an unexpected response.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def request_device_code(timeout: float = 15.0) -> Dict[str, Any]:
|
||||||
|
response = httpx.post(
|
||||||
|
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/usercode",
|
||||||
|
json={"client_id": CHATGPT_OAUTH_CLIENT_ID},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
data = _json_or_error(response, "device-code request")
|
||||||
|
if not data.get("device_auth_id") or not data.get("user_code"):
|
||||||
|
raise ChatGPTSubscriptionError("ChatGPT device-code response was missing required fields.")
|
||||||
|
data.setdefault("verification_uri", f"{CHATGPT_OAUTH_ISSUER}/codex/device")
|
||||||
|
data.setdefault("interval", 5)
|
||||||
|
data.setdefault("expires_in", 900)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def poll_device_auth(device_auth_id: str, user_code: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||||
|
response = httpx.post(
|
||||||
|
f"{CHATGPT_OAUTH_ISSUER}/api/accounts/deviceauth/token",
|
||||||
|
json={"device_auth_id": device_auth_id, "user_code": user_code},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if response.status_code in (403, 404):
|
||||||
|
return {"status": "pending", "error": "authorization_pending"}
|
||||||
|
return _json_or_error(response, "device-code poll")
|
||||||
|
|
||||||
|
|
||||||
|
def exchange_authorization_code(authorization_code: str, code_verifier: str, timeout: float = 15.0) -> Dict[str, Any]:
|
||||||
|
response = httpx.post(
|
||||||
|
CHATGPT_OAUTH_TOKEN_URL,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": authorization_code,
|
||||||
|
"redirect_uri": CHATGPT_OAUTH_REDIRECT_URI,
|
||||||
|
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
data = _json_or_error(response, "token exchange")
|
||||||
|
if not data.get("access_token"):
|
||||||
|
raise ChatGPTSubscriptionReauthRequired("ChatGPT token exchange did not return an access token.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_oauth_tokens(access_token: str, refresh_token: str, timeout: float = 20.0) -> Dict[str, Any]:
|
||||||
|
del access_token
|
||||||
|
if not refresh_token:
|
||||||
|
raise ChatGPTSubscriptionReauthRequired("ChatGPT Subscription is missing a refresh token. Reconnect the provider.")
|
||||||
|
response = httpx.post(
|
||||||
|
CHATGPT_OAUTH_TOKEN_URL,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"client_id": CHATGPT_OAUTH_CLIENT_ID,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
data = _json_or_error(response, "token refresh")
|
||||||
|
if not data.get("access_token"):
|
||||||
|
raise ChatGPTSubscriptionReauthRequired("ChatGPT token refresh did not return an access token.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_jwt_payload(token: str) -> Dict[str, Any]:
|
||||||
|
parts = (token or "").split(".")
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError("not a JWT")
|
||||||
|
segment = parts[1]
|
||||||
|
segment += "=" * (-len(segment) % 4)
|
||||||
|
raw = base64.urlsafe_b64decode(segment.encode("ascii"))
|
||||||
|
payload = json.loads(raw.decode("utf-8"))
|
||||||
|
return payload if isinstance(payload, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def access_token_is_expiring(access_token: str, skew_seconds: int = CHATGPT_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
|
||||||
|
try:
|
||||||
|
exp = int(_decode_jwt_payload(access_token).get("exp") or 0)
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
return exp <= int(time.time()) + int(skew_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_runtime_credentials(auth_id: str, owner: Optional[str] = None, *, force_refresh: bool = False) -> Dict[str, Any]:
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
q = db.query(ProviderAuthSession).filter(
|
||||||
|
ProviderAuthSession.id == auth_id,
|
||||||
|
ProviderAuthSession.provider == CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
)
|
||||||
|
if owner:
|
||||||
|
q = q.filter(ProviderAuthSession.owner == owner)
|
||||||
|
row = q.first()
|
||||||
|
if row is None:
|
||||||
|
raise ChatGPTSubscriptionAuthNotFound("ChatGPT Subscription credentials were not found for this user.")
|
||||||
|
|
||||||
|
access_token = row.access_token or ""
|
||||||
|
if force_refresh or access_token_is_expiring(access_token):
|
||||||
|
with _refresh_lock_for(auth_id):
|
||||||
|
db.refresh(row)
|
||||||
|
access_token = row.access_token or ""
|
||||||
|
refresh_token = row.refresh_token or ""
|
||||||
|
if force_refresh or access_token_is_expiring(access_token):
|
||||||
|
refreshed = refresh_oauth_tokens(access_token, refresh_token)
|
||||||
|
row.access_token = refreshed["access_token"]
|
||||||
|
if refreshed.get("refresh_token"):
|
||||||
|
row.refresh_token = refreshed["refresh_token"]
|
||||||
|
row.last_refresh = utcnow_naive()
|
||||||
|
db.commit()
|
||||||
|
db.refresh(row)
|
||||||
|
access_token = row.access_token or ""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"provider": CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
"base_url": (row.base_url or DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL).rstrip("/"),
|
||||||
|
"api_key": access_token,
|
||||||
|
"auth_mode": row.auth_mode or "chatgpt",
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def to_http_exception(exc: Exception) -> HTTPException:
|
||||||
|
if isinstance(exc, ChatGPTSubscriptionRateLimited):
|
||||||
|
return HTTPException(429, str(exc))
|
||||||
|
if isinstance(exc, (ChatGPTSubscriptionReauthRequired, ChatGPTSubscriptionAuthNotFound)):
|
||||||
|
return HTTPException(401, f"{exc} Reconnect the provider.")
|
||||||
|
return HTTPException(502, str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
def build_responses_input(messages: list[dict]) -> list[dict]:
|
||||||
|
input_items: list[dict] = []
|
||||||
|
for msg in messages or []:
|
||||||
|
role = msg.get("role") or "user"
|
||||||
|
if role == "tool":
|
||||||
|
role = "user"
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text = "\n".join(str(part.get("text") or part.get("content") or "") for part in content if isinstance(part, dict))
|
||||||
|
else:
|
||||||
|
text = "" if content is None else str(content)
|
||||||
|
input_type = "output_text" if role == "assistant" else "input_text"
|
||||||
|
input_items.append({"role": role, "content": [{"type": input_type, "text": text}]})
|
||||||
|
return input_items
|
||||||
@@ -70,6 +70,25 @@ def _endpoint_enabled_models(ep) -> list:
|
|||||||
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
|
return [m for m in _endpoint_cached_models(ep) if m not in hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_endpoint_runtime(ep, owner: Optional[str] = None) -> Tuple[str, Optional[str]]:
|
||||||
|
"""Resolve a ModelEndpoint row to its runtime base URL and bearer/API key.
|
||||||
|
|
||||||
|
Static-key providers use ``ModelEndpoint.api_key``. Session-backed providers
|
||||||
|
store refreshable credentials in ProviderAuthSession and must resolve a
|
||||||
|
current access token at call time.
|
||||||
|
"""
|
||||||
|
base = normalize_base(getattr(ep, "base_url", "") or "")
|
||||||
|
api_key = getattr(ep, "api_key", None)
|
||||||
|
auth_id = getattr(ep, "provider_auth_id", None)
|
||||||
|
if auth_id:
|
||||||
|
from src.chatgpt_subscription import resolve_runtime_credentials
|
||||||
|
|
||||||
|
creds = resolve_runtime_credentials(auth_id, owner=owner)
|
||||||
|
base = normalize_base(creds.get("base_url") or base)
|
||||||
|
api_key = creds.get("api_key")
|
||||||
|
return base, api_key
|
||||||
|
|
||||||
|
|
||||||
# Cache for Tailscale hostname → IP resolution
|
# Cache for Tailscale hostname → IP resolution
|
||||||
_tailscale_cache: Dict[str, Optional[str]] = {}
|
_tailscale_cache: Dict[str, Optional[str]] = {}
|
||||||
|
|
||||||
@@ -133,7 +152,7 @@ def resolve_url(url: str) -> str:
|
|||||||
def normalize_base(url: str) -> str:
|
def normalize_base(url: str) -> str:
|
||||||
"""Strip known API path suffixes from a base URL."""
|
"""Strip known API path suffixes from a base URL."""
|
||||||
url = (url or "").strip().rstrip("/")
|
url = (url or "").strip().rstrip("/")
|
||||||
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages"]:
|
for suffix in ["/models", "/chat/completions", "/completions", "/v1/messages", "/responses"]:
|
||||||
if url.endswith(suffix):
|
if url.endswith(suffix):
|
||||||
url = url[: -len(suffix)].rstrip("/")
|
url = url[: -len(suffix)].rstrip("/")
|
||||||
for suffix in ["/chat", "/tags", "/generate"]:
|
for suffix in ["/chat", "/tags", "/generate"]:
|
||||||
@@ -158,10 +177,12 @@ def build_chat_url(base: str) -> str:
|
|||||||
return _anthropic_api_root(base) + "/v1/messages"
|
return _anthropic_api_root(base) + "/v1/messages"
|
||||||
if provider == "ollama":
|
if provider == "ollama":
|
||||||
return _ollama_api_root(base) + "/chat"
|
return _ollama_api_root(base) + "/chat"
|
||||||
|
if provider == "chatgpt-subscription":
|
||||||
|
return base.rstrip("/") + "/responses"
|
||||||
return base + "/chat/completions"
|
return base + "/chat/completions"
|
||||||
|
|
||||||
|
|
||||||
def build_models_url(base: str) -> str:
|
def build_models_url(base: str) -> Optional[str]:
|
||||||
"""Return the provider-specific model-list endpoint URL for a base."""
|
"""Return the provider-specific model-list endpoint URL for a base."""
|
||||||
base = resolve_url(base)
|
base = resolve_url(base)
|
||||||
provider = _detect_provider(base)
|
provider = _detect_provider(base)
|
||||||
@@ -169,6 +190,8 @@ def build_models_url(base: str) -> str:
|
|||||||
return _anthropic_api_root(base) + "/v1/models"
|
return _anthropic_api_root(base) + "/v1/models"
|
||||||
if provider == "ollama":
|
if provider == "ollama":
|
||||||
return _ollama_api_root(base) + "/tags"
|
return _ollama_api_root(base) + "/tags"
|
||||||
|
if provider == "chatgpt-subscription":
|
||||||
|
return None
|
||||||
return base + "/models"
|
return base + "/models"
|
||||||
|
|
||||||
|
|
||||||
@@ -184,6 +207,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
|||||||
if provider == "copilot":
|
if provider == "copilot":
|
||||||
from src.copilot import copilot_headers
|
from src.copilot import copilot_headers
|
||||||
return copilot_headers(api_key)
|
return copilot_headers(api_key)
|
||||||
|
if provider == "chatgpt-subscription":
|
||||||
|
from src.chatgpt_subscription import chatgpt_headers
|
||||||
|
return chatgpt_headers(api_key)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
if provider == "openrouter":
|
if provider == "openrouter":
|
||||||
@@ -262,9 +288,13 @@ def resolve_endpoint(
|
|||||||
if not ep:
|
if not ep:
|
||||||
return fallback_url, fallback_model, fallback_headers
|
return fallback_url, fallback_model, fallback_headers
|
||||||
|
|
||||||
base = normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
|
||||||
|
return fallback_url, fallback_model, fallback_headers
|
||||||
chat_url = build_chat_url(base)
|
chat_url = build_chat_url(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
|
|
||||||
# Discard a configured model the user has since disabled on the
|
# Discard a configured model the user has since disabled on the
|
||||||
# endpoint (e.g. a stale `default_model` left pointing at a now-hidden
|
# endpoint (e.g. a stale `default_model` left pointing at a now-hidden
|
||||||
@@ -308,9 +338,13 @@ def resolve_endpoint_by_id(
|
|||||||
ep = q.first()
|
ep = q.first()
|
||||||
if not ep:
|
if not ep:
|
||||||
return None
|
return None
|
||||||
base = normalize_base(ep.base_url)
|
try:
|
||||||
|
base, api_key = resolve_endpoint_runtime(ep, owner=owner)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not resolve endpoint runtime credentials: %s", e)
|
||||||
|
return None
|
||||||
chat_url = build_chat_url(base)
|
chat_url = build_chat_url(base)
|
||||||
headers = build_headers(ep.api_key, base)
|
headers = build_headers(api_key, base)
|
||||||
m = (model or "").strip()
|
m = (model or "").strip()
|
||||||
# Drop a model the user disabled on the endpoint, then pick the first
|
# Drop a model the user disabled on the endpoint, then pick the first
|
||||||
# enabled chat model rather than a hidden one.
|
# enabled chat model rather than a hidden one.
|
||||||
|
|||||||
+217
-7
@@ -426,6 +426,9 @@ def _detect_provider(url: str) -> str:
|
|||||||
return "openrouter"
|
return "openrouter"
|
||||||
if _host_match(url, "groq.com"):
|
if _host_match(url, "groq.com"):
|
||||||
return "groq"
|
return "groq"
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
if is_chatgpt_subscription_base(url):
|
||||||
|
return "chatgpt-subscription"
|
||||||
from src.copilot import is_copilot_base
|
from src.copilot import is_copilot_base
|
||||||
if is_copilot_base(url):
|
if is_copilot_base(url):
|
||||||
return "copilot"
|
return "copilot"
|
||||||
@@ -462,6 +465,8 @@ def _provider_label(url: str) -> str:
|
|||||||
if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go"
|
if _host_match(url, "opencode.ai/zen/go"): return "OpenCode Go"
|
||||||
if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen"
|
if _host_match(url, "opencode.ai/zen"): return "OpenCode Zen"
|
||||||
if _host_match(url, "groq.com"): return "Groq"
|
if _host_match(url, "groq.com"): return "Groq"
|
||||||
|
from src.chatgpt_subscription import is_chatgpt_subscription_base
|
||||||
|
if is_chatgpt_subscription_base(url): return "ChatGPT Subscription"
|
||||||
from src.copilot import is_copilot_base
|
from src.copilot import is_copilot_base
|
||||||
if is_copilot_base(url): return "GitHub Copilot"
|
if is_copilot_base(url): return "GitHub Copilot"
|
||||||
if _host_match(url, "mistral.ai"): return "Mistral"
|
if _host_match(url, "mistral.ai"): return "Mistral"
|
||||||
@@ -479,6 +484,77 @@ def _provider_label(url: str) -> str:
|
|||||||
return host or "provider"
|
return host or "provider"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_chatgpt_subscription_url(url: str) -> str:
|
||||||
|
base = (url or "").strip().rstrip("/")
|
||||||
|
if base.endswith("/responses"):
|
||||||
|
return base
|
||||||
|
return base + "/responses"
|
||||||
|
|
||||||
|
|
||||||
|
def _message_content_as_text(content) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for part in content:
|
||||||
|
if not isinstance(part, dict):
|
||||||
|
if part:
|
||||||
|
parts.append(str(part))
|
||||||
|
continue
|
||||||
|
if isinstance(part.get("text"), str):
|
||||||
|
parts.append(part["text"])
|
||||||
|
continue
|
||||||
|
if isinstance(part.get("content"), str):
|
||||||
|
parts.append(part["content"])
|
||||||
|
return "\n".join(parts)
|
||||||
|
return "" if content is None else str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _chatgpt_subscription_instructions(messages: List[Dict]) -> str:
|
||||||
|
instructions = [
|
||||||
|
_message_content_as_text(msg.get("content")).strip()
|
||||||
|
for msg in messages or []
|
||||||
|
if (msg.get("role") or "") == "system"
|
||||||
|
]
|
||||||
|
instructions = [part for part in instructions if part]
|
||||||
|
if instructions:
|
||||||
|
return "\n\n".join(instructions)
|
||||||
|
return "You are a helpful AI assistant."
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chatgpt_responses_payload(
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
*,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
from src.chatgpt_subscription import build_responses_input
|
||||||
|
|
||||||
|
conversation = [msg for msg in (messages or []) if (msg.get("role") or "") != "system"]
|
||||||
|
payload: Dict = {
|
||||||
|
"model": model,
|
||||||
|
"instructions": _chatgpt_subscription_instructions(messages),
|
||||||
|
"input": build_responses_input(conversation),
|
||||||
|
"stream": stream,
|
||||||
|
"store": False,
|
||||||
|
}
|
||||||
|
if not _restricts_temperature(model):
|
||||||
|
payload["temperature"] = temperature
|
||||||
|
if max_tokens and max_tokens > 0:
|
||||||
|
payload["max_output_tokens"] = max_tokens
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _format_chatgpt_subscription_error(status_code: int, text: str) -> str:
|
||||||
|
if status_code in (401, 403):
|
||||||
|
return "ChatGPT Subscription credentials expired or were rejected. Reconnect the provider."
|
||||||
|
if status_code == 429:
|
||||||
|
return "ChatGPT Subscription quota or rate limit was reached. Retry after the upstream limit resets."
|
||||||
|
return _format_upstream_error(status_code, text, "https://chatgpt.com/backend-api/codex")
|
||||||
|
|
||||||
|
|
||||||
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
|
def _format_upstream_error(status: int, body: bytes | str, url: str) -> str:
|
||||||
"""Turn an upstream HTTP error into a user-readable sentence.
|
"""Turn an upstream HTTP error into a user-readable sentence.
|
||||||
|
|
||||||
@@ -874,7 +950,7 @@ def _normalize_anthropic_url(url: str) -> str:
|
|||||||
def _model_list_base(url: str) -> str:
|
def _model_list_base(url: str) -> str:
|
||||||
"""Normalize model/chat URLs to the configured endpoint base."""
|
"""Normalize model/chat URLs to the configured endpoint base."""
|
||||||
base = (url or "").strip().rstrip("/")
|
base = (url or "").strip().rstrip("/")
|
||||||
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages"):
|
for suffix in ("/models", "/chat/completions", "/completions", "/v1/messages", "/responses"):
|
||||||
if base.endswith(suffix):
|
if base.endswith(suffix):
|
||||||
base = base[: -len(suffix)].rstrip("/")
|
base = base[: -len(suffix)].rstrip("/")
|
||||||
for suffix in ("/chat", "/tags", "/generate"):
|
for suffix in ("/chat", "/tags", "/generate"):
|
||||||
@@ -903,7 +979,12 @@ def _parse_model_cache(raw) -> List[str]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
def _configured_cached_model_ids(
|
||||||
|
endpoint_url: str,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
endpoint_id: Optional[str] = None,
|
||||||
|
) -> List[str]:
|
||||||
"""Return cached models for a configured endpoint matching endpoint_url."""
|
"""Return cached models for a configured endpoint matching endpoint_url."""
|
||||||
target = _model_list_base(endpoint_url)
|
target = _model_list_base(endpoint_url)
|
||||||
if not target:
|
if not target:
|
||||||
@@ -914,7 +995,13 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
|||||||
return []
|
return []
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
rows = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
if endpoint_id:
|
||||||
|
q = q.filter(ModelEndpoint.id == endpoint_id)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
rows = q.all()
|
||||||
for ep in rows:
|
for ep in rows:
|
||||||
if _model_list_base(getattr(ep, "base_url", "")) != target:
|
if _model_list_base(getattr(ep, "base_url", "")) != target:
|
||||||
continue
|
continue
|
||||||
@@ -933,9 +1020,16 @@ def _configured_cached_model_ids(endpoint_url: str) -> List[str]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT, headers: Optional[Dict] = None) -> List[str]:
|
def list_model_ids(
|
||||||
|
base_chat_url: str,
|
||||||
|
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||||
|
headers: Optional[Dict] = None,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
endpoint_id: Optional[str] = None,
|
||||||
|
) -> List[str]:
|
||||||
"""List available model IDs from an endpoint."""
|
"""List available model IDs from an endpoint."""
|
||||||
cached = _configured_cached_model_ids(base_chat_url)
|
cached = _configured_cached_model_ids(base_chat_url, owner=owner, endpoint_id=endpoint_id)
|
||||||
if cached:
|
if cached:
|
||||||
return cached
|
return cached
|
||||||
provider = _detect_provider(base_chat_url)
|
provider = _detect_provider(base_chat_url)
|
||||||
@@ -971,9 +1065,16 @@ def list_model_ids(base_chat_url: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
|||||||
pass
|
pass
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def normalize_model_id(endpoint_url: str, requested: str, timeout: int = LLMConfig.DEFAULT_TIMEOUT) -> Optional[str]:
|
def normalize_model_id(
|
||||||
|
endpoint_url: str,
|
||||||
|
requested: str,
|
||||||
|
timeout: int = LLMConfig.DEFAULT_TIMEOUT,
|
||||||
|
*,
|
||||||
|
owner: Optional[str] = None,
|
||||||
|
endpoint_id: Optional[str] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
"""Normalize a model ID to match available models."""
|
"""Normalize a model ID to match available models."""
|
||||||
avail = list_model_ids(endpoint_url, timeout)
|
avail = list_model_ids(endpoint_url, timeout, owner=owner, endpoint_id=endpoint_id)
|
||||||
if not avail:
|
if not avail:
|
||||||
return None
|
return None
|
||||||
if requested in avail:
|
if requested in avail:
|
||||||
@@ -1169,6 +1270,49 @@ async def llm_call_async(
|
|||||||
logger.debug(f"Returning cached response for key: {cache_key}")
|
logger.debug(f"Returning cached response for key: {cache_key}")
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
|
if provider == "chatgpt-subscription":
|
||||||
|
# ChatGPT/Codex requires streamed Responses requests even for callers
|
||||||
|
# that want a plain string (auto-title, memory extraction, etc.).
|
||||||
|
# Reuse stream_llm's validated Codex SSE path and collect deltas.
|
||||||
|
parts: List[str] = []
|
||||||
|
async for chunk in stream_llm(
|
||||||
|
url,
|
||||||
|
model,
|
||||||
|
messages_copy,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
):
|
||||||
|
event_is_error = False
|
||||||
|
for line in str(chunk).splitlines():
|
||||||
|
if line.startswith("event:"):
|
||||||
|
event_is_error = line[6:].strip() == "error"
|
||||||
|
continue
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
raw = line[5:].strip()
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
if raw == "[DONE]":
|
||||||
|
response = "".join(parts)
|
||||||
|
_set_cached_response(cache_key, response)
|
||||||
|
return response
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if event_is_error or data.get("error") or (data.get("status") and data.get("text")):
|
||||||
|
status = int(data.get("status") or 502)
|
||||||
|
text = data.get("text") or data.get("error") or "ChatGPT Subscription request failed"
|
||||||
|
raise HTTPException(status, text)
|
||||||
|
delta = data.get("delta")
|
||||||
|
if isinstance(delta, str):
|
||||||
|
parts.append(delta)
|
||||||
|
response = "".join(parts)
|
||||||
|
_set_cached_response(cache_key, response)
|
||||||
|
return response
|
||||||
|
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
target_url = _normalize_anthropic_url(url)
|
target_url = _normalize_anthropic_url(url)
|
||||||
h = _build_anthropic_headers(headers)
|
h = _build_anthropic_headers(headers)
|
||||||
@@ -1294,6 +1438,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
|||||||
model, messages_copy, temperature, max_tokens,
|
model, messages_copy, temperature, max_tokens,
|
||||||
stream=True, tools=tools, num_ctx=get_context_length(url, model),
|
stream=True, tools=tools, num_ctx=get_context_length(url, model),
|
||||||
)
|
)
|
||||||
|
elif provider == "chatgpt-subscription":
|
||||||
|
target_url = _normalize_chatgpt_subscription_url(url)
|
||||||
|
h = _provider_headers(provider, headers)
|
||||||
|
payload = _build_chatgpt_responses_payload(model, messages_copy, temperature, max_tokens, stream=True)
|
||||||
else:
|
else:
|
||||||
target_url = url
|
target_url = url
|
||||||
payload = {
|
payload = {
|
||||||
@@ -1325,6 +1473,68 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
|||||||
return
|
return
|
||||||
note_model_activity(target_url, model)
|
note_model_activity(target_url, model)
|
||||||
|
|
||||||
|
# ── ChatGPT Subscription / Codex Responses streaming ──
|
||||||
|
if provider == "chatgpt-subscription":
|
||||||
|
event_name = ""
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
try:
|
||||||
|
client = _get_http_client()
|
||||||
|
async with client.stream('POST', target_url, json=payload, headers=h, timeout=stream_timeout) as r:
|
||||||
|
_clear_host_dead(target_url)
|
||||||
|
if r.status_code != 200:
|
||||||
|
raw = (await r.aread()).decode(errors="replace")
|
||||||
|
friendly = _format_chatgpt_subscription_error(r.status_code, raw)
|
||||||
|
yield f'event: error\ndata: {json.dumps({"status": r.status_code, "text": friendly, "raw": raw[:500]})}\n\n'
|
||||||
|
return
|
||||||
|
async for line in r.aiter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("event:"):
|
||||||
|
event_name = line[6:].strip()
|
||||||
|
continue
|
||||||
|
if not line.startswith("data:"):
|
||||||
|
continue
|
||||||
|
raw = line[5:].strip()
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
evt = data.get("type") or event_name
|
||||||
|
if evt == "response.output_text.delta":
|
||||||
|
delta = data.get("delta") or ""
|
||||||
|
if delta:
|
||||||
|
yield f'data: {json.dumps({"delta": delta})}\n\n'
|
||||||
|
elif evt == "response.completed":
|
||||||
|
usage = (data.get("response") or {}).get("usage") or data.get("usage") or {}
|
||||||
|
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens") or input_tokens
|
||||||
|
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens") or output_tokens
|
||||||
|
if input_tokens or output_tokens:
|
||||||
|
yield f'data: {json.dumps({"type": "usage", "data": {"input_tokens": input_tokens, "output_tokens": output_tokens}})}\n\n'
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
elif evt in ("response.failed", "error"):
|
||||||
|
err = data.get("error") or (data.get("response") or {}).get("error") or {}
|
||||||
|
text = err.get("message") if isinstance(err, dict) else str(err or "ChatGPT Subscription request failed")
|
||||||
|
yield f'event: error\ndata: {json.dumps({"status": 502, "text": text})}\n\n'
|
||||||
|
return
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||||
|
_cooled = _mark_host_dead(target_url)
|
||||||
|
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||||
|
logger.warning(f"ChatGPT Subscription stream connect to {target_url} failed: {e}{_tail}")
|
||||||
|
yield f'event: error\ndata: {json.dumps({"error": f"Cannot reach {_host_key(target_url)}", "status": 503})}\n\n'
|
||||||
|
except httpx.ReadTimeout:
|
||||||
|
yield f'event: error\ndata: {json.dumps({"error": "Read timeout", "status": 504})}\n\n'
|
||||||
|
except httpx.NetworkError:
|
||||||
|
yield f'event: error\ndata: {json.dumps({"error": "Network error", "status": 502})}\n\n'
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ChatGPT Subscription stream error: {e}")
|
||||||
|
yield f'event: error\ndata: {json.dumps({"error": str(e), "status": 502})}\n\n'
|
||||||
|
return
|
||||||
|
|
||||||
# ── Native Ollama streaming ──
|
# ── Native Ollama streaming ──
|
||||||
if provider == "ollama":
|
if provider == "ollama":
|
||||||
_ollama_tool_calls: List[Dict] = []
|
_ollama_tool_calls: List[Dict] = []
|
||||||
|
|||||||
@@ -2108,6 +2108,8 @@
|
|||||||
<option value="https://api.anthropic.com" data-logo="anthropic">Anthropic</option>
|
<option value="https://api.anthropic.com" data-logo="anthropic">Anthropic</option>
|
||||||
<option value="https://api.deepseek.com/v1" data-logo="deepseek" selected>DeepSeek</option>
|
<option value="https://api.deepseek.com/v1" data-logo="deepseek" selected>DeepSeek</option>
|
||||||
<option value="https://api.openai.com/v1" data-logo="openai">OpenAI</option>
|
<option value="https://api.openai.com/v1" data-logo="openai">OpenAI</option>
|
||||||
|
<option value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot</option>
|
||||||
|
<option value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription</option>
|
||||||
<option value="https://openrouter.ai/api/v1" data-logo="openrouter">OpenRouter</option>
|
<option value="https://openrouter.ai/api/v1" data-logo="openrouter">OpenRouter</option>
|
||||||
<option value="https://ollama.com/api" data-logo="ollama">Ollama Cloud</option>
|
<option value="https://ollama.com/api" data-logo="ollama">Ollama Cloud</option>
|
||||||
<option value="https://api.groq.com/openai/v1" data-logo="groq">Groq</option>
|
<option value="https://api.groq.com/openai/v1" data-logo="groq">Groq</option>
|
||||||
@@ -2136,6 +2138,7 @@
|
|||||||
<button class="admin-btn-add" id="adm-epAddBtn" style="width:55px;text-align:center;">Add</button>
|
<button class="admin-btn-add" id="adm-epAddBtn" style="width:55px;text-align:center;">Add</button>
|
||||||
</div>
|
</div>
|
||||||
<div id="adm-epApiMsg" class="adm-ep-inline-msg"></div>
|
<div id="adm-epApiMsg" class="adm-ep-inline-msg"></div>
|
||||||
|
<div id="adm-deviceAuthStatus" class="adm-ep-inline-msg"></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
+201
-68
@@ -5,6 +5,7 @@ import uiModule from './ui.js';
|
|||||||
import settingsModule from './settings.js';
|
import settingsModule from './settings.js';
|
||||||
import { providerLogo } from './providers.js';
|
import { providerLogo } from './providers.js';
|
||||||
import { sortModelObjects } from './modelSort.js';
|
import { sortModelObjects } from './modelSort.js';
|
||||||
|
import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js';
|
||||||
|
|
||||||
let initialized = false;
|
let initialized = false;
|
||||||
let modalEl = null;
|
let modalEl = null;
|
||||||
@@ -707,6 +708,80 @@ function initEndpointForm() {
|
|||||||
const pickerBtn = el('adm-provider-btn');
|
const pickerBtn = el('adm-provider-btn');
|
||||||
const pickerMenu = el('adm-provider-menu');
|
const pickerMenu = el('adm-provider-menu');
|
||||||
const pickerCurrent = picker ? picker.querySelector('.adm-provider-current') : null;
|
const pickerCurrent = picker ? picker.querySelector('.adm-provider-current') : null;
|
||||||
|
const DEVICE_AUTH_PROVIDER_VALUES = new Set(Object.keys(PROVIDER_DEVICE_FLOWS));
|
||||||
|
let deviceAuthPolling = false;
|
||||||
|
function _selectedProviderOption() {
|
||||||
|
return provider && provider.selectedOptions ? provider.selectedOptions[0] : null;
|
||||||
|
}
|
||||||
|
function _selectedDeviceAuthProvider() {
|
||||||
|
const opt = _selectedProviderOption();
|
||||||
|
const flow = opt && opt.dataset ? opt.dataset.authFlow : '';
|
||||||
|
if (flow && DEVICE_AUTH_PROVIDER_VALUES.has(flow)) return flow;
|
||||||
|
return DEVICE_AUTH_PROVIDER_VALUES.has(provider.value) ? provider.value : '';
|
||||||
|
}
|
||||||
|
function _isDeviceAuthSelected() {
|
||||||
|
return !!_selectedDeviceAuthProvider();
|
||||||
|
}
|
||||||
|
function _setApiFormForProvider() {
|
||||||
|
const deviceAuthProvider = _selectedDeviceAuthProvider();
|
||||||
|
const deviceAuthConfig = PROVIDER_DEVICE_FLOWS[deviceAuthProvider] || null;
|
||||||
|
const apiKey = el('adm-epApiKey');
|
||||||
|
const testBtn = el('adm-epApiTestBtn');
|
||||||
|
const addBtn = el('adm-epAddBtn');
|
||||||
|
const status = el('adm-deviceAuthStatus');
|
||||||
|
const msg = _endpointMsg('api');
|
||||||
|
if (deviceAuthConfig) {
|
||||||
|
urlInput.value = '';
|
||||||
|
urlInput.placeholder = deviceAuthProvider === 'copilot'
|
||||||
|
? 'GitHub Copilot uses GitHub account sign-in'
|
||||||
|
: 'ChatGPT Subscription uses OpenAI account sign-in';
|
||||||
|
urlInput.readOnly = true;
|
||||||
|
if (apiKey) {
|
||||||
|
apiKey.value = '';
|
||||||
|
apiKey.placeholder = 'No API key needed';
|
||||||
|
apiKey.disabled = true;
|
||||||
|
}
|
||||||
|
if (testBtn) {
|
||||||
|
testBtn.disabled = true;
|
||||||
|
testBtn.style.opacity = '0.45';
|
||||||
|
testBtn.style.cursor = 'not-allowed';
|
||||||
|
}
|
||||||
|
if (addBtn) {
|
||||||
|
addBtn.disabled = false;
|
||||||
|
addBtn.textContent = 'Add';
|
||||||
|
addBtn.style.width = '55px';
|
||||||
|
addBtn.style.display = '';
|
||||||
|
}
|
||||||
|
if (kindSel) kindSel.value = 'api';
|
||||||
|
if (msg) {
|
||||||
|
msg.textContent = '';
|
||||||
|
msg.className = '';
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
urlInput.placeholder = 'Base URL or pick provider';
|
||||||
|
urlInput.readOnly = false;
|
||||||
|
if (apiKey) {
|
||||||
|
apiKey.placeholder = 'API key';
|
||||||
|
apiKey.disabled = false;
|
||||||
|
}
|
||||||
|
if (testBtn) {
|
||||||
|
testBtn.disabled = false;
|
||||||
|
testBtn.style.opacity = '';
|
||||||
|
testBtn.style.cursor = '';
|
||||||
|
}
|
||||||
|
if (addBtn) {
|
||||||
|
addBtn.disabled = false;
|
||||||
|
addBtn.textContent = 'Add';
|
||||||
|
addBtn.style.width = '55px';
|
||||||
|
addBtn.style.display = '';
|
||||||
|
}
|
||||||
|
if (msg) {
|
||||||
|
msg.textContent = '';
|
||||||
|
msg.className = '';
|
||||||
|
}
|
||||||
|
if (!deviceAuthPolling && status) status.textContent = '';
|
||||||
|
}
|
||||||
|
}
|
||||||
function _renderPickerMenu() {
|
function _renderPickerMenu() {
|
||||||
if (!pickerMenu) return;
|
if (!pickerMenu) return;
|
||||||
pickerMenu.innerHTML = Array.from(provider.options).map(o => {
|
pickerMenu.innerHTML = Array.from(provider.options).map(o => {
|
||||||
@@ -748,9 +823,16 @@ function initEndpointForm() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
provider.addEventListener('change', () => {
|
provider.addEventListener('change', () => {
|
||||||
|
if (_isDeviceAuthSelected()) {
|
||||||
|
_setApiFormForProvider();
|
||||||
|
_renderPickerMenu();
|
||||||
|
_syncPickerCurrent();
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (provider.value) urlInput.value = provider.value;
|
if (provider.value) urlInput.value = provider.value;
|
||||||
else urlInput.value = '';
|
else urlInput.value = '';
|
||||||
if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy';
|
if (kindSel) kindSel.value = provider.value ? 'api' : 'proxy';
|
||||||
|
_setApiFormForProvider();
|
||||||
});
|
});
|
||||||
urlInput.addEventListener('input', () => {
|
urlInput.addEventListener('input', () => {
|
||||||
if (provider.value && urlInput.value.trim() !== provider.value) {
|
if (provider.value && urlInput.value.trim() !== provider.value) {
|
||||||
@@ -838,6 +920,12 @@ function initEndpointForm() {
|
|||||||
const apiCancelTestBtn = el('adm-epApiCancelTestBtn');
|
const apiCancelTestBtn = el('adm-epApiCancelTestBtn');
|
||||||
if (apiTestBtn) {
|
if (apiTestBtn) {
|
||||||
apiTestBtn.addEventListener('click', async () => {
|
apiTestBtn.addEventListener('click', async () => {
|
||||||
|
if (_isDeviceAuthSelected()) {
|
||||||
|
const msg = _endpointMsg('api');
|
||||||
|
msg.textContent = '';
|
||||||
|
msg.className = '';
|
||||||
|
return;
|
||||||
|
}
|
||||||
const msg = _endpointMsg('api');
|
const msg = _endpointMsg('api');
|
||||||
msg.textContent = ''; msg.className = '';
|
msg.textContent = ''; msg.className = '';
|
||||||
const rawUrl = (urlInput.value || provider.value).trim();
|
const rawUrl = (urlInput.value || provider.value).trim();
|
||||||
@@ -885,6 +973,11 @@ function initEndpointForm() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
el('adm-epAddBtn').addEventListener('click', async () => {
|
el('adm-epAddBtn').addEventListener('click', async () => {
|
||||||
|
const deviceAuthProvider = _selectedDeviceAuthProvider();
|
||||||
|
if (deviceAuthProvider) {
|
||||||
|
await _startProviderDeviceAuth(deviceAuthProvider, el('adm-epAddBtn'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
const msg = _endpointMsg('api');
|
const msg = _endpointMsg('api');
|
||||||
msg.textContent = ''; msg.className = '';
|
msg.textContent = ''; msg.className = '';
|
||||||
const rawUrl = (urlInput.value || provider.value).trim();
|
const rawUrl = (urlInput.value || provider.value).trim();
|
||||||
@@ -936,76 +1029,116 @@ function initEndpointForm() {
|
|||||||
btn.disabled = false; btn.textContent = 'Add';
|
btn.disabled = false; btn.textContent = 'Add';
|
||||||
});
|
});
|
||||||
|
|
||||||
// GitHub Copilot — device-flow login. Starts the flow, shows the user a
|
async function _startProviderDeviceAuth(providerKey, triggerEl = null) {
|
||||||
// code + verification link, and polls until they authorise (or it expires).
|
if (deviceAuthPolling) return;
|
||||||
const copilotBtn = el('adm-copilotConnectBtn');
|
const config = PROVIDER_DEVICE_FLOWS[providerKey];
|
||||||
if (copilotBtn) {
|
if (!config) return;
|
||||||
let copilotPolling = false;
|
const status = el('adm-deviceAuthStatus') || _endpointMsg('api');
|
||||||
copilotBtn.addEventListener('click', async () => {
|
if (!status) return;
|
||||||
if (copilotPolling) return;
|
const triggerText = triggerEl ? triggerEl.textContent : '';
|
||||||
const status = el('adm-copilotStatus');
|
// Render an error with an inline "Try again" (the top button is hidden for
|
||||||
const reset = () => { copilotBtn.disabled = false; copilotBtn.textContent = 'Connect GitHub Copilot'; copilotPolling = false; };
|
// device-auth providers, so retry lives here). Built with DOM methods, not
|
||||||
status.textContent = ''; status.className = 'adm-ep-inline-msg';
|
// innerHTML. Call reset() first so the deviceAuthPolling guard is cleared.
|
||||||
copilotBtn.disabled = true; copilotBtn.textContent = 'Starting...';
|
const showAuthError = (text) => {
|
||||||
copilotPolling = true;
|
status.className = 'admin-error';
|
||||||
let start;
|
status.textContent = text + ' ';
|
||||||
try {
|
const retry = document.createElement('button');
|
||||||
const res = await fetch('/api/copilot/device/start', { method: 'POST', body: new FormData(), credentials: 'same-origin' });
|
retry.type = 'button';
|
||||||
start = await res.json();
|
retry.className = 'admin-btn-sm';
|
||||||
if (!res.ok) { status.textContent = start.detail || 'Failed to start login'; status.className = 'admin-error'; reset(); return; }
|
retry.textContent = 'Try again';
|
||||||
} catch (e) { status.textContent = 'Request failed'; status.className = 'admin-error'; reset(); return; }
|
retry.addEventListener('click', () => { _startProviderDeviceAuth(providerKey, triggerEl); });
|
||||||
|
status.appendChild(retry);
|
||||||
|
};
|
||||||
|
const reset = () => {
|
||||||
|
if (triggerEl) {
|
||||||
|
triggerEl.disabled = false;
|
||||||
|
triggerEl.textContent = triggerText || 'Add';
|
||||||
|
}
|
||||||
|
deviceAuthPolling = false;
|
||||||
|
_setApiFormForProvider();
|
||||||
|
};
|
||||||
|
status.textContent = '';
|
||||||
|
status.className = 'adm-ep-inline-msg';
|
||||||
|
if (triggerEl) {
|
||||||
|
triggerEl.disabled = true;
|
||||||
|
triggerEl.textContent = 'Starting...';
|
||||||
|
}
|
||||||
|
deviceAuthPolling = true;
|
||||||
|
_setApiFormForProvider();
|
||||||
|
status.textContent = `Starting ${config.label} sign-in...`;
|
||||||
|
|
||||||
const { poll_id, user_code, verification_uri, verification_uri_complete, interval, expires_in } = start;
|
try {
|
||||||
// Prefer the "complete" URL — it embeds the code so the user only has to
|
const result = await runProviderDeviceFlow(providerKey, {
|
||||||
// click "Authorize" (no manual code entry).
|
openWindow: () => {},
|
||||||
const authUrl = verification_uri_complete || verification_uri || '';
|
onStart: ({ start, authUrl }) => {
|
||||||
const esc = (s) => String(s || '').replace(/[<>&"]/g, (c) => ({ '<': '<', '>': '>', '&': '&', '"': '"' }[c]));
|
if (triggerEl) triggerEl.textContent = 'Waiting...';
|
||||||
copilotBtn.textContent = 'Waiting…';
|
status.className = '';
|
||||||
|
const authLabel = providerKey === 'copilot' ? 'Authorize on GitHub' : 'Authorize with OpenAI';
|
||||||
// Cohesive waiting panel: spinner + status line, the device code as a
|
const waitLabel = providerKey === 'copilot' ? 'Waiting for GitHub authorization...' : 'Waiting for ChatGPT authorization...';
|
||||||
// copyable chip, and a primary "Authorize on GitHub" action.
|
status.innerHTML =
|
||||||
status.className = '';
|
'<div class="adm-copilot-panel">' +
|
||||||
status.innerHTML =
|
'<div class="adm-copilot-wait"><span class="admin-spinner"></span>' +
|
||||||
'<div class="adm-copilot-panel">' +
|
'<span>' + esc(waitLabel) + '</span></div>' +
|
||||||
'<div class="adm-copilot-wait"><span class="admin-spinner"></span>' +
|
'<div class="adm-copilot-coderow">' +
|
||||||
'<span>Waiting for GitHub authorization…</span></div>' +
|
'<span class="adm-copilot-code-label">Code</span>' +
|
||||||
'<div class="adm-copilot-coderow">' +
|
'<code class="adm-copilot-code">' + esc(start.user_code) + '</code>' +
|
||||||
'<span class="adm-copilot-code-label">Code</span>' +
|
'<button type="button" class="admin-btn-sm adm-device-auth-copy">Copy</button>' +
|
||||||
'<code class="adm-copilot-code">' + esc(user_code) + '</code>' +
|
'</div>' +
|
||||||
'<button type="button" class="admin-btn-sm adm-copilot-copy">Copy</button>' +
|
'<a class="admin-btn-add adm-copilot-auth" href="' + encodeURI(authUrl || '') + '" target="_blank" rel="noopener">' + esc(authLabel) + ' ↗</a>' +
|
||||||
'</div>' +
|
'</div>';
|
||||||
'<a class="admin-btn-add adm-copilot-auth" href="' + encodeURI(authUrl) + '" target="_blank" rel="noopener">Authorize on GitHub ↗</a>' +
|
const copyBtn = status.querySelector('.adm-device-auth-copy');
|
||||||
'<div class="adm-copilot-hint">A new tab opened on GitHub — approve there to finish. Didn\'t open? Use the button above.</div>' +
|
if (copyBtn) copyBtn.addEventListener('click', async () => {
|
||||||
'</div>';
|
const code = start.user_code || '';
|
||||||
const copyBtn = status.querySelector('.adm-copilot-copy');
|
let ok = false;
|
||||||
if (copyBtn) copyBtn.addEventListener('click', async () => {
|
try {
|
||||||
try { await navigator.clipboard.writeText(user_code || ''); copyBtn.textContent = 'Copied'; setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500); } catch (e) {}
|
if (navigator.clipboard && window.isSecureContext) {
|
||||||
|
await navigator.clipboard.writeText(code);
|
||||||
|
ok = true;
|
||||||
|
}
|
||||||
|
} catch (e) {}
|
||||||
|
if (!ok) {
|
||||||
|
// navigator.clipboard is unavailable in non-secure contexts (HTTP
|
||||||
|
// self-host over a LAN IP), so fall back to execCommand('copy').
|
||||||
|
const ta = document.createElement('textarea');
|
||||||
|
ta.value = code;
|
||||||
|
ta.style.cssText = 'position:fixed;top:0;left:0;width:1px;height:1px;padding:0;border:0;opacity:0;font-size:16px;';
|
||||||
|
document.body.appendChild(ta);
|
||||||
|
ta.focus();
|
||||||
|
ta.select();
|
||||||
|
try { ta.setSelectionRange(0, code.length); } catch (e) {}
|
||||||
|
try { ok = document.execCommand('copy'); } catch (e) {}
|
||||||
|
ta.remove();
|
||||||
|
}
|
||||||
|
copyBtn.textContent = ok ? 'Copied' : 'Failed';
|
||||||
|
setTimeout(() => { copyBtn.textContent = 'Copy'; }, 1500);
|
||||||
|
});
|
||||||
|
},
|
||||||
});
|
});
|
||||||
try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {}
|
if (result.status === 'authorized') {
|
||||||
|
const endpoint = result.endpoint || {};
|
||||||
const deadline = Date.now() + (expires_in || 900) * 1000;
|
const n = ((endpoint && endpoint.models) || []).length;
|
||||||
const stepMs = Math.max((interval || 5), 2) * 1000;
|
status.className = 'admin-success';
|
||||||
const done = (cls, text) => { status.className = cls; status.textContent = text; reset(); };
|
status.textContent = 'Connected - ' + n + ' ' + config.label + ' model' + (n !== 1 ? 's' : '') + ' available.';
|
||||||
const poll = async () => {
|
if (endpoint && endpoint.id) _recentlyAddedEpId = String(endpoint.id);
|
||||||
if (Date.now() > deadline) { done('admin-error', 'Authorization expired — try again.'); return; }
|
await loadEndpoints();
|
||||||
try {
|
await _selectAddedModelInChat(endpoint || {});
|
||||||
const fd = new FormData(); fd.append('poll_id', poll_id);
|
reset();
|
||||||
const r = await fetch('/api/copilot/device/poll', { method: 'POST', body: fd, credentials: 'same-origin' });
|
return;
|
||||||
const d = await r.json();
|
}
|
||||||
if (d.status === 'authorized') {
|
if (result.status === 'failed') {
|
||||||
const n = ((d.endpoint && d.endpoint.models) || []).length;
|
reset();
|
||||||
done('admin-success', '✓ Connected — ' + n + ' Copilot model' + (n !== 1 ? 's' : '') + ' available.');
|
showAuthError('Authorization failed (' + (result.error || 'denied') + ').');
|
||||||
if (d.endpoint && d.endpoint.id) _recentlyAddedEpId = String(d.endpoint.id);
|
return;
|
||||||
await loadEndpoints();
|
}
|
||||||
await _selectAddedModelInChat(d.endpoint || {});
|
if (result.status === 'expired') {
|
||||||
return;
|
reset();
|
||||||
}
|
showAuthError('Authorization expired.');
|
||||||
if (d.status === 'failed') { done('admin-error', 'Authorization failed (' + (d.error || 'denied') + ').'); return; }
|
return;
|
||||||
} catch (e) { /* transient — keep polling */ }
|
}
|
||||||
setTimeout(poll, stepMs);
|
} catch (e) {
|
||||||
};
|
reset();
|
||||||
setTimeout(poll, stepMs);
|
showAuthError(formatDeviceFlowError(e));
|
||||||
});
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Local "Add" button — sibling form for self-hosted base URLs.
|
// Local "Add" button — sibling form for self-hosted base URLs.
|
||||||
|
|||||||
+39
-15
@@ -680,9 +680,11 @@ export function applyModelColor(roleEl, modelName) {
|
|||||||
html += '<div><span class="ctx-label">Max tokens</span> ' + _mt.toLocaleString() + ' <span style="opacity:0.4">(configured)</span></div>';
|
html += '<div><span class="ctx-label">Max tokens</span> ' + _mt.toLocaleString() + ' <span style="opacity:0.4">(configured)</span></div>';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (info && info.input != null) html += '<div><span class="ctx-label">Input</span> $' + info.input.toFixed(2) + ' / 1M</div>';
|
if (isCostTrackedEndpoint(_epUrl)) {
|
||||||
if (info && info.output != null) html += '<div><span class="ctx-label">Output</span> $' + info.output.toFixed(2) + ' / 1M</div>';
|
if (info && info.input != null) html += '<div><span class="ctx-label">Input</span> $' + info.input.toFixed(2) + ' / 1M</div>';
|
||||||
if (!info) html += '<div style="opacity:0.4;font-size:0.85em;margin-top:4px;">No pricing data available</div>';
|
if (info && info.output != null) html += '<div><span class="ctx-label">Output</span> $' + info.output.toFixed(2) + ' / 1M</div>';
|
||||||
|
if (!info) html += '<div style="opacity:0.4;font-size:0.85em;margin-top:4px;">No pricing data available</div>';
|
||||||
|
}
|
||||||
popup.innerHTML = html;
|
popup.innerHTML = html;
|
||||||
const rect = roleEl.getBoundingClientRect();
|
const rect = roleEl.getBoundingClientRect();
|
||||||
popup.style.top = (rect.bottom + 4) + 'px';
|
popup.style.top = (rect.bottom + 4) + 'px';
|
||||||
@@ -735,11 +737,31 @@ export function isLocalEndpoint(url) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Cost for the current turn, returning null (free) for local endpoints. */
|
export function isSubscriptionEndpoint(url) {
|
||||||
function _billableCost(model, inputTokens, outputTokens) {
|
if (!url) return false;
|
||||||
const url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
|
try {
|
||||||
|
const parsed = new URL(url);
|
||||||
|
const path = parsed.pathname.replace(/\/+$/, '');
|
||||||
|
return parsed.hostname === 'chatgpt.com'
|
||||||
|
&& (path === '/backend-api/codex' || path.startsWith('/backend-api/codex/'));
|
||||||
|
} catch (_e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function _currentEndpointUrl() {
|
||||||
|
return (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
|
||||||
? window.sessionModule.getCurrentEndpointUrl() : null;
|
? window.sessionModule.getCurrentEndpointUrl() : null;
|
||||||
if (isLocalEndpoint(url)) return null;
|
}
|
||||||
|
|
||||||
|
export function isCostTrackedEndpoint(url) {
|
||||||
|
return !isLocalEndpoint(url) && !isSubscriptionEndpoint(url);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Cost for the current turn, returning null for non-billable endpoints. */
|
||||||
|
function _billableCost(model, inputTokens, outputTokens) {
|
||||||
|
const url = _currentEndpointUrl();
|
||||||
|
if (!isCostTrackedEndpoint(url)) return null;
|
||||||
return getModelCost(model, inputTokens, outputTokens);
|
return getModelCost(model, inputTokens, outputTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -784,11 +806,10 @@ export function resetSessionCost(sessionId) {
|
|||||||
export function updateSessionCostUI() {
|
export function updateSessionCostUI() {
|
||||||
const el = document.getElementById('session-cost-display');
|
const el = document.getElementById('session-cost-display');
|
||||||
if (!el) return;
|
if (!el) return;
|
||||||
// Local model? It's free — hide the badge and clear any stale cost that a
|
// Non-billable endpoint? Hide the badge and clear stale cost that a previous
|
||||||
// previous (buggy) cloud-rate billing left in localStorage for this session.
|
// cloud-rate calculation may have left in localStorage for this session.
|
||||||
const _url = (window.sessionModule && window.sessionModule.getCurrentEndpointUrl)
|
const _url = _currentEndpointUrl();
|
||||||
? window.sessionModule.getCurrentEndpointUrl() : null;
|
if (!isCostTrackedEndpoint(_url)) {
|
||||||
if (isLocalEndpoint(_url)) {
|
|
||||||
const sid = window.sessionModule && window.sessionModule.getCurrentSessionId();
|
const sid = window.sessionModule && window.sessionModule.getCurrentSessionId();
|
||||||
if (sid && getSessionCost(sid) > 0) {
|
if (sid && getSessionCost(sid) > 0) {
|
||||||
try {
|
try {
|
||||||
@@ -1708,7 +1729,8 @@ export function displayMetrics(messageElement, metrics) {
|
|||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
document.querySelectorAll('.ctx-popup').forEach(p => { if (typeof p._dismiss === 'function') p._dismiss(); else p.remove(); });
|
document.querySelectorAll('.ctx-popup').forEach(p => { if (typeof p._dismiss === 'function') p._dismiss(); else p.remove(); });
|
||||||
|
|
||||||
const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : 'n/a';
|
const costStr = cost !== null ? `$${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}` : '';
|
||||||
|
const costRows = costStr ? `<div><span class="ctx-label">Cost</span> ${costStr}</div>` : '';
|
||||||
const speedStr = tps != null && tps !== 'undefined' ? `${tps} tok/s` : 'n/a';
|
const speedStr = tps != null && tps !== 'undefined' ? `${tps} tok/s` : 'n/a';
|
||||||
const totalTok = inputTokens + outputTokens;
|
const totalTok = inputTokens + outputTokens;
|
||||||
const ctxColor = ctxPct >= 85 ? 'var(--red, #e06c75)' : ctxPct >= 70 ? '#ff9900' : 'var(--color-muted-alt, #6b7280)';
|
const ctxColor = ctxPct >= 85 ? 'var(--red, #e06c75)' : ctxPct >= 70 ? '#ff9900' : 'var(--color-muted-alt, #6b7280)';
|
||||||
@@ -1722,7 +1744,7 @@ export function displayMetrics(messageElement, metrics) {
|
|||||||
// Session total cost
|
// Session total cost
|
||||||
let sessionCostStr = '';
|
let sessionCostStr = '';
|
||||||
const sc = getSessionCost();
|
const sc = getSessionCost();
|
||||||
if (sc > 0) {
|
if (costStr && sc > 0) {
|
||||||
sessionCostStr = `<div><span class="ctx-label">Session</span> $${sc < 0.01 ? sc.toFixed(4) : sc.toFixed(3)}</div>`;
|
sessionCostStr = `<div><span class="ctx-label">Session</span> $${sc < 0.01 ? sc.toFixed(4) : sc.toFixed(3)}</div>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1738,7 +1760,7 @@ export function displayMetrics(messageElement, metrics) {
|
|||||||
<div><span class="ctx-label">Time</span> ${responseTime}s</div>
|
<div><span class="ctx-label">Time</span> ${responseTime}s</div>
|
||||||
${prepTime != null ? `<div><span class="ctx-label">Prep</span> ${prepTime}s</div>` : ''}
|
${prepTime != null ? `<div><span class="ctx-label">Prep</span> ${prepTime}s</div>` : ''}
|
||||||
${modelWaitTime != null ? `<div><span class="ctx-label">Model wait</span> ${modelWaitTime}s</div>` : ''}
|
${modelWaitTime != null ? `<div><span class="ctx-label">Model wait</span> ${modelWaitTime}s</div>` : ''}
|
||||||
<div><span class="ctx-label">Cost</span> ${costStr}</div>
|
${costRows}
|
||||||
${sessionCostStr}
|
${sessionCostStr}
|
||||||
${prepDetails ? `<div style="margin-top:6px;padding-top:6px;border-top:1px solid var(--border);font-size:0.85em;opacity:0.8;">
|
${prepDetails ? `<div style="margin-top:6px;padding-top:6px;border-top:1px solid var(--border);font-size:0.85em;opacity:0.8;">
|
||||||
<div style="font-weight:600;margin-bottom:4px;color:var(--fg);">Agent prep</div>
|
<div style="font-weight:600;margin-bottom:4px;color:var(--fg);">Agent prep</div>
|
||||||
@@ -2392,6 +2414,8 @@ const chatRenderer = {
|
|||||||
modelColor,
|
modelColor,
|
||||||
applyModelColor,
|
applyModelColor,
|
||||||
getModelCost,
|
getModelCost,
|
||||||
|
isCostTrackedEndpoint,
|
||||||
|
isSubscriptionEndpoint,
|
||||||
getImageCost,
|
getImageCost,
|
||||||
getSessionCost,
|
getSessionCost,
|
||||||
resetSessionCost,
|
resetSessionCost,
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
// Shared DOM-free provider device-flow runner.
|
||||||
|
|
||||||
|
export const PROVIDER_DEVICE_FLOWS = {
|
||||||
|
copilot: {
|
||||||
|
label: 'GitHub Copilot',
|
||||||
|
startUrl: '/api/copilot/device/start',
|
||||||
|
pollUrl: '/api/copilot/device/poll',
|
||||||
|
authUrl(start) {
|
||||||
|
return start?.verification_uri_complete || start?.verification_uri || '';
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'chatgpt-subscription': {
|
||||||
|
label: 'ChatGPT Subscription',
|
||||||
|
startUrl: '/api/chatgpt-subscription/device/start',
|
||||||
|
pollUrl: '/api/chatgpt-subscription/device/poll',
|
||||||
|
authUrl(start) {
|
||||||
|
return start?.verification_uri || '';
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
function _formData() {
|
||||||
|
if (typeof FormData !== 'undefined') return new FormData();
|
||||||
|
return new URLSearchParams();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _jsonOrEmpty(response) {
|
||||||
|
try {
|
||||||
|
return await response.json();
|
||||||
|
} catch (_) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function _messageFromPayload(payload, fallback) {
|
||||||
|
if (payload && typeof payload.detail === 'string' && payload.detail.trim()) {
|
||||||
|
return payload.detail.trim();
|
||||||
|
}
|
||||||
|
if (payload && typeof payload.error === 'string' && payload.error.trim()) {
|
||||||
|
return payload.error.trim();
|
||||||
|
}
|
||||||
|
if (payload && typeof payload.message === 'string' && payload.message.trim()) {
|
||||||
|
return payload.message.trim();
|
||||||
|
}
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function formatDeviceFlowError(error, fallback = 'Request failed') {
|
||||||
|
if (!error) return fallback;
|
||||||
|
if (typeof error === 'string') return error;
|
||||||
|
if (error.detail) return String(error.detail);
|
||||||
|
if (error.message) return String(error.message);
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _fetchJson(fetchImpl, url, options, fallback) {
|
||||||
|
const response = await fetchImpl(url, options);
|
||||||
|
const payload = await _jsonOrEmpty(response);
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(_messageFromPayload(payload, fallback || `Request failed (HTTP ${response.status})`));
|
||||||
|
}
|
||||||
|
return payload;
|
||||||
|
}
|
||||||
|
|
||||||
|
function _defaultSleep(ms) {
|
||||||
|
return new Promise(resolve => setTimeout(resolve, ms));
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _callCallback(fn, payload) {
|
||||||
|
if (typeof fn === 'function') await fn(payload);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function runProviderDeviceFlow(provider, options = {}) {
|
||||||
|
const cfg = PROVIDER_DEVICE_FLOWS[provider];
|
||||||
|
if (!cfg) throw new Error(`Unknown device-flow provider: ${provider}`);
|
||||||
|
|
||||||
|
const fetchImpl = options.fetchImpl || globalThis.fetch?.bind(globalThis);
|
||||||
|
if (!fetchImpl) throw new Error('Fetch API is unavailable');
|
||||||
|
|
||||||
|
const openWindow = options.openWindow || ((url) => {
|
||||||
|
if (globalThis.window && typeof globalThis.window.open === 'function') {
|
||||||
|
globalThis.window.open(url, '_blank', 'noopener');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const sleep = options.sleep || _defaultSleep;
|
||||||
|
const now = options.now || (() => Date.now());
|
||||||
|
const formData = options.formData || _formData();
|
||||||
|
|
||||||
|
const start = await _fetchJson(fetchImpl, cfg.startUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData,
|
||||||
|
credentials: 'same-origin',
|
||||||
|
}, `Failed to start ${cfg.label} sign-in`);
|
||||||
|
|
||||||
|
if (!start.poll_id) throw new Error(`${cfg.label} sign-in did not return a poll id`);
|
||||||
|
const authUrl = cfg.authUrl(start);
|
||||||
|
await _callCallback(options.onStart, { provider, config: cfg, start, authUrl });
|
||||||
|
if (authUrl) openWindow(authUrl);
|
||||||
|
|
||||||
|
const deadline = now() + Number(start.expires_in || 900) * 1000;
|
||||||
|
let stepMs = Math.max(Number(start.interval || 5), 2) * 1000;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (now() > deadline) return { status: 'expired' };
|
||||||
|
await _callCallback(options.onWaiting, { provider, config: cfg, start, authUrl });
|
||||||
|
await sleep(stepMs);
|
||||||
|
if (now() > deadline) return { status: 'expired' };
|
||||||
|
|
||||||
|
const fd = _formData();
|
||||||
|
fd.append('poll_id', start.poll_id);
|
||||||
|
const poll = await _fetchJson(fetchImpl, cfg.pollUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
body: fd,
|
||||||
|
credentials: 'same-origin',
|
||||||
|
}, `${cfg.label} sign-in poll failed`);
|
||||||
|
await _callCallback(options.onPoll, { provider, config: cfg, start, poll });
|
||||||
|
|
||||||
|
if (poll.status === 'authorized') {
|
||||||
|
return { status: 'authorized', endpoint: poll.endpoint || {} };
|
||||||
|
}
|
||||||
|
if (poll.status === 'failed') {
|
||||||
|
return { status: 'failed', error: poll.error || 'denied' };
|
||||||
|
}
|
||||||
|
if (poll.interval) {
|
||||||
|
stepMs = Math.max(Number(poll.interval || 5), 2) * 1000;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,10 @@ const _PROVIDERS = [
|
|||||||
[/opencode/i,
|
[/opencode/i,
|
||||||
'<svg viewBox="0 0 24 30" fill="currentColor"><path d="M18 6H6V24H18V6ZM24 30H0V0H24V30Z"/></svg>'],
|
'<svg viewBox="0 0 24 30" fill="currentColor"><path d="M18 6H6V24H18V6ZM24 30H0V0H24V30Z"/></svg>'],
|
||||||
|
|
||||||
|
// GitHub / Copilot
|
||||||
|
[/github|copilot/i,
|
||||||
|
'<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .5A12 12 0 0 0 8.2 23.9c.6.1.8-.3.8-.6v-2.1c-3.3.7-4-1.4-4-1.4-.5-1.4-1.3-1.8-1.3-1.8-1.1-.8.1-.8.1-.8 1.2.1 1.9 1.3 1.9 1.3 1.1 1.9 2.9 1.3 3.6 1 .1-.8.4-1.3.8-1.6-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.3-.5-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.4 11.4 0 0 1 6 0C15.3 4.7 16 5 16 5c.6 1.6.2 2.9.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .5Z"/></svg>'],
|
||||||
|
|
||||||
// OpenRouter
|
// OpenRouter
|
||||||
[/openrouter|open router/i,
|
[/openrouter|open router/i,
|
||||||
'<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="5" cy="12" r="2.5"/><circle cx="19" cy="6" r="2.5"/><circle cx="19" cy="18" r="2.5"/><path d="M7.5 12h4.5c2 0 2.5-6 4.5-6"/><path d="M12 12c2 0 2.5 6 4.5 6"/></svg>'],
|
'<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="5" cy="12" r="2.5"/><circle cx="19" cy="6" r="2.5"/><circle cx="19" cy="18" r="2.5"/><path d="M7.5 12h4.5c2 0 2.5-6 4.5-6"/><path d="M12 12c2 0 2.5 6 4.5 6"/></svg>'],
|
||||||
@@ -102,6 +106,7 @@ export function providerLogo(modelId) {
|
|||||||
// doesn't match `x.ai`.
|
// doesn't match `x.ai`.
|
||||||
const _ENDPOINT_LABELS = [
|
const _ENDPOINT_LABELS = [
|
||||||
[/(^|\.)githubcopilot\.com$/i, "GitHub Copilot"],
|
[/(^|\.)githubcopilot\.com$/i, "GitHub Copilot"],
|
||||||
|
[/(^|\.)chatgpt\.com$/i, "ChatGPT Subscription"],
|
||||||
[/(^|\.)openrouter\.ai$/i, "OpenRouter"],
|
[/(^|\.)openrouter\.ai$/i, "OpenRouter"],
|
||||||
[/(^|\.)anthropic\.com$/i, "Anthropic"],
|
[/(^|\.)anthropic\.com$/i, "Anthropic"],
|
||||||
[/(^|\.)openai\.com$/i, "OpenAI"],
|
[/(^|\.)openai\.com$/i, "OpenAI"],
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
import { COMMANDS, LEGACY_ALIASES } from './slashCommands.js';
|
import { COMMANDS, LEGACY_ALIASES } from './slashCommands.js';
|
||||||
|
|
||||||
const POPUP_ID = 'slash-autocomplete';
|
const POPUP_ID = 'slash-autocomplete';
|
||||||
const MAX_VISIBLE = 12;
|
const MAX_VISIBLE = 14;
|
||||||
|
|
||||||
// Flatten the registry into a searchable list of leaf entries. Each entry is
|
// Flatten the registry into a searchable list of leaf entries. Each entry is
|
||||||
// either a top-level command or a "cmd sub" pair (so subcommands get their
|
// either a top-level command or a "cmd sub" pair (so subcommands get their
|
||||||
@@ -81,6 +81,23 @@ function _flatten() {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function _loadSkillEntries() {
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/skills/slash-catalog', { credentials: 'same-origin' });
|
||||||
|
if (!res.ok) return [];
|
||||||
|
const data = await res.json();
|
||||||
|
return (Array.isArray(data.skills) ? data.skills : []).map(s => ({
|
||||||
|
token: s.token || `/${s.name}`,
|
||||||
|
aliases: [],
|
||||||
|
category: s.category || 'Skills',
|
||||||
|
help: s.help || 'Run skill',
|
||||||
|
usage: s.usage || `${s.token || `/${s.name}`} <request>`,
|
||||||
|
})).filter(e => e.token && e.token.startsWith('/'));
|
||||||
|
} catch {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function _scoreMatch(entry, query) {
|
function _scoreMatch(entry, query) {
|
||||||
// query already starts with "/". Match against token + aliases. Prefix wins
|
// query already starts with "/". Match against token + aliases. Prefix wins
|
||||||
// over substring; alias match scores slightly lower than token match.
|
// over substring; alias match scores slightly lower than token match.
|
||||||
@@ -98,6 +115,17 @@ function _scoreMatch(entry, query) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function _exactCommandGroupItems(all, query) {
|
||||||
|
const q = query.toLowerCase();
|
||||||
|
if (!/^\/[a-z0-9_-]+$/i.test(q)) return [];
|
||||||
|
const parent = all.find(entry => entry.token.toLowerCase() === q);
|
||||||
|
if (!parent) return [];
|
||||||
|
const prefix = q + ' ';
|
||||||
|
const children = all.filter(entry => entry.token.toLowerCase().startsWith(prefix));
|
||||||
|
if (!children.length) return [];
|
||||||
|
return children.concat(parent);
|
||||||
|
}
|
||||||
|
|
||||||
function _ensurePopup(textarea) {
|
function _ensurePopup(textarea) {
|
||||||
let el = document.getElementById(POPUP_ID);
|
let el = document.getElementById(POPUP_ID);
|
||||||
if (el) return el;
|
if (el) return el;
|
||||||
@@ -164,7 +192,7 @@ export function initSlashAutocomplete(textarea) {
|
|||||||
if (!textarea || textarea._slashAcWired) return;
|
if (!textarea || textarea._slashAcWired) return;
|
||||||
textarea._slashAcWired = true;
|
textarea._slashAcWired = true;
|
||||||
|
|
||||||
const all = _flatten();
|
let all = _flatten();
|
||||||
let popup = null;
|
let popup = null;
|
||||||
let visible = false;
|
let visible = false;
|
||||||
let items = [];
|
let items = [];
|
||||||
@@ -191,12 +219,17 @@ export function initSlashAutocomplete(textarea) {
|
|||||||
// the menu hides — we don't autocomplete mid-sentence.
|
// the menu hides — we don't autocomplete mid-sentence.
|
||||||
if (!v.startsWith('/') || v.includes('\n')) { hide(); return; }
|
if (!v.startsWith('/') || v.includes('\n')) { hide(); return; }
|
||||||
const query = v.trim();
|
const query = v.trim();
|
||||||
items = all
|
const groupItems = _exactCommandGroupItems(all, query);
|
||||||
|
if (groupItems.length) {
|
||||||
|
items = groupItems.slice(0, MAX_VISIBLE);
|
||||||
|
} else {
|
||||||
|
items = all
|
||||||
.map(e => ({ e, s: _scoreMatch(e, query) }))
|
.map(e => ({ e, s: _scoreMatch(e, query) }))
|
||||||
.filter(x => x.s > 0)
|
.filter(x => x.s > 0)
|
||||||
.sort((a, b) => b.s - a.s)
|
.sort((a, b) => b.s - a.s)
|
||||||
.slice(0, MAX_VISIBLE)
|
.slice(0, MAX_VISIBLE)
|
||||||
.map(x => x.e);
|
.map(x => x.e);
|
||||||
|
}
|
||||||
if (!items.length && query.length > 1) { hide(); return; }
|
if (!items.length && query.length > 1) { hide(); return; }
|
||||||
if (!items.length) {
|
if (!items.length) {
|
||||||
// Just "/" with no matches — fall back to showing everything up to MAX_VISIBLE
|
// Just "/" with no matches — fall back to showing everything up to MAX_VISIBLE
|
||||||
@@ -207,6 +240,19 @@ export function initSlashAutocomplete(textarea) {
|
|||||||
_render(popup, items, selectedIdx, query);
|
_render(popup, items, selectedIdx, query);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
_loadSkillEntries().then(skillEntries => {
|
||||||
|
if (!skillEntries.length) return;
|
||||||
|
const seen = new Set(all.map(e => e.token));
|
||||||
|
const merged = all.slice();
|
||||||
|
for (const entry of skillEntries) {
|
||||||
|
if (seen.has(entry.token)) continue;
|
||||||
|
seen.add(entry.token);
|
||||||
|
merged.push(entry);
|
||||||
|
}
|
||||||
|
all = merged;
|
||||||
|
if (visible) refresh();
|
||||||
|
});
|
||||||
|
|
||||||
const insert = (token) => {
|
const insert = (token) => {
|
||||||
textarea.value = token + ' ';
|
textarea.value = token + ' ';
|
||||||
textarea.dispatchEvent(new Event('input', { bubbles: true }));
|
textarea.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
|
|||||||
+351
-71
@@ -21,6 +21,7 @@ import workspaceModule from './workspace.js';
|
|||||||
import settingsModule from './settings.js';
|
import settingsModule from './settings.js';
|
||||||
import cookbookModule from './cookbook.js';
|
import cookbookModule from './cookbook.js';
|
||||||
import { EVAL_PROMPTS } from './compare/index.js';
|
import { EVAL_PROMPTS } from './compare/index.js';
|
||||||
|
import { PROVIDER_DEVICE_FLOWS, formatDeviceFlowError, runProviderDeviceFlow } from './providerDeviceFlow.js';
|
||||||
|
|
||||||
// ── Module state ──────────────────────────────────────────────────────
|
// ── Module state ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -58,11 +59,28 @@ const SETUP_PROVIDER_URLS = {
|
|||||||
'opencode-go': { name: 'OpenCode Go', url: 'https://opencode.ai/zen/go/v1' },
|
'opencode-go': { name: 'OpenCode Go', url: 'https://opencode.ai/zen/go/v1' },
|
||||||
};
|
};
|
||||||
const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'ollama', 'xai', 'anthropic', 'groq', 'gemini', 'opencode-zen', 'opencode-go'];
|
const SETUP_PROVIDER_NAMES = ['deepseek', 'openai', 'openrouter', 'ollama', 'xai', 'anthropic', 'groq', 'gemini', 'opencode-zen', 'opencode-go'];
|
||||||
const SETUP_PROVIDER_HINT = SETUP_PROVIDER_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_NAMES[SETUP_PROVIDER_NAMES.length - 1];
|
const SETUP_DEVICE_AUTH_PROVIDERS = [
|
||||||
|
{ key: 'copilot', name: 'GitHub Copilot', aliases: ['github'], command: '/setup copilot' },
|
||||||
|
{ key: 'chatgpt-subscription', name: 'ChatGPT Subscription', aliases: ['chatgptsubscription', 'chatgpt-sub', 'codex'], command: '/setup chatgpt-subscription' },
|
||||||
|
];
|
||||||
|
const SETUP_PROVIDER_HINT_NAMES = SETUP_PROVIDER_NAMES.concat(SETUP_DEVICE_AUTH_PROVIDERS.map(provider => provider.key));
|
||||||
|
const SETUP_PROVIDER_HINT = SETUP_PROVIDER_HINT_NAMES.slice(0, -1).join(', ') + ', or ' + SETUP_PROVIDER_HINT_NAMES[SETUP_PROVIDER_HINT_NAMES.length - 1];
|
||||||
const SETUP_LOCAL_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><rect x="2" y="3" width="20" height="14" rx="2"/><path d="M8 21h8"/><path d="M12 17v4"/></svg>';
|
const SETUP_LOCAL_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><rect x="2" y="3" width="20" height="14" rx="2"/><path d="M8 21h8"/><path d="M12 17v4"/></svg>';
|
||||||
const SETUP_API_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>';
|
const SETUP_API_ICON = '<svg width="11" height="11" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-1px;margin-right:5px;"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>';
|
||||||
const SETUP_SETTINGS_ICON = '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-2px;margin-right:5px;"><circle cx="12" cy="12" r="3"/><path d="M19.4 15a1.65 1.65 0 0 0 .33 1.82l.06.06a2 2 0 0 1-2.83 2.83l-.06-.06a1.65 1.65 0 0 0-1.82-.33 1.65 1.65 0 0 0-1 1.51V21a2 2 0 0 1-4 0v-.09A1.65 1.65 0 0 0 9 19.4a1.65 1.65 0 0 0-1.82.33l-.06.06a2 2 0 0 1-2.83-2.83l.06-.06a1.65 1.65 0 0 0 .33-1.82 1.65 1.65 0 0 0-1.51-1H3a2 2 0 0 1 0-4h.09A1.65 1.65 0 0 0 4.6 9a1.65 1.65 0 0 0-.33-1.82l-.06-.06a2 2 0 0 1 2.83-2.83l.06.06a1.65 1.65 0 0 0 1.82.33H9a1.65 1.65 0 0 0 1-1.51V3a2 2 0 0 1 4 0v.09a1.65 1.65 0 0 0 1 1.51 1.65 1.65 0 0 0 1.82-.33l.06-.06a2 2 0 0 1 2.83 2.83l-.06.06a1.65 1.65 0 0 0-.33 1.82V9a1.65 1.65 0 0 0 1.51 1H21a2 2 0 0 1 0 4h-.09a1.65 1.65 0 0 0-1.51 1z"/></svg>';
|
const SETUP_SETTINGS_ICON = '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" style="vertical-align:-2px;margin-right:5px;"><circle cx="12" cy="12" r="3"/><path d="M19.4 15a1.65 1.65 0 0 0 .33 1.82l.06.06a2 2 0 0 1-2.83 2.83l-.06-.06a1.65 1.65 0 0 0-1.82-.33 1.65 1.65 0 0 0-1 1.51V21a2 2 0 0 1-4 0v-.09A1.65 1.65 0 0 0 9 19.4a1.65 1.65 0 0 0-1.82.33l-.06.06a2 2 0 0 1-2.83-2.83l.06-.06a1.65 1.65 0 0 0 .33-1.82 1.65 1.65 0 0 0-1.51-1H3a2 2 0 0 1 0-4h.09A1.65 1.65 0 0 0 4.6 9a1.65 1.65 0 0 0-.33-1.82l-.06-.06a2 2 0 0 1 2.83-2.83l.06.06a1.65 1.65 0 0 0 1.82.33H9a1.65 1.65 0 0 0 1-1.51V3a2 2 0 0 1 4 0v.09a1.65 1.65 0 0 0 1 1.51 1.65 1.65 0 0 0 1.82-.33l.06-.06a2 2 0 0 1 2.83 2.83l-.06.06a1.65 1.65 0 0 0-.33 1.82V9a1.65 1.65 0 0 0 1.51 1H21a2 2 0 0 1 0 4h-.09a1.65 1.65 0 0 0-1.51 1z"/></svg>';
|
||||||
|
|
||||||
|
function _setupApiProviderChips() {
|
||||||
|
return SETUP_PROVIDER_NAMES.map(name =>
|
||||||
|
'<span class="setup-clickable-provider" data-setup-kind="api-key" data-setup-provider="' + name + '" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
|
||||||
|
).join(' ');
|
||||||
|
}
|
||||||
|
|
||||||
|
function _setupDeviceAuthProviderChips() {
|
||||||
|
return SETUP_DEVICE_AUTH_PROVIDERS.map(provider =>
|
||||||
|
'<span class="setup-clickable-provider" data-setup-kind="device-auth" data-setup-provider="' + provider.key + '" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Run ' + provider.command + '">' + provider.name + '</span>'
|
||||||
|
).join(' ');
|
||||||
|
}
|
||||||
|
|
||||||
function _setupProviderFromInput(input) {
|
function _setupProviderFromInput(input) {
|
||||||
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '');
|
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '');
|
||||||
const aliases = {
|
const aliases = {
|
||||||
@@ -84,6 +102,17 @@ function _setupProviderFromInput(input) {
|
|||||||
return SETUP_PROVIDER_URLS[aliases[raw] || raw] || null;
|
return SETUP_PROVIDER_URLS[aliases[raw] || raw] || null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function _setupDeviceAuthProviderFromInput(input) {
|
||||||
|
const raw = (input || '').trim().toLowerCase().replace(/\s+/g, '').replace(/_/g, '-');
|
||||||
|
if (!raw) return '';
|
||||||
|
for (const provider of SETUP_DEVICE_AUTH_PROVIDERS) {
|
||||||
|
const candidates = [provider.key, provider.name, ...(provider.aliases || [])]
|
||||||
|
.map(value => String(value || '').toLowerCase().replace(/\s+/g, '').replace(/_/g, '-'));
|
||||||
|
if (candidates.includes(raw)) return provider.key;
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
function _extractSetupProviderCredential(input) {
|
function _extractSetupProviderCredential(input) {
|
||||||
const raw = (input || '').trim();
|
const raw = (input || '').trim();
|
||||||
if (!raw) return null;
|
if (!raw) return null;
|
||||||
@@ -158,9 +187,8 @@ function _setupReply(text, remember = true) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function _showSetupEndpointChoices() {
|
function _showSetupEndpointChoices() {
|
||||||
const providers = SETUP_PROVIDER_NAMES.map(name =>
|
const providers = _setupApiProviderChips();
|
||||||
'<span class="setup-clickable-provider" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
|
const deviceAuthProviders = _setupDeviceAuthProviderChips();
|
||||||
).join(' ');
|
|
||||||
return slashReply(
|
return slashReply(
|
||||||
'<div class="setup-guide-no-censor" style="display:grid;gap:10px;">' +
|
'<div class="setup-guide-no-censor" style="display:grid;gap:10px;">' +
|
||||||
'<div>' +
|
'<div>' +
|
||||||
@@ -178,6 +206,7 @@ function _showSetupEndpointChoices() {
|
|||||||
'<div>Paste provider name then API key (example):</div>' +
|
'<div>Paste provider name then API key (example):</div>' +
|
||||||
'<pre style="margin:4px 0 0;"><code class="setup-clickable-code" style="cursor:pointer;text-decoration:underline;" title="Click to fill in chat">deepseek sk-...</code></pre>' +
|
'<pre style="margin:4px 0 0;"><code class="setup-clickable-code" style="cursor:pointer;text-decoration:underline;" title="Click to fill in chat">deepseek sk-...</code></pre>' +
|
||||||
'<div style="margin-top:8px;font-size:1em;"><span>Supported providers:</span><br>' + providers + '</div>' +
|
'<div style="margin-top:8px;font-size:1em;"><span>Supported providers:</span><br>' + providers + '</div>' +
|
||||||
|
'<div style="margin-top:8px;font-size:1em;"><span>Account sign-in:</span><br>' + deviceAuthProviders + '</div>' +
|
||||||
'</div>' +
|
'</div>' +
|
||||||
'</div>'
|
'</div>'
|
||||||
);
|
);
|
||||||
@@ -208,9 +237,8 @@ function _showSetupEndpointChoicesStreamed(options = {}) {
|
|||||||
text: 'deepseek sk-...',
|
text: 'deepseek sk-...',
|
||||||
copyText: 'deepseek sk-...',
|
copyText: 'deepseek sk-...',
|
||||||
},
|
},
|
||||||
{ kind: 'p', html: '<strong>Supported providers:</strong><br>' + SETUP_PROVIDER_NAMES.map(name =>
|
{ kind: 'p', html: '<strong>Supported providers:</strong><br>' + _setupApiProviderChips() },
|
||||||
'<span class="setup-clickable-provider" style="cursor:pointer;text-decoration:underline;margin-right:8px;" title="Click to setup ' + name + '">' + name + '</span>'
|
{ kind: 'p', html: '<strong>Account sign-in:</strong><br>' + _setupDeviceAuthProviderChips() },
|
||||||
).join(' ') },
|
|
||||||
];
|
];
|
||||||
return typewriterBlocksReply(blocks, { gap: '4px', bodyClass: 'setup-guide-no-censor', interval: 3 });
|
return typewriterBlocksReply(blocks, { gap: '4px', bodyClass: 'setup-guide-no-censor', interval: 3 });
|
||||||
}
|
}
|
||||||
@@ -231,7 +259,7 @@ async function _hasConfiguredModels() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function _setupProviderPrompt() {
|
function _setupProviderPrompt() {
|
||||||
const chips = SETUP_PROVIDER_NAMES.map(name =>
|
const chips = SETUP_PROVIDER_HINT_NAMES.map(name =>
|
||||||
'<span style="font-weight:650;">' + name + '</span>'
|
'<span style="font-weight:650;">' + name + '</span>'
|
||||||
).join(' ');
|
).join(' ');
|
||||||
slashReply('<b>Supported providers:</b><br>' + chips);
|
slashReply('<b>Supported providers:</b><br>' + chips);
|
||||||
@@ -286,6 +314,53 @@ function slashReply(text) {
|
|||||||
return { el: div, body };
|
return { el: div, body };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let _skillCatalogCache = { at: 0, items: [] };
|
||||||
|
|
||||||
|
async function _loadSkillSlashCatalog(force = false) {
|
||||||
|
const now = Date.now();
|
||||||
|
if (!force && (now - _skillCatalogCache.at) < 15000) return _skillCatalogCache.items;
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_BASE}/api/skills/slash-catalog`, { credentials: 'same-origin' });
|
||||||
|
if (!res.ok) throw new Error('catalog unavailable');
|
||||||
|
const data = await res.json();
|
||||||
|
const items = Array.isArray(data.skills) ? data.skills : [];
|
||||||
|
_skillCatalogCache = { at: now, items };
|
||||||
|
return items;
|
||||||
|
} catch {
|
||||||
|
return _skillCatalogCache.items || [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function _submitComposedMessage(text) {
|
||||||
|
const msgInput = document.getElementById('message');
|
||||||
|
const form = document.getElementById('chat-form');
|
||||||
|
if (!msgInput || !form) return false;
|
||||||
|
msgInput.value = text;
|
||||||
|
msgInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
|
if (typeof form.requestSubmit === 'function') form.requestSubmit();
|
||||||
|
else form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true }));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _invokeSkillByName(name, requestText, ctx) {
|
||||||
|
const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/invoke`, {
|
||||||
|
method: 'POST',
|
||||||
|
credentials: 'same-origin',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ request: requestText || '' })
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
const err = await res.json().catch(() => null);
|
||||||
|
slashReply(ctx?.esc ? ctx.esc(err?.detail || 'Skill is not available') : 'Skill is not available');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const data = await res.json();
|
||||||
|
if (!data.message || !_submitComposedMessage(data.message)) {
|
||||||
|
slashReply('Could not start skill invocation.');
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/** Minimal footer for slash replies: copy + dismiss */
|
/** Minimal footer for slash replies: copy + dismiss */
|
||||||
function _slashFooter(msgEl) {
|
function _slashFooter(msgEl) {
|
||||||
const footer = document.createElement('div');
|
const footer = document.createElement('div');
|
||||||
@@ -681,6 +756,13 @@ async function handleSetupWizard(mode, input) {
|
|||||||
await _setupProviderPrompt();
|
await _setupProviderPrompt();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
const deviceAuthProvider = _setupDeviceAuthProviderFromInput(input);
|
||||||
|
if (deviceAuthProvider) {
|
||||||
|
_addMessage('user', input);
|
||||||
|
setupMode = false;
|
||||||
|
await _setupProviderDeviceFlow(deviceAuthProvider);
|
||||||
|
return;
|
||||||
|
}
|
||||||
const paired = _extractSetupProviderCredential(input);
|
const paired = _extractSetupProviderCredential(input);
|
||||||
const provider = paired?.provider || _setupProviderFromInput(input);
|
const provider = paired?.provider || _setupProviderFromInput(input);
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
@@ -1429,6 +1511,42 @@ async function _cmdModels(args, ctx) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function _cmdModel(args, ctx) {
|
||||||
|
const sub = (args[0] || '').toLowerCase();
|
||||||
|
if (sub === 'list' || sub === 'ls') return _cmdModels(args.slice(1), ctx);
|
||||||
|
|
||||||
|
const model = sessionModule.getCurrentModel ? sessionModule.getCurrentModel() : '';
|
||||||
|
const endpoint = sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : '';
|
||||||
|
slashReply(`<pre>${[
|
||||||
|
`Current model: ${ctx.esc(model || 'None selected')}`,
|
||||||
|
endpoint ? `Endpoint: ${ctx.esc(endpoint)}` : 'Endpoint: not available',
|
||||||
|
'',
|
||||||
|
'Usage: /model list to show all available models'
|
||||||
|
].join('\n')}</pre>`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _cmdMcp(args, ctx) {
|
||||||
|
const res = await fetch(`${API_BASE}/api/mcp/servers`, { credentials: 'same-origin' });
|
||||||
|
if (!res.ok) {
|
||||||
|
slashReply('MCP status is unavailable for this user.');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const servers = await res.json();
|
||||||
|
if (!Array.isArray(servers) || !servers.length) {
|
||||||
|
slashReply('No MCP servers configured.');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const lines = servers.map(s => {
|
||||||
|
const status = s.status || (s.is_enabled ? 'enabled' : 'disabled');
|
||||||
|
const enabled = Number(s.enabled_tool_count ?? s.tool_count ?? 0);
|
||||||
|
const total = Number(s.tool_count ?? enabled);
|
||||||
|
return `${s.name || s.id || 'MCP server'} - ${status} (${enabled}/${total} tools)`;
|
||||||
|
});
|
||||||
|
slashReply(`<pre>${lines.map(line => ctx.esc(line)).join('\n')}</pre>`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// ── Memory ──
|
// ── Memory ──
|
||||||
|
|
||||||
async function _cmdMemoryList(args, ctx) {
|
async function _cmdMemoryList(args, ctx) {
|
||||||
@@ -1507,6 +1625,73 @@ async function _cmdMemorySearch(args, ctx) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Skills ──
|
||||||
|
|
||||||
|
async function _cmdSkills(args, ctx) {
|
||||||
|
const sub = (args[0] || 'list').toLowerCase();
|
||||||
|
const rest = args.slice(1);
|
||||||
|
|
||||||
|
if (sub === 'list' || sub === 'ls') {
|
||||||
|
const skills = await _loadSkillSlashCatalog(true);
|
||||||
|
if (!skills.length) {
|
||||||
|
slashReply('No published skills available for slash commands.');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const lines = skills.map(s => {
|
||||||
|
const uses = Number(s.uses || 0);
|
||||||
|
const useText = uses > 0 ? ` uses:${uses}` : '';
|
||||||
|
return `${ctx.esc(String(s.token || '').padEnd(24))}${ctx.esc(s.help || '')}${useText}`;
|
||||||
|
});
|
||||||
|
slashReply(`<pre>${lines.join('\n')}</pre>`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sub === 'search' || sub === 'find') {
|
||||||
|
const query = rest.join(' ').trim();
|
||||||
|
if (!query) { slashReply('Usage: /skills search query'); return true; }
|
||||||
|
const res = await fetch(`${API_BASE}/api/skills/search`, {
|
||||||
|
method: 'POST',
|
||||||
|
credentials: 'same-origin',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ query })
|
||||||
|
});
|
||||||
|
if (!res.ok) { slashReply('Skill search failed.'); return true; }
|
||||||
|
const data = await res.json();
|
||||||
|
const skills = Array.isArray(data.skills) ? data.skills : [];
|
||||||
|
if (!skills.length) { slashReply(`No skills found for "${ctx.esc(query)}".`); return true; }
|
||||||
|
const lines = skills.map(s =>
|
||||||
|
ctx.esc(`/${s.name || s.id || ''}`.padEnd(24)) + ctx.esc(s.description || '')
|
||||||
|
);
|
||||||
|
slashReply(`<pre>${lines.join('\n')}</pre>`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sub === 'view' || sub === 'cat' || sub === 'show') {
|
||||||
|
const name = (rest[0] || '').trim();
|
||||||
|
if (!name) { slashReply('Usage: /skills view name'); return true; }
|
||||||
|
const res = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(name)}/markdown`, { credentials: 'same-origin' });
|
||||||
|
if (!res.ok) { slashReply(`Skill "${ctx.esc(name)}" was not found.`); return true; }
|
||||||
|
const data = await res.json();
|
||||||
|
slashReply(`<pre>${ctx.esc(data.markdown || '')}</pre>`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sub === 'use' || sub === 'run') {
|
||||||
|
const name = (rest[0] || '').trim();
|
||||||
|
if (!name) { slashReply('Usage: /skills use name request'); return true; }
|
||||||
|
return _invokeSkillByName(name, rest.slice(1).join(' ').trim(), ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
slashReply('Usage: /skills list | search query | view name | use name request');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function _cmdReloadSkills(args, ctx) {
|
||||||
|
const skills = await _loadSkillSlashCatalog(true);
|
||||||
|
slashReply(`Reloaded skills. ${skills.length} skill command${skills.length === 1 ? '' : 's'} available.`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// ── Note (quick Notes shortcut) ──
|
// ── Note (quick Notes shortcut) ──
|
||||||
|
|
||||||
async function _cmdNote(args, ctx) {
|
async function _cmdNote(args, ctx) {
|
||||||
@@ -1799,6 +1984,53 @@ Uploads: ${d.uploads || '?'}</pre>`);
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function _cmdUsage(args, ctx) {
|
||||||
|
const sid = ctx.sid;
|
||||||
|
if (!sid) {
|
||||||
|
slashReply('No active session.');
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let session = null;
|
||||||
|
try {
|
||||||
|
const sessions = sessionModule.getSessions ? sessionModule.getSessions() : [];
|
||||||
|
session = (sessions || []).find(s => s.id === sid) || null;
|
||||||
|
if (!session) {
|
||||||
|
const res = await fetch(`${API_BASE}/api/sessions`, { credentials: 'same-origin' });
|
||||||
|
if (res.ok) {
|
||||||
|
const data = await res.json();
|
||||||
|
const items = Array.isArray(data) ? data : (data.sessions || data.items || []);
|
||||||
|
session = items.find(s => s.id === sid) || null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (_) {}
|
||||||
|
|
||||||
|
const model = session?.model || 'Unknown';
|
||||||
|
const endpointUrl = session?.endpoint_url || (
|
||||||
|
sessionModule.getCurrentEndpointUrl ? sessionModule.getCurrentEndpointUrl() : ''
|
||||||
|
);
|
||||||
|
const messageCount = Number(session?.message_count || 0);
|
||||||
|
const totalTokens = Number(session?.total_tokens || 0);
|
||||||
|
const costTracked = chatRenderer.isCostTrackedEndpoint ? chatRenderer.isCostTrackedEndpoint(endpointUrl) : true;
|
||||||
|
const cost = costTracked && chatRenderer.getSessionCost ? Number(chatRenderer.getSessionCost(sid) || 0) : 0;
|
||||||
|
const costLine = costTracked
|
||||||
|
? (cost > 0
|
||||||
|
? `Estimated local cost: $${cost < 0.01 ? cost.toFixed(4) : cost.toFixed(3)}`
|
||||||
|
: 'Estimated local cost: unavailable or zero')
|
||||||
|
: 'Estimated local cost: not tracked for this endpoint';
|
||||||
|
|
||||||
|
slashReply(`<pre>${[
|
||||||
|
`Session: ${ctx.esc(session?.name || 'Current chat')}`,
|
||||||
|
`Model: ${ctx.esc(model)}`,
|
||||||
|
`Messages: ${messageCount.toLocaleString()}`,
|
||||||
|
`Recorded tokens: ${totalTokens.toLocaleString()}`,
|
||||||
|
costLine,
|
||||||
|
'',
|
||||||
|
'Provider account usage is not available from here; check the provider dashboard for account quota/billing.'
|
||||||
|
].join('\n')}</pre>`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// ── Context compaction ──
|
// ── Context compaction ──
|
||||||
|
|
||||||
async function _cmdCompact(args, ctx) {
|
async function _cmdCompact(args, ctx) {
|
||||||
@@ -4783,39 +5015,53 @@ function _clearSetupCommandInput() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GitHub Copilot device-flow sign-in, driven from chat (mirrors the Settings
|
async function _setupProviderDeviceFlow(providerKey) {
|
||||||
// "Connect GitHub Copilot" button). Replies via the setup guide messages.
|
|
||||||
async function _setupCopilot() {
|
|
||||||
_clearSetupGuideMessages();
|
_clearSetupGuideMessages();
|
||||||
await _setupReply('Starting GitHub Copilot sign-in…');
|
const config = PROVIDER_DEVICE_FLOWS[providerKey];
|
||||||
let start;
|
if (!config) {
|
||||||
|
await _setupReply('Provider not recognised.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await _setupReply(`Starting ${config.label} sign-in...`);
|
||||||
try {
|
try {
|
||||||
const r = await fetch(`${API_BASE}/api/copilot/device/start`, { method: 'POST', body: new FormData(), credentials: 'same-origin' });
|
const result = await runProviderDeviceFlow(providerKey, {
|
||||||
start = await r.json();
|
onStart: async ({ start, authUrl }) => {
|
||||||
if (!r.ok) { await _setupReply(start.detail || 'Failed to start Copilot sign-in.'); return; }
|
const place = providerKey === 'copilot' ? 'GitHub' : 'OpenAI';
|
||||||
} catch (e) { await _setupReply('Request failed.'); return; }
|
const action = providerKey === 'copilot' ? 'approve the request' : 'enter the code';
|
||||||
const authUrl = start.verification_uri_complete || start.verification_uri || '';
|
if (providerKey === 'chatgpt-subscription') {
|
||||||
await _setupReply(`Opening GitHub — approve the request (code ${start.user_code}). Waiting…`);
|
slashReply(
|
||||||
try { if (authUrl) window.open(authUrl, '_blank', 'noopener'); } catch (e) {}
|
'<div class="setup-guide-no-censor" style="display:grid;gap:6px;">' +
|
||||||
const deadline = Date.now() + (start.expires_in || 900) * 1000;
|
'<div>Open this URL in your browser, enter the code, then come back here. Waiting...</div>' +
|
||||||
const stepMs = Math.max((start.interval || 5), 2) * 1000;
|
'<div>Code: <code>' + uiModule.esc(start.user_code || '') + '</code></div>' +
|
||||||
const poll = async () => {
|
'<div><a href="' + uiModule.esc(authUrl || '') + '" target="_blank" rel="noopener noreferrer">' + uiModule.esc(authUrl || '') + '</a></div>' +
|
||||||
if (Date.now() > deadline) { await _setupReply('Copilot sign-in expired — run /setup copilot again.'); return; }
|
'</div>'
|
||||||
try {
|
);
|
||||||
const fd = new FormData(); fd.append('poll_id', start.poll_id);
|
return;
|
||||||
const r = await fetch(`${API_BASE}/api/copilot/device/poll`, { method: 'POST', body: fd, credentials: 'same-origin' });
|
}
|
||||||
const d = await r.json();
|
await _setupReply(`Opening ${place} - ${action} (code ${start.user_code}). Waiting...`);
|
||||||
if (d.status === 'authorized') {
|
},
|
||||||
const n = ((d.endpoint && d.endpoint.models) || []).length;
|
openWindow: (url) => {
|
||||||
await _setupReply(`Connected — ${n} Copilot model${n !== 1 ? 's' : ''} available.`);
|
if (providerKey === 'chatgpt-subscription') return;
|
||||||
if (modelsModule) modelsModule.refreshModels(true);
|
try { if (url) window.open(url, '_blank', 'noopener'); } catch (e) {}
|
||||||
return;
|
},
|
||||||
}
|
});
|
||||||
if (d.status === 'failed') { await _setupReply('Copilot sign-in failed (' + (d.error || 'denied') + ').'); return; }
|
if (result.status === 'authorized') {
|
||||||
} catch (e) { /* transient — keep polling */ }
|
const n = ((result.endpoint && result.endpoint.models) || []).length;
|
||||||
setTimeout(poll, stepMs);
|
await _setupReply(`Connected - ${n} ${config.label} model${n !== 1 ? 's' : ''} available.`);
|
||||||
};
|
if (modelsModule) modelsModule.refreshModels(true);
|
||||||
setTimeout(poll, stepMs);
|
return;
|
||||||
|
}
|
||||||
|
if (result.status === 'failed') {
|
||||||
|
await _setupReply(`${config.label} sign-in failed (${result.error || 'denied'}).`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (result.status === 'expired') {
|
||||||
|
await _setupReply(`${config.label} sign-in expired - run /setup ${providerKey} again.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
await _setupReply(formatDeviceFlowError(e));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function _cmdSetup(args, ctx) {
|
async function _cmdSetup(args, ctx) {
|
||||||
@@ -4823,7 +5069,11 @@ async function _cmdSetup(args, ctx) {
|
|||||||
_clearSetupCommandInput();
|
_clearSetupCommandInput();
|
||||||
const topic = (args[0] || '').trim().toLowerCase();
|
const topic = (args[0] || '').trim().toLowerCase();
|
||||||
const topicArgs = args.slice(1);
|
const topicArgs = args.slice(1);
|
||||||
if (topic === 'copilot' || topic === 'github') { await _setupCopilot(); return true; }
|
const deviceAuthProvider = _setupDeviceAuthProviderFromInput(topic);
|
||||||
|
if (deviceAuthProvider) {
|
||||||
|
await _setupProviderDeviceFlow(deviceAuthProvider);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
const provider = _setupProviderFromInput(topic);
|
const provider = _setupProviderFromInput(topic);
|
||||||
if (provider) {
|
if (provider) {
|
||||||
_clearSetupGuideMessages();
|
_clearSetupGuideMessages();
|
||||||
@@ -5463,8 +5713,20 @@ async function _cmdHelp(args, ctx) {
|
|||||||
lines.push('');
|
lines.push('');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const skillCommands = await _loadSkillSlashCatalog(false);
|
||||||
|
if (skillCommands.length) {
|
||||||
|
lines.push('Skills:');
|
||||||
|
for (const skill of skillCommands.slice(0, 20)) {
|
||||||
|
const token = String(skill.token || '').padEnd(21);
|
||||||
|
lines.push(` ${ctx.esc(token)}${ctx.esc(skill.help || '')}`);
|
||||||
|
}
|
||||||
|
if (skillCommands.length > 20) {
|
||||||
|
lines.push(` ... ${skillCommands.length - 20} more. Use /skills list`);
|
||||||
|
}
|
||||||
|
lines.push('');
|
||||||
|
}
|
||||||
lines.push('Tip: /<command> --help for details');
|
lines.push('Tip: /<command> --help for details');
|
||||||
lines.push('Shortcuts: /new /rename /fork /web /bash /memories /forget');
|
lines.push('Shortcuts: /new /rename /fork /web /bash /memories /skills');
|
||||||
slashReply(`<pre style="line-height:1.7">${lines.join('\n')}</pre>`);
|
slashReply(`<pre style="line-height:1.7">${lines.join('\n')}</pre>`);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@@ -5539,6 +5801,20 @@ const COMMANDS = {
|
|||||||
'search': { handler: _cmdMemorySearch, alias: ['grep'], help: 'Search memories', usage: '/memory search q' }
|
'search': { handler: _cmdMemorySearch, alias: ['grep'], help: 'Search memories', usage: '/memory search q' }
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
skills: {
|
||||||
|
alias: ['skill'],
|
||||||
|
category: 'Memory',
|
||||||
|
help: 'List, search, inspect, or run skills',
|
||||||
|
handler: _cmdSkills,
|
||||||
|
usage: '/skills list | search query | view name | use name request',
|
||||||
|
},
|
||||||
|
'reload-skills': {
|
||||||
|
alias: ['reload_skills'],
|
||||||
|
category: 'Memory',
|
||||||
|
help: 'Refresh the slash skill catalog',
|
||||||
|
handler: _cmdReloadSkills,
|
||||||
|
usage: '/reload-skills',
|
||||||
|
},
|
||||||
rag: {
|
rag: {
|
||||||
alias: [],
|
alias: [],
|
||||||
category: 'RAG',
|
category: 'RAG',
|
||||||
@@ -5572,7 +5848,7 @@ const COMMANDS = {
|
|||||||
category: 'Getting started',
|
category: 'Getting started',
|
||||||
help: 'Add local or API model endpoints',
|
help: 'Add local or API model endpoints',
|
||||||
handler: _cmdSetup,
|
handler: _cmdSetup,
|
||||||
usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup endpoint',
|
usage: '/setup local URL · /setup groq KEY · /setup copilot · /setup chatgpt-subscription',
|
||||||
// Provider subs so the autocomplete popup surfaces "/setup deepseek",
|
// Provider subs so the autocomplete popup surfaces "/setup deepseek",
|
||||||
// "/setup openai", etc. when the user types "/setup de". Each sub's
|
// "/setup openai", etc. when the user types "/setup de". Each sub's
|
||||||
// handler is a thin wrapper that re-prepends the sub name and
|
// handler is a thin wrapper that re-prepends the sub name and
|
||||||
@@ -5590,6 +5866,7 @@ const COMMANDS = {
|
|||||||
xai: { help: 'xAI (Grok)', alias: ['grok'], usage: '/setup xai xai-...', handler: (a, c) => _cmdSetup(['xai', ...a], c) },
|
xai: { help: 'xAI (Grok)', alias: ['grok'], usage: '/setup xai xai-...', handler: (a, c) => _cmdSetup(['xai', ...a], c) },
|
||||||
ollama: { help: 'Ollama Cloud', usage: '/setup ollama KEY', handler: (a, c) => _cmdSetup(['ollama', ...a], c) },
|
ollama: { help: 'Ollama Cloud', usage: '/setup ollama KEY', handler: (a, c) => _cmdSetup(['ollama', ...a], c) },
|
||||||
copilot: { help: 'GitHub Copilot', usage: '/setup copilot', handler: (a, c) => _cmdSetup(['copilot', ...a], c) },
|
copilot: { help: 'GitHub Copilot', usage: '/setup copilot', handler: (a, c) => _cmdSetup(['copilot', ...a], c) },
|
||||||
|
'chatgpt-subscription': { help: 'ChatGPT Subscription', alias: ['codex'], usage: '/setup chatgpt-subscription', handler: (a, c) => _cmdSetup(['chatgpt-subscription', ...a], c) },
|
||||||
local: { help: 'Local model server (vLLM / LM Studio / llama.cpp / Ollama)',
|
local: { help: 'Local model server (vLLM / LM Studio / llama.cpp / Ollama)',
|
||||||
usage: '/setup local http://localhost:8000/v1',
|
usage: '/setup local http://localhost:8000/v1',
|
||||||
handler: (a, c) => _cmdSetup(['local', ...a], c) },
|
handler: (a, c) => _cmdSetup(['local', ...a], c) },
|
||||||
@@ -5767,8 +6044,22 @@ const COMMANDS = {
|
|||||||
handler: (args, ctx) => _cmdToolPanel('compare', args, ctx),
|
handler: (args, ctx) => _cmdToolPanel('compare', args, ctx),
|
||||||
usage: '/compare'
|
usage: '/compare'
|
||||||
},
|
},
|
||||||
|
mcp: {
|
||||||
|
alias: [],
|
||||||
|
category: 'Tools',
|
||||||
|
help: 'Show MCP server status',
|
||||||
|
handler: _cmdMcp,
|
||||||
|
usage: '/mcp'
|
||||||
|
},
|
||||||
|
model: {
|
||||||
|
alias: [],
|
||||||
|
category: 'Settings',
|
||||||
|
help: 'Show current chat model',
|
||||||
|
handler: _cmdModel,
|
||||||
|
usage: '/model · /model list'
|
||||||
|
},
|
||||||
models: {
|
models: {
|
||||||
alias: ['model'],
|
alias: [],
|
||||||
category: 'Settings',
|
category: 'Settings',
|
||||||
help: 'List available models',
|
help: 'List available models',
|
||||||
handler: _cmdModels,
|
handler: _cmdModels,
|
||||||
@@ -5799,10 +6090,16 @@ const COMMANDS = {
|
|||||||
handler: _cmdStats,
|
handler: _cmdStats,
|
||||||
usage: '/stats'
|
usage: '/stats'
|
||||||
},
|
},
|
||||||
|
usage: {
|
||||||
|
alias: ['cost', 'tokens'],
|
||||||
|
category: 'Utility',
|
||||||
|
help: 'Show local usage for the current chat',
|
||||||
|
handler: _cmdUsage,
|
||||||
|
usage: '/usage'
|
||||||
|
},
|
||||||
compact: {
|
compact: {
|
||||||
alias: [],
|
alias: [],
|
||||||
category: 'Utility',
|
category: 'Utility',
|
||||||
hidden: true,
|
|
||||||
help: 'Compact older chat messages',
|
help: 'Compact older chat messages',
|
||||||
handler: _cmdCompact,
|
handler: _cmdCompact,
|
||||||
usage: '/compact'
|
usage: '/compact'
|
||||||
@@ -6075,33 +6372,13 @@ async function handleSlashCommand(input) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// --- 4. Skill invocation: /<skill-name> [request] ---
|
// --- 4. Skill invocation: /<skill-name> [request] ---
|
||||||
// If `rawCmd` matches a published skill, pin its SKILL.md to the user's
|
// If `rawCmd` matches a published skill, the backend records usage and
|
||||||
// message and re-submit. Lets you fire a stored procedure on demand
|
// returns a skill-pinned message to submit as the next agent turn.
|
||||||
// without the model having to discover the skill itself.
|
|
||||||
try {
|
try {
|
||||||
const skillRes = await fetch(`${API_BASE}/api/skills/${encodeURIComponent(rawCmd)}/markdown`, { credentials: 'same-origin' });
|
const catalog = await _loadSkillSlashCatalog(false);
|
||||||
if (skillRes.ok) {
|
if (catalog.some(s => s.name === rawCmd)) {
|
||||||
const skillData = await skillRes.json();
|
_showUser();
|
||||||
const md = skillData.markdown || '';
|
return await _invokeSkillByName(rawCmd, args.join(' ').trim(), ctx);
|
||||||
if (md) {
|
|
||||||
_showUser();
|
|
||||||
const request = args.join(' ').trim();
|
|
||||||
const msgInput = document.getElementById('message');
|
|
||||||
const composed =
|
|
||||||
`Apply the skill below to my request, following its Procedure / Pitfalls / Verification.\n\n` +
|
|
||||||
`--- BEGIN SKILL ---\n${md}\n--- END SKILL ---\n\n` +
|
|
||||||
(request ? `Request: ${request}` : `Request: (use the skill as appropriate)`);
|
|
||||||
if (msgInput) {
|
|
||||||
msgInput.value = composed;
|
|
||||||
const form = document.getElementById('chat-form');
|
|
||||||
if (form && typeof form.requestSubmit === 'function') {
|
|
||||||
form.requestSubmit();
|
|
||||||
} else if (form) {
|
|
||||||
form.dispatchEvent(new Event('submit', { cancelable: true, bubbles: true }));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} catch (_) { /* fall through to fuzzy match */ }
|
} catch (_) { /* fall through to fuzzy match */ }
|
||||||
|
|
||||||
@@ -6158,10 +6435,13 @@ export function initSlashCommands(deps) {
|
|||||||
const providerEl = e.target.closest('.setup-clickable-provider');
|
const providerEl = e.target.closest('.setup-clickable-provider');
|
||||||
if (providerEl) {
|
if (providerEl) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
const providerKey = providerEl.dataset.setupProvider || providerEl.textContent.trim();
|
||||||
const providerName = providerEl.textContent.trim();
|
const providerName = providerEl.textContent.trim();
|
||||||
const messageInput = document.getElementById('message');
|
const messageInput = document.getElementById('message');
|
||||||
if (messageInput) {
|
if (messageInput) {
|
||||||
const text = providerName + ' sk-';
|
const text = providerEl.dataset.setupKind === 'device-auth'
|
||||||
|
? '/setup ' + providerKey
|
||||||
|
: providerName + ' sk-';
|
||||||
messageInput.value = text;
|
messageInput.value = text;
|
||||||
messageInput.dispatchEvent(new Event('input', { bubbles: true }));
|
messageInput.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
messageInput.focus();
|
messageInput.focus();
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
"""Static regressions for Add Models provider device-flow UX."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
_REPO = Path(__file__).resolve().parent.parent
|
||||||
|
_INDEX = (_REPO / "static" / "index.html").read_text(encoding="utf-8")
|
||||||
|
_ADMIN = (_REPO / "static" / "js" / "admin.js").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _between(src: str, start: str, end: str) -> str:
|
||||||
|
start_idx = src.index(start)
|
||||||
|
end_idx = src.index(end, start_idx)
|
||||||
|
return src[start_idx:end_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def test_copilot_and_chatgpt_subscription_are_dropdown_device_auth_options():
|
||||||
|
assert 'value="copilot" data-logo="github" data-auth-flow="copilot">GitHub Copilot' in _INDEX
|
||||||
|
assert 'value="chatgpt-subscription" data-logo="openai" data-auth-flow="chatgpt-subscription">ChatGPT Subscription' in _INDEX
|
||||||
|
assert 'id="adm-deviceAuthStatus"' in _INDEX
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_selection_is_inert_and_add_button_starts_device_flow():
|
||||||
|
change_block = _between(_ADMIN, "provider.addEventListener('change'", "urlInput.addEventListener('input'")
|
||||||
|
add_block = _between(_ADMIN, "el('adm-epAddBtn').addEventListener('click'", "async function _startProviderDeviceAuth")
|
||||||
|
|
||||||
|
assert "_startProviderDeviceAuth" not in change_block
|
||||||
|
assert "_startProviderDeviceAuth(deviceAuthProvider" in add_block
|
||||||
|
|
||||||
|
|
||||||
|
def test_device_auth_selection_disables_and_dims_api_test_button():
|
||||||
|
form_block = _between(_ADMIN, "function _setApiFormForProvider()", "function _renderPickerMenu()")
|
||||||
|
|
||||||
|
assert "testBtn.disabled = true" in form_block
|
||||||
|
assert "testBtn.style.opacity = '0.45'" in form_block
|
||||||
|
assert "testBtn.style.cursor = 'not-allowed'" in form_block
|
||||||
|
assert "testBtn.disabled = false" in form_block
|
||||||
|
assert "testBtn.style.opacity = ''" in form_block
|
||||||
|
assert "testBtn.style.cursor = ''" in form_block
|
||||||
|
|
||||||
|
|
||||||
|
def test_device_auth_keeps_manual_auth_button_without_auto_opening_tab():
|
||||||
|
auth_block = _between(_ADMIN, "async function _startProviderDeviceAuth", "// Local \"Add\" button")
|
||||||
|
|
||||||
|
assert "Authorize with OpenAI" in auth_block
|
||||||
|
assert "Authorize on GitHub" in auth_block
|
||||||
|
assert "adm-copilot-panel" in auth_block
|
||||||
|
assert "adm-device-auth-copy" in auth_block
|
||||||
|
assert "openWindow: () => {}" in auth_block
|
||||||
|
assert "A new tab opened" not in auth_block
|
||||||
|
|
||||||
|
|
||||||
|
def test_loud_oauth_copy_and_removed_button_hooks_do_not_return():
|
||||||
|
forbidden = [
|
||||||
|
"Click Add to start",
|
||||||
|
"uses account sign-in",
|
||||||
|
"Uses ChatGPT/Codex OAuth, not an OpenAI API key.",
|
||||||
|
"adm-chatgptStatus",
|
||||||
|
"adm-chatgptConnectBtn",
|
||||||
|
"adm-copilotConnectBtn",
|
||||||
|
"adm-copilotStatus",
|
||||||
|
]
|
||||||
|
for needle in forbidden:
|
||||||
|
assert needle not in _INDEX
|
||||||
|
assert needle not in _ADMIN
|
||||||
@@ -0,0 +1,280 @@
|
|||||||
|
"""DB-backed ChatGPT Subscription endpoint provisioning tests."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from core.database import Base, ModelEndpoint, ProviderAuthSession
|
||||||
|
import routes.chatgpt_subscription_routes as csr
|
||||||
|
|
||||||
|
|
||||||
|
def _mem_db(monkeypatch):
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
# Match production (core.database SessionLocal is autoflush=False): a pending
|
||||||
|
# db.delete(ep) is NOT flushed before the orphan-auth reference-count SELECT,
|
||||||
|
# which is exactly why _delete_orphaned_provider_auth needs exclude_ep_id.
|
||||||
|
TestSessionLocal = sessionmaker(bind=engine, autoflush=False)
|
||||||
|
monkeypatch.setattr(csr, "SessionLocal", TestSessionLocal)
|
||||||
|
return TestSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_creates_owner_scoped_auth_session_and_endpoint(monkeypatch):
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5", "o4-mini"])
|
||||||
|
|
||||||
|
res = csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice")
|
||||||
|
|
||||||
|
assert res["name"] == "ChatGPT Subscription"
|
||||||
|
assert res["base_url"] == csr.chatgpt_subscription.DEFAULT_CHATGPT_SUBSCRIPTION_BASE_URL
|
||||||
|
assert res["models"] == ["gpt-5.5", "o4-mini"]
|
||||||
|
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
auth = db.query(ProviderAuthSession).first()
|
||||||
|
ep = db.query(ModelEndpoint).filter(ModelEndpoint.id == res["id"]).first()
|
||||||
|
assert auth is not None
|
||||||
|
assert auth.owner == "alice"
|
||||||
|
assert auth.provider == csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER
|
||||||
|
assert auth.access_token == "AT"
|
||||||
|
assert auth.refresh_token == "RT"
|
||||||
|
assert auth.auth_mode == "chatgpt"
|
||||||
|
assert ep is not None
|
||||||
|
assert ep.owner == "alice"
|
||||||
|
assert ep.api_key is None
|
||||||
|
assert ep.provider_auth_id == auth.id
|
||||||
|
assert ep.endpoint_kind == "api"
|
||||||
|
assert ep.model_refresh_mode == "manual"
|
||||||
|
assert ep.supports_tools is False
|
||||||
|
assert json.loads(ep.cached_models) == ["gpt-5.5", "o4-mini"]
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_refreshes_existing_auth_session_and_endpoint(monkeypatch):
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: ["gpt-5.5"])
|
||||||
|
|
||||||
|
first = csr._provision_endpoint({"access_token": "OLD", "refresh_token": "OLD-RT"}, "bob")
|
||||||
|
second = csr._provision_endpoint({"access_token": "NEW", "refresh_token": "NEW-RT"}, "bob")
|
||||||
|
|
||||||
|
assert first["id"] == second["id"]
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
auth_rows = db.query(ProviderAuthSession).filter(ProviderAuthSession.owner == "bob").all()
|
||||||
|
ep_rows = db.query(ModelEndpoint).filter(ModelEndpoint.owner == "bob").all()
|
||||||
|
assert len(auth_rows) == 1
|
||||||
|
assert len(ep_rows) == 1
|
||||||
|
assert auth_rows[0].access_token == "NEW"
|
||||||
|
assert auth_rows[0].refresh_token == "NEW-RT"
|
||||||
|
assert ep_rows[0].provider_auth_id == auth_rows[0].id
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_rejects_missing_tokens(monkeypatch):
|
||||||
|
_mem_db(monkeypatch)
|
||||||
|
with pytest.raises(ValueError, match="missing access_token or refresh_token"):
|
||||||
|
csr._provision_endpoint({"access_token": "AT"}, "alice")
|
||||||
|
|
||||||
|
|
||||||
|
def test_provision_rejects_accounts_without_usable_models(monkeypatch):
|
||||||
|
_mem_db(monkeypatch)
|
||||||
|
monkeypatch.setattr(csr.chatgpt_subscription, "fetch_available_models", lambda token: [])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="no usable Codex models"):
|
||||||
|
csr._provision_endpoint({"access_token": "AT", "refresh_token": "RT"}, "alice")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_auth_and_endpoints(db, *, auth_id="auth1", ep_ids=("ep1",)):
|
||||||
|
db.add(ProviderAuthSession(
|
||||||
|
id=auth_id, provider=csr.chatgpt_subscription.CHATGPT_SUBSCRIPTION_PROVIDER,
|
||||||
|
owner="alice", base_url="https://chatgpt.com/backend-api/codex",
|
||||||
|
refresh_token="RT", auth_mode="chatgpt",
|
||||||
|
))
|
||||||
|
for ep_id in ep_ids:
|
||||||
|
db.add(ModelEndpoint(
|
||||||
|
id=ep_id, name="ChatGPT Subscription",
|
||||||
|
base_url="https://chatgpt.com/backend-api/codex",
|
||||||
|
provider_auth_id=auth_id, owner="alice",
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_orphaned_provider_auth_revokes_when_last_endpoint_removed(monkeypatch):
|
||||||
|
from routes.model_routes import _delete_orphaned_provider_auth
|
||||||
|
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
|
||||||
|
# Mirror the production delete route: db.delete(ep) is issued (but not yet
|
||||||
|
# flushed/committed) BEFORE the orphan check runs.
|
||||||
|
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||||
|
db.delete(ep1)
|
||||||
|
# ep1 (its only referencing endpoint) is being deleted, so the auth clears.
|
||||||
|
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is True
|
||||||
|
db.commit()
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_orphaned_provider_auth_requires_exclude_ep_id_for_pending_delete(monkeypatch):
|
||||||
|
from routes.model_routes import _delete_orphaned_provider_auth
|
||||||
|
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
|
||||||
|
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||||
|
db.delete(ep1)
|
||||||
|
# Without exclude_ep_id, the un-flushed pending delete leaves ep1 visible
|
||||||
|
# to the reference-count SELECT (autoflush=False), so the helper must
|
||||||
|
# conservatively KEEP the auth row. This is the bug exclude_ep_id fixes.
|
||||||
|
assert _delete_orphaned_provider_auth(db, "auth1") is False
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_orphaned_provider_auth_keeps_auth_while_another_endpoint_uses_it(monkeypatch):
|
||||||
|
from routes.model_routes import _delete_orphaned_provider_auth
|
||||||
|
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
|
||||||
|
# ep2 still references auth1, so deleting ep1 must NOT revoke it.
|
||||||
|
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_orphaned_provider_auth_noop_without_auth_id(monkeypatch):
|
||||||
|
from routes.model_routes import _delete_orphaned_provider_auth
|
||||||
|
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
assert _delete_orphaned_provider_auth(db, None, exclude_ep_id="ep1") is False
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_orphaned_provider_auth_noop_when_auth_row_missing(monkeypatch):
|
||||||
|
from routes.model_routes import _delete_orphaned_provider_auth
|
||||||
|
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
# Endpoint points at an auth_id whose ProviderAuthSession is already gone.
|
||||||
|
db.add(ModelEndpoint(
|
||||||
|
id="ep1", name="ChatGPT Subscription",
|
||||||
|
base_url="https://chatgpt.com/backend-api/codex",
|
||||||
|
provider_auth_id="ghost", owner="alice",
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||||
|
db.delete(ep1)
|
||||||
|
# No other endpoint references "ghost" and no auth row exists → no-op, no error.
|
||||||
|
assert _delete_orphaned_provider_auth(db, "ghost", exclude_ep_id="ep1") is False
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_route(monkeypatch, TestSessionLocal):
|
||||||
|
"""Resolve the real DELETE /model-endpoints/{ep_id} route, wired to the test DB.
|
||||||
|
|
||||||
|
Neutralizes the route's unrelated cleanup side effects (settings/prefs files,
|
||||||
|
in-memory session manager) so the test stays hermetic and focuses on the
|
||||||
|
provider-auth revocation wiring.
|
||||||
|
"""
|
||||||
|
import routes.model_routes as mr
|
||||||
|
import routes.prefs_routes as prefs_routes
|
||||||
|
import src.ai_interaction as ai_interaction
|
||||||
|
|
||||||
|
monkeypatch.setattr(mr, "SessionLocal", TestSessionLocal)
|
||||||
|
monkeypatch.setattr(mr, "require_admin", lambda request: None)
|
||||||
|
monkeypatch.setattr(mr, "_load_settings", lambda: {})
|
||||||
|
monkeypatch.setattr(mr, "_save_settings", lambda settings: None)
|
||||||
|
monkeypatch.setattr(prefs_routes, "_load", lambda: {})
|
||||||
|
monkeypatch.setattr(prefs_routes, "_save", lambda prefs: None)
|
||||||
|
monkeypatch.setattr(ai_interaction, "get_session_manager", lambda: None)
|
||||||
|
|
||||||
|
router = mr.setup_model_routes(model_discovery=None)
|
||||||
|
for route in router.routes:
|
||||||
|
if getattr(route, "path", "") == "/api/model-endpoints/{ep_id}" and "DELETE" in getattr(route, "methods", set()):
|
||||||
|
return route.endpoint
|
||||||
|
raise AssertionError("DELETE /api/model-endpoints/{ep_id} not found")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_endpoint_route_revokes_orphaned_provider_auth(monkeypatch):
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1",))
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
delete_endpoint = _delete_route(monkeypatch, TestSessionLocal)
|
||||||
|
result = delete_endpoint("ep1", object())
|
||||||
|
|
||||||
|
assert result["deleted"] is True
|
||||||
|
# The last (only) endpoint backed by auth1 is gone, so the route revokes it.
|
||||||
|
assert result["cleared_provider_auth"] is True
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
|
||||||
|
assert db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first() is None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_endpoint_route_keeps_auth_when_shared(monkeypatch):
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
delete_endpoint = _delete_route(monkeypatch, TestSessionLocal)
|
||||||
|
result = delete_endpoint("ep1", object())
|
||||||
|
|
||||||
|
assert result["deleted"] is True
|
||||||
|
# ep2 still references auth1, so deleting ep1 must NOT revoke the credentials.
|
||||||
|
assert result["cleared_provider_auth"] is False
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_orphaned_provider_auth_revokes_only_after_last_of_several(monkeypatch):
|
||||||
|
from routes.model_routes import _delete_orphaned_provider_auth
|
||||||
|
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
_add_auth_and_endpoints(db, auth_id="auth1", ep_ids=("ep1", "ep2"))
|
||||||
|
|
||||||
|
# Delete ep1 first: ep2 still references auth1, so the row survives.
|
||||||
|
ep1 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep1").first()
|
||||||
|
db.delete(ep1)
|
||||||
|
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep1") is False
|
||||||
|
db.commit()
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is not None
|
||||||
|
|
||||||
|
# Now delete the last endpoint ep2: the auth row is finally cleared.
|
||||||
|
ep2 = db.query(ModelEndpoint).filter(ModelEndpoint.id == "ep2").first()
|
||||||
|
db.delete(ep2)
|
||||||
|
assert _delete_orphaned_provider_auth(db, "auth1", exclude_ep_id="ep2") is True
|
||||||
|
db.commit()
|
||||||
|
assert db.query(ProviderAuthSession).filter(ProviderAuthSession.id == "auth1").first() is None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
"""Shared device-flow route helper regressions."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from routes import device_flow
|
||||||
|
|
||||||
|
|
||||||
|
def _client(monkeypatch, now_ref, start_flow, poll_flow):
|
||||||
|
store = device_flow.PendingDeviceFlowStore(time_func=lambda: now_ref[0])
|
||||||
|
router = device_flow.create_device_flow_router(
|
||||||
|
prefix="/api/test-device",
|
||||||
|
tags=["test-device"],
|
||||||
|
store=store,
|
||||||
|
start_flow=start_flow,
|
||||||
|
poll_flow=poll_flow,
|
||||||
|
)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router)
|
||||||
|
monkeypatch.setattr(device_flow, "require_admin", lambda request: None)
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _start(_request, _form):
|
||||||
|
return device_flow.DeviceFlowStart(
|
||||||
|
pending={"secret": "server-only", "owner": "alice"},
|
||||||
|
response={"user_code": "ABCD-EFGH", "verification_uri": "https://example.test/device"},
|
||||||
|
interval=5,
|
||||||
|
expires_in=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pending_poll_is_throttled_until_interval(monkeypatch):
|
||||||
|
now = [100.0]
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def poll(_request, pending):
|
||||||
|
calls.append(dict(pending))
|
||||||
|
return device_flow.DeviceFlowPoll.pending()
|
||||||
|
|
||||||
|
client = _client(monkeypatch, now, _start, poll)
|
||||||
|
start = client.post("/api/test-device/device/start").json()
|
||||||
|
|
||||||
|
first = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
|
||||||
|
assert first.json() == {"status": "pending"}
|
||||||
|
assert calls == [{"secret": "server-only", "owner": "alice"}]
|
||||||
|
|
||||||
|
second = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
|
||||||
|
assert second.json() == {"status": "pending"}
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
now[0] += 5
|
||||||
|
third = client.post("/api/test-device/device/poll", data={"poll_id": start["poll_id"]})
|
||||||
|
assert third.json() == {"status": "pending"}
|
||||||
|
assert len(calls) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_slow_down_updates_poll_interval(monkeypatch):
|
||||||
|
now = [100.0]
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def poll(_request, _pending):
|
||||||
|
calls.append(now[0])
|
||||||
|
if len(calls) == 1:
|
||||||
|
return device_flow.DeviceFlowPoll.slow_down(interval=10)
|
||||||
|
return device_flow.DeviceFlowPoll.authorized({"id": "ep1", "models": ["gpt-4o"]})
|
||||||
|
|
||||||
|
client = _client(monkeypatch, now, _start, poll)
|
||||||
|
poll_id = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||||
|
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"}
|
||||||
|
now[0] += 9
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {"status": "pending"}
|
||||||
|
assert len(calls) == 1
|
||||||
|
|
||||||
|
now[0] += 1
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": poll_id}).json() == {
|
||||||
|
"status": "authorized",
|
||||||
|
"endpoint": {"id": "ep1", "models": ["gpt-4o"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_authorized_and_failed_polls_remove_pending_session(monkeypatch):
|
||||||
|
now = [100.0]
|
||||||
|
outcomes = [
|
||||||
|
device_flow.DeviceFlowPoll.authorized({"id": "ep1"}),
|
||||||
|
device_flow.DeviceFlowPoll.failed("access_denied"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def poll(_request, _pending):
|
||||||
|
return outcomes.pop(0)
|
||||||
|
|
||||||
|
client = _client(monkeypatch, now, _start, poll)
|
||||||
|
first = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||||
|
second = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||||
|
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": first}).json()["status"] == "authorized"
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": first}).status_code == 404
|
||||||
|
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": second}).json() == {
|
||||||
|
"status": "failed",
|
||||||
|
"error": "access_denied",
|
||||||
|
}
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": second}).status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_and_expiry_remove_pending_session(monkeypatch):
|
||||||
|
now = [100.0]
|
||||||
|
|
||||||
|
def poll(_request, _pending):
|
||||||
|
return device_flow.DeviceFlowPoll.pending()
|
||||||
|
|
||||||
|
client = _client(monkeypatch, now, _start, poll)
|
||||||
|
cancelled = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||||
|
assert client.post("/api/test-device/device/cancel", data={"poll_id": cancelled}).json() == {"status": "cancelled"}
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": cancelled}).status_code == 404
|
||||||
|
|
||||||
|
expired = client.post("/api/test-device/device/start").json()["poll_id"]
|
||||||
|
now[0] += 21
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": expired}).status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_routes_are_admin_gated(monkeypatch):
|
||||||
|
now = [100.0]
|
||||||
|
|
||||||
|
def poll(_request, _pending):
|
||||||
|
return device_flow.DeviceFlowPoll.pending()
|
||||||
|
|
||||||
|
client = _client(monkeypatch, now, _start, poll)
|
||||||
|
|
||||||
|
def deny(_request):
|
||||||
|
raise HTTPException(403, "admin required")
|
||||||
|
|
||||||
|
monkeypatch.setattr(device_flow, "require_admin", deny)
|
||||||
|
assert client.post("/api/test-device/device/start").status_code == 403
|
||||||
|
assert client.post("/api/test-device/device/poll", data={"poll_id": "missing"}).status_code == 403
|
||||||
|
assert client.post("/api/test-device/device/cancel", data={"poll_id": "missing"}).status_code == 403
|
||||||
@@ -25,32 +25,36 @@ from unittest.mock import MagicMock
|
|||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules
|
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state
|
||||||
|
|
||||||
# Match test_model_routes.py: if another test stubbed src.endpoint_resolver
|
with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"):
|
||||||
# during collection, drop the stub so the real URL helpers load here.
|
# Match test_model_routes.py: if another test stubbed src.endpoint_resolver
|
||||||
clear_fake_endpoint_resolver_modules()
|
# during collection, drop the stub so the real URL helpers load here.
|
||||||
|
clear_fake_endpoint_resolver_modules()
|
||||||
|
|
||||||
if "core.database" not in sys.modules:
|
if "core.database" not in sys.modules:
|
||||||
_core_db = types.ModuleType("core.database")
|
_core_db = types.ModuleType("core.database")
|
||||||
for _name in [
|
for _name in [
|
||||||
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
||||||
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
||||||
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer",
|
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun", "McpServer",
|
||||||
]:
|
"ProviderAuthSession", "Base",
|
||||||
setattr(_core_db, _name, MagicMock())
|
]:
|
||||||
sys.modules["core.database"] = _core_db
|
setattr(_core_db, _name, MagicMock())
|
||||||
|
_core_db.utcnow_naive = MagicMock()
|
||||||
|
sys.modules["core.database"] = _core_db
|
||||||
|
|
||||||
import routes.model_routes as model_routes
|
import routes.model_routes as model_routes
|
||||||
import src.endpoint_resolver as endpoint_resolver
|
import src.endpoint_resolver as endpoint_resolver
|
||||||
from routes.model_routes import (
|
from routes.model_routes import (
|
||||||
_probe_endpoint,
|
_probe_endpoint,
|
||||||
_ping_endpoint,
|
_ping_endpoint,
|
||||||
_probe_single_model,
|
_probe_single_model,
|
||||||
_classify_endpoint,
|
_resolve_probe_key,
|
||||||
_rewrite_loopback_for_docker,
|
_classify_endpoint,
|
||||||
_PROVIDER_CURATED,
|
_rewrite_loopback_for_docker,
|
||||||
)
|
_PROVIDER_CURATED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _patch_resolve(monkeypatch):
|
def _patch_resolve(monkeypatch):
|
||||||
@@ -117,6 +121,26 @@ class TestProbeEndpointParsing:
|
|||||||
)
|
)
|
||||||
assert _probe_endpoint("https://api.example.com/v1") == []
|
assert _probe_endpoint("https://api.example.com/v1") == []
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_probe_uses_discovery_only(self, monkeypatch):
|
||||||
|
_patch_resolve(monkeypatch)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_fetch(access_token, timeout=5):
|
||||||
|
calls.append((access_token, timeout))
|
||||||
|
return ["gpt-5.5"]
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", fake_fetch)
|
||||||
|
|
||||||
|
assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS", timeout=7) == ["gpt-5.5"]
|
||||||
|
assert calls == [("ACCESS", 7)]
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_probe_without_discovery_returns_empty(self, monkeypatch):
|
||||||
|
_patch_resolve(monkeypatch)
|
||||||
|
monkeypatch.setattr("src.chatgpt_subscription.fetch_available_models", lambda access_token, timeout=5: [])
|
||||||
|
|
||||||
|
assert _probe_endpoint("https://chatgpt.com/backend-api/codex", "ACCESS") == []
|
||||||
|
assert _probe_endpoint("https://chatgpt.com/backend-api/codex") == []
|
||||||
|
|
||||||
|
|
||||||
# ── _ping_endpoint: reachability classification ──
|
# ── _ping_endpoint: reachability classification ──
|
||||||
|
|
||||||
@@ -321,6 +345,51 @@ class TestProbeSingleModel:
|
|||||||
_probe_single_model("https://api.anthropic.com/v1", "sk-ant", "claude-sonnet-4-5", with_tools=True)
|
_probe_single_model("https://api.anthropic.com/v1", "sk-ant", "claude-sonnet-4-5", with_tools=True)
|
||||||
assert "input_schema" in captured["payload"]["tools"][0]
|
assert "input_schema" in captured["payload"]["tools"][0]
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_skips_completion_probe(self, monkeypatch):
|
||||||
|
# This provider speaks the Responses/Codex API. A chat-completions probe
|
||||||
|
# would 400 and (via the re-probe flow) hide every model, so it must be
|
||||||
|
# short-circuited as discovery-only without any HTTP call.
|
||||||
|
_patch_resolve(monkeypatch)
|
||||||
|
|
||||||
|
def boom(*args, **kwargs):
|
||||||
|
raise AssertionError("must not send a completion probe for chatgpt-subscription")
|
||||||
|
|
||||||
|
monkeypatch.setattr(model_routes.httpx, "post", boom)
|
||||||
|
result = _probe_single_model("https://chatgpt.com/backend-api/codex", None, "gpt-5.1-codex")
|
||||||
|
assert result["status"] == "ok"
|
||||||
|
assert result.get("skipped") is True
|
||||||
|
# Pin the full documented return shape — downstream JSON/UI reads latency_ms.
|
||||||
|
assert result["latency_ms"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── _resolve_probe_key: static key vs provider-auth runtime token ──
|
||||||
|
|
||||||
|
class TestResolveProbeKey:
|
||||||
|
def test_static_endpoint_uses_api_key(self):
|
||||||
|
ep = types.SimpleNamespace(id="e1", api_key="sk-static", provider_auth_id=None, owner=None)
|
||||||
|
assert _resolve_probe_key(ep) == "sk-static"
|
||||||
|
|
||||||
|
def test_provider_auth_endpoint_resolves_runtime_token(self, monkeypatch):
|
||||||
|
ep = types.SimpleNamespace(id="e2", api_key=None, provider_auth_id="auth123", owner="alice")
|
||||||
|
seen = {}
|
||||||
|
|
||||||
|
def fake_runtime(endpoint, owner=None):
|
||||||
|
seen["owner"] = owner
|
||||||
|
return ("https://chatgpt.com/backend-api/codex", "live-bearer")
|
||||||
|
|
||||||
|
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", fake_runtime)
|
||||||
|
assert _resolve_probe_key(ep) == "live-bearer"
|
||||||
|
assert seen["owner"] == "alice"
|
||||||
|
|
||||||
|
def test_provider_auth_resolution_failure_returns_none(self, monkeypatch):
|
||||||
|
ep = types.SimpleNamespace(id="e3", api_key=None, provider_auth_id="auth123", owner=None)
|
||||||
|
|
||||||
|
def boom(endpoint, owner=None):
|
||||||
|
raise RuntimeError("reauth required")
|
||||||
|
|
||||||
|
monkeypatch.setattr(endpoint_resolver, "resolve_endpoint_runtime", boom)
|
||||||
|
assert _resolve_probe_key(ep) is None
|
||||||
|
|
||||||
|
|
||||||
# ── _classify_endpoint: Tailscale CGNAT range ──
|
# ── _classify_endpoint: Tailscale CGNAT range ──
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,28 @@ def test_normal_model_payload_keeps_temperature_above_one(monkeypatch):
|
|||||||
assert payload["temperature"] == 1.2
|
assert payload["temperature"] == 1.2
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_payload_uses_max_output_tokens():
|
||||||
|
payload = llm_core._build_chatgpt_responses_payload(
|
||||||
|
"gpt-5.1-codex",
|
||||||
|
[{"role": "user", "content": "Say OK"}],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=37,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload["max_output_tokens"] == 37
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_payload_omits_empty_max_output_tokens():
|
||||||
|
payload = llm_core._build_chatgpt_responses_payload(
|
||||||
|
"gpt-5.1-codex",
|
||||||
|
[{"role": "user", "content": "Say OK"}],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "max_output_tokens" not in payload
|
||||||
|
|
||||||
|
|
||||||
def _anthropic_payload(temperature):
|
def _anthropic_payload(temperature):
|
||||||
return llm_core._build_anthropic_payload(
|
return llm_core._build_anthropic_payload(
|
||||||
"claude-3-5-sonnet",
|
"claude-3-5-sonnet",
|
||||||
|
|||||||
+92
-42
@@ -11,49 +11,51 @@ from types import SimpleNamespace
|
|||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules
|
from tests.helpers.import_state import clear_fake_endpoint_resolver_modules, preserve_import_state
|
||||||
|
|
||||||
# Other tests stub this module during collection. These helper tests need
|
with preserve_import_state("core.database", "src.database", "core.session_manager", "routes.model_routes"):
|
||||||
# the real URL normalization helpers so Anthropic /v1 handling is covered.
|
# Other tests stub this module during collection. These helper tests need
|
||||||
clear_fake_endpoint_resolver_modules()
|
# the real URL normalization helpers so Anthropic /v1 handling is covered.
|
||||||
|
clear_fake_endpoint_resolver_modules()
|
||||||
|
|
||||||
if "core.database" not in sys.modules:
|
if "core.database" not in sys.modules:
|
||||||
_core_db = types.ModuleType("core.database")
|
_core_db = types.ModuleType("core.database")
|
||||||
for _name in [
|
for _name in [
|
||||||
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
"SessionLocal", "ModelEndpoint", "Session", "ChatMessage", "Document",
|
||||||
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
"DocumentVersion", "GalleryImage", "GalleryAlbum", "Note",
|
||||||
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun",
|
"CalendarCal", "CalendarEvent", "ScheduledTask", "TaskRun",
|
||||||
"McpServer",
|
"McpServer", "ProviderAuthSession", "Base",
|
||||||
]:
|
]:
|
||||||
setattr(_core_db, _name, MagicMock())
|
setattr(_core_db, _name, MagicMock())
|
||||||
sys.modules["core.database"] = _core_db
|
_core_db.utcnow_naive = MagicMock()
|
||||||
|
sys.modules["core.database"] = _core_db
|
||||||
|
|
||||||
import routes.model_routes as model_routes
|
import routes.model_routes as model_routes
|
||||||
import src.database as src_database
|
import src.database as src_database
|
||||||
import src.endpoint_resolver as endpoint_resolver
|
import src.endpoint_resolver as endpoint_resolver
|
||||||
import src.llm_core as llm_core
|
import src.llm_core as llm_core
|
||||||
from routes.model_routes import (
|
from routes.model_routes import (
|
||||||
_match_provider_curated,
|
_match_provider_curated,
|
||||||
_curate_models,
|
_curate_models,
|
||||||
_visible_models,
|
_visible_models,
|
||||||
_normalize_model_ids,
|
_normalize_model_ids,
|
||||||
_api_key_fingerprint,
|
_api_key_fingerprint,
|
||||||
_is_chat_model,
|
_is_chat_model,
|
||||||
_classify_endpoint,
|
_classify_endpoint,
|
||||||
_effective_endpoint_kind,
|
_effective_endpoint_kind,
|
||||||
_probe_endpoint,
|
_probe_endpoint,
|
||||||
_ping_endpoint,
|
_ping_endpoint,
|
||||||
_parse_model_list,
|
_parse_model_list,
|
||||||
_normalize_refresh_mode,
|
_normalize_refresh_mode,
|
||||||
_truthy,
|
_truthy,
|
||||||
_speech_settings_using_endpoint,
|
_speech_settings_using_endpoint,
|
||||||
_clear_speech_settings_for_endpoint,
|
_clear_speech_settings_for_endpoint,
|
||||||
_endpoint_settings_using_endpoint,
|
_endpoint_settings_using_endpoint,
|
||||||
_clear_endpoint_settings_for_endpoint,
|
_clear_endpoint_settings_for_endpoint,
|
||||||
_clear_user_pref_endpoint_refs,
|
_clear_user_pref_endpoint_refs,
|
||||||
_PROVIDER_CURATED,
|
_PROVIDER_CURATED,
|
||||||
)
|
)
|
||||||
from src.llm_core import ANTHROPIC_MODELS
|
from src.llm_core import ANTHROPIC_MODELS
|
||||||
|
|
||||||
|
|
||||||
# ── speech endpoint settings ──
|
# ── speech endpoint settings ──
|
||||||
@@ -687,8 +689,7 @@ class _PinnedFakeRequest:
|
|||||||
|
|
||||||
|
|
||||||
def _get_route(path, method):
|
def _get_route(path, method):
|
||||||
from routes.model_routes import setup_model_routes
|
router = model_routes.setup_model_routes(model_discovery=None)
|
||||||
router = setup_model_routes(model_discovery=None)
|
|
||||||
for route in router.routes:
|
for route in router.routes:
|
||||||
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
if getattr(route, "path", "") == path and method in getattr(route, "methods", set()):
|
||||||
return route.endpoint
|
return route.endpoint
|
||||||
@@ -787,6 +788,55 @@ def test_reprobe_preserves_pinned_models(monkeypatch):
|
|||||||
assert json.loads(ep.cached_models) == ["m1"]
|
assert json.loads(ep.cached_models) == ["m1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reprobe_chatgpt_subscription_does_not_hide_models(monkeypatch):
|
||||||
|
# The whole point of the _probe_single_model short-circuit is that re-probing
|
||||||
|
# a chatgpt-subscription endpoint must NOT mark every (un-probeable) model as
|
||||||
|
# failed and write them all into hidden_models. Assert that end-to-end at the
|
||||||
|
# route level, with the REAL _probe_single_model doing the skip.
|
||||||
|
ep = _make_endpoint(
|
||||||
|
base_url="https://chatgpt.com/backend-api/codex",
|
||||||
|
api_key=None,
|
||||||
|
hidden_models=json.dumps(["stale-hidden"]),
|
||||||
|
)
|
||||||
|
db = _PinnedFakeDb([ep])
|
||||||
|
monkeypatch.setattr(model_routes, "SessionLocal", lambda: db)
|
||||||
|
monkeypatch.setattr(model_routes, "require_admin", lambda request: None)
|
||||||
|
monkeypatch.setattr(model_routes, "_normalize_base", lambda url: url.rstrip("/"))
|
||||||
|
monkeypatch.setattr(model_routes, "_probe_endpoint", lambda *a, **k: ["gpt-5.1-codex", "gpt-5.1"])
|
||||||
|
monkeypatch.setattr(model_routes, "_is_chat_model", lambda m: True)
|
||||||
|
# Any completion probe would be a bug for this provider.
|
||||||
|
monkeypatch.setattr(
|
||||||
|
model_routes.httpx, "post",
|
||||||
|
lambda *a, **k: (_ for _ in ()).throw(AssertionError("must not probe chatgpt-subscription")),
|
||||||
|
)
|
||||||
|
endpoint = _get_route("/api/model-endpoints/{ep_id}/probe", "GET")
|
||||||
|
|
||||||
|
response = endpoint("ep1", _PinnedFakeRequest())
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
async def _drain():
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
chunks.append(chunk.decode() if isinstance(chunk, bytes) else chunk)
|
||||||
|
|
||||||
|
asyncio.run(_drain())
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for chunk in chunks:
|
||||||
|
for line in chunk.splitlines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
events.append(json.loads(line[len("data: "):]))
|
||||||
|
|
||||||
|
done = next(e for e in events if e.get("type") == "probe_done")
|
||||||
|
results = [e for e in events if e.get("type") == "probe_result"]
|
||||||
|
|
||||||
|
# Every model was skipped as ok; none failed → nothing hidden.
|
||||||
|
assert done["hidden"] == 0
|
||||||
|
assert done["ok"] == len(results) == 2
|
||||||
|
assert all(r["status"] == "ok" and r.get("skipped") is True for r in results)
|
||||||
|
# The stale hidden_models is cleared, not repopulated with every model.
|
||||||
|
assert ep.hidden_models is None
|
||||||
|
|
||||||
|
|
||||||
def test_visible_models_handles_malformed_strings():
|
def test_visible_models_handles_malformed_strings():
|
||||||
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
# Non-JSON cached/pinned strings are treated as comma/newline lists and
|
||||||
# never raise; a malformed hidden string is normalized too.
|
# never raise; a malformed hidden string is normalized too.
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ class TestHostMatch:
|
|||||||
|
|
||||||
|
|
||||||
class TestDetectProviderRealHosts:
|
class TestDetectProviderRealHosts:
|
||||||
|
def test_chatgpt_subscription_codex_backend(self):
|
||||||
|
assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex") == "chatgpt-subscription"
|
||||||
|
assert llm_core._detect_provider("https://chatgpt.com/backend-api/codex/responses") == "chatgpt-subscription"
|
||||||
|
|
||||||
def test_anthropic(self):
|
def test_anthropic(self):
|
||||||
assert llm_core._detect_provider("https://api.anthropic.com") == "anthropic"
|
assert llm_core._detect_provider("https://api.anthropic.com") == "anthropic"
|
||||||
|
|
||||||
@@ -93,6 +97,12 @@ class TestBuildersRejectLookalikeHosts:
|
|||||||
def test_real_anthropic_chat(self):
|
def test_real_anthropic_chat(self):
|
||||||
assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages"
|
assert build_chat_url("https://api.anthropic.com") == "https://api.anthropic.com/v1/messages"
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_chat_uses_responses(self):
|
||||||
|
assert build_chat_url("https://chatgpt.com/backend-api/codex") == "https://chatgpt.com/backend-api/codex/responses"
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_models_uses_no_live_probe(self):
|
||||||
|
assert build_models_url("https://chatgpt.com/backend-api/codex") is None
|
||||||
|
|
||||||
def test_lookalike_anthropic_chat_is_openai(self):
|
def test_lookalike_anthropic_chat_is_openai(self):
|
||||||
assert build_chat_url("https://notanthropic.com") == "https://notanthropic.com/chat/completions"
|
assert build_chat_url("https://notanthropic.com") == "https://notanthropic.com/chat/completions"
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,157 @@
|
|||||||
|
"""Node-driven tests for the shared provider device-flow runner."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
_REPO = Path(__file__).resolve().parent.parent
|
||||||
|
_HELPER = _REPO / "static" / "js" / "providerDeviceFlow.js"
|
||||||
|
pytestmark = pytest.mark.skipif(not shutil.which("node"), reason="node not on PATH")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_node(script: str):
|
||||||
|
proc = subprocess.run(
|
||||||
|
["node", "--input-type=module"],
|
||||||
|
input=script,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
cwd=str(_REPO),
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
assert proc.returncode == 0, proc.stderr
|
||||||
|
return json.loads(proc.stdout.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def test_copilot_success_uses_complete_verification_uri():
|
||||||
|
js = f"""
|
||||||
|
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||||
|
const calls = [];
|
||||||
|
const opened = [];
|
||||||
|
let polls = 0;
|
||||||
|
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||||
|
const fetchImpl = async (url) => {{
|
||||||
|
calls.push(url);
|
||||||
|
if (url.endsWith('/device/start')) {{
|
||||||
|
return response(true, 200, {{
|
||||||
|
poll_id: 'poll-1',
|
||||||
|
user_code: 'GH-CODE',
|
||||||
|
verification_uri: 'https://github.com/login/device',
|
||||||
|
verification_uri_complete: 'https://github.com/login/device?user_code=GH-CODE',
|
||||||
|
interval: 2,
|
||||||
|
expires_in: 30,
|
||||||
|
}});
|
||||||
|
}}
|
||||||
|
polls += 1;
|
||||||
|
return response(true, 200, polls === 1
|
||||||
|
? {{ status: 'pending' }}
|
||||||
|
: {{ status: 'authorized', endpoint: {{ id: 'ep1', models: ['gpt-4o'] }} }}
|
||||||
|
);
|
||||||
|
}};
|
||||||
|
const result = await runProviderDeviceFlow('copilot', {{
|
||||||
|
fetchImpl,
|
||||||
|
openWindow: (url) => opened.push(url),
|
||||||
|
sleep: async () => {{}},
|
||||||
|
now: () => 0,
|
||||||
|
}});
|
||||||
|
console.log(JSON.stringify({{ result, calls, opened }}));
|
||||||
|
"""
|
||||||
|
out = _run_node(js)
|
||||||
|
assert out["result"]["status"] == "authorized"
|
||||||
|
assert out["result"]["endpoint"]["id"] == "ep1"
|
||||||
|
assert out["opened"] == ["https://github.com/login/device?user_code=GH-CODE"]
|
||||||
|
assert out["calls"] == ["/api/copilot/device/start", "/api/copilot/device/poll", "/api/copilot/device/poll"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_success_uses_plain_verification_uri():
|
||||||
|
js = f"""
|
||||||
|
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||||
|
const opened = [];
|
||||||
|
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||||
|
const fetchImpl = async (url) => {{
|
||||||
|
if (url.endsWith('/device/start')) {{
|
||||||
|
return response(true, 200, {{
|
||||||
|
poll_id: 'poll-1',
|
||||||
|
user_code: 'OA-CODE',
|
||||||
|
verification_uri: 'https://auth.openai.com/codex/device',
|
||||||
|
interval: 2,
|
||||||
|
expires_in: 30,
|
||||||
|
}});
|
||||||
|
}}
|
||||||
|
return response(true, 200, {{ status: 'authorized', endpoint: {{ id: 'chatgpt', models: ['gpt-5.5'] }} }});
|
||||||
|
}};
|
||||||
|
const result = await runProviderDeviceFlow('chatgpt-subscription', {{
|
||||||
|
fetchImpl,
|
||||||
|
openWindow: (url) => opened.push(url),
|
||||||
|
sleep: async () => {{}},
|
||||||
|
now: () => 0,
|
||||||
|
}});
|
||||||
|
console.log(JSON.stringify({{ result, opened }}));
|
||||||
|
"""
|
||||||
|
out = _run_node(js)
|
||||||
|
assert out["result"]["status"] == "authorized"
|
||||||
|
assert out["opened"] == ["https://auth.openai.com/codex/device"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_errors_surface_backend_detail():
|
||||||
|
js = f"""
|
||||||
|
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||||
|
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||||
|
try {{
|
||||||
|
await runProviderDeviceFlow('copilot', {{
|
||||||
|
fetchImpl: async () => response(false, 502, {{ detail: 'GitHub device-code request failed: upstream down' }}),
|
||||||
|
openWindow: () => {{}},
|
||||||
|
sleep: async () => {{}},
|
||||||
|
now: () => 0,
|
||||||
|
}});
|
||||||
|
}} catch (err) {{
|
||||||
|
console.log(JSON.stringify({{ message: err.message }}));
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
out = _run_node(js)
|
||||||
|
assert out["message"] == "GitHub device-code request failed: upstream down"
|
||||||
|
|
||||||
|
|
||||||
|
def test_thrown_fetch_errors_are_preserved():
|
||||||
|
js = f"""
|
||||||
|
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||||
|
try {{
|
||||||
|
await runProviderDeviceFlow('chatgpt-subscription', {{
|
||||||
|
fetchImpl: async () => {{ throw new Error('network offline'); }},
|
||||||
|
openWindow: () => {{}},
|
||||||
|
sleep: async () => {{}},
|
||||||
|
now: () => 0,
|
||||||
|
}});
|
||||||
|
}} catch (err) {{
|
||||||
|
console.log(JSON.stringify({{ message: err.message }}));
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
out = _run_node(js)
|
||||||
|
assert out["message"] == "network offline"
|
||||||
|
|
||||||
|
|
||||||
|
def test_expired_flow_returns_expired_status():
|
||||||
|
js = f"""
|
||||||
|
import {{ runProviderDeviceFlow }} from '{_HELPER.as_posix()}';
|
||||||
|
let currentTime = 0;
|
||||||
|
const response = (ok, status, payload) => ({{ ok, status, async json() {{ return payload; }} }});
|
||||||
|
const result = await runProviderDeviceFlow('copilot', {{
|
||||||
|
fetchImpl: async (url) => url.endsWith('/device/start')
|
||||||
|
? response(true, 200, {{
|
||||||
|
poll_id: 'poll-1',
|
||||||
|
user_code: 'GH-CODE',
|
||||||
|
verification_uri: 'https://github.com/login/device',
|
||||||
|
interval: 2,
|
||||||
|
expires_in: 1,
|
||||||
|
}})
|
||||||
|
: response(true, 200, {{ status: 'pending' }}),
|
||||||
|
openWindow: () => {{}},
|
||||||
|
sleep: async () => {{ currentTime += 2000; }},
|
||||||
|
now: () => currentTime,
|
||||||
|
}});
|
||||||
|
console.log(JSON.stringify(result));
|
||||||
|
"""
|
||||||
|
out = _run_node(js)
|
||||||
|
assert out == {"status": "expired"}
|
||||||
@@ -24,7 +24,7 @@ _sd = types.ModuleType("src.database")
|
|||||||
_sd.ModelEndpoint = MagicMock()
|
_sd.ModelEndpoint = MagicMock()
|
||||||
sys.modules.setdefault("src.database", _sd)
|
sys.modules.setdefault("src.database", _sd)
|
||||||
|
|
||||||
from routes.research_routes import _owned_enabled_endpoint # noqa: E402
|
from routes.research_routes import _owned_enabled_endpoint, _resolve_endpoint_runtime # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
class _Predicate:
|
class _Predicate:
|
||||||
@@ -129,3 +129,29 @@ def test_null_owner_is_legacy_single_user_noop():
|
|||||||
rows = [_ep("ep-x", "bob"), _ep("ep-y", "alice")]
|
rows = [_ep("ep-x", "bob"), _ep("ep-y", "alice")]
|
||||||
ep = _resolve(rows, None, "ep-x")
|
ep = _resolve(rows, None, "ep-x")
|
||||||
assert ep is not None and ep.id == "ep-x"
|
assert ep is not None and ep.id == "ep-x"
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_resolution_uses_provider_auth_for_chatgpt_subscription(monkeypatch):
|
||||||
|
ep = SimpleNamespace(
|
||||||
|
id="ep-chatgpt",
|
||||||
|
owner="alice",
|
||||||
|
base_url="https://chatgpt.com/backend-api/codex",
|
||||||
|
api_key=None,
|
||||||
|
provider_auth_id="auth-1",
|
||||||
|
cached_models='["gpt-5.5"]',
|
||||||
|
hidden_models=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"src.chatgpt_subscription.resolve_runtime_credentials",
|
||||||
|
lambda auth_id, owner=None: {
|
||||||
|
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||||
|
"api_key": "fresh-access-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
url, model, headers = _resolve_endpoint_runtime(ep, owner="alice", model="")
|
||||||
|
|
||||||
|
assert url == "https://chatgpt.com/backend-api/codex/responses"
|
||||||
|
assert model == "gpt-5.5"
|
||||||
|
assert headers["Authorization"] == "Bearer fresh-access-token"
|
||||||
|
|||||||
@@ -0,0 +1,215 @@
|
|||||||
|
"""resolve_session_auth must not persist the ChatGPT Subscription bearer.
|
||||||
|
|
||||||
|
The ChatGPT Subscription access token is a short-lived OAuth bearer re-resolved
|
||||||
|
(and refreshed) on every request. resolve_session_auth() may set it on the
|
||||||
|
in-memory session for the current request, but it must never write it back into
|
||||||
|
the sessions table — otherwise the live token sits at rest as
|
||||||
|
"Authorization: Bearer ...". Only the encrypted refresh token in
|
||||||
|
ProviderAuthSession is allowed to persist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
import routes.chat_helpers as chat_helpers
|
||||||
|
import src.endpoint_resolver as endpoint_resolver
|
||||||
|
from core.database import Base, ModelEndpoint, Session as DbSession
|
||||||
|
|
||||||
|
_CODEX_BASE = "https://chatgpt.com/backend-api/codex"
|
||||||
|
|
||||||
|
|
||||||
|
def _mem_db(monkeypatch):
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
# Match production SessionLocal (core.database) which is autoflush=False.
|
||||||
|
TestSessionLocal = sessionmaker(bind=engine, autoflush=False)
|
||||||
|
monkeypatch.setattr(chat_helpers, "SessionLocal", TestSessionLocal)
|
||||||
|
return TestSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_auth_is_not_written_to_sessions_table(monkeypatch):
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
db.add(ModelEndpoint(
|
||||||
|
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
|
||||||
|
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
|
||||||
|
))
|
||||||
|
db.add(DbSession(
|
||||||
|
id="sess1", name="chat", endpoint_url=_CODEX_BASE,
|
||||||
|
model="gpt-5.1-codex", owner="alice", headers={},
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# A live access token is resolved at request time.
|
||||||
|
monkeypatch.setattr(
|
||||||
|
endpoint_resolver, "resolve_endpoint_runtime",
|
||||||
|
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
sess = types.SimpleNamespace(
|
||||||
|
id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex",
|
||||||
|
owner="alice", headers={},
|
||||||
|
)
|
||||||
|
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
|
||||||
|
|
||||||
|
# In-memory session got request-local auth for this request...
|
||||||
|
assert any(k.lower() == "authorization" for k in sess.headers)
|
||||||
|
assert sess.headers["Authorization"] == "Bearer live-access-token"
|
||||||
|
|
||||||
|
# ...but the DB row must NOT have the bearer persisted.
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||||
|
stored = row.headers or {}
|
||||||
|
assert not any(k.lower() == "authorization" for k in stored), (
|
||||||
|
f"ChatGPT bearer leaked into sessions table: {stored}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_subscription_auth_is_still_persisted_to_sessions_table(monkeypatch):
|
||||||
|
"""The early-return must be scoped to ChatGPT Subscription only.
|
||||||
|
|
||||||
|
Ordinary endpoints rely on resolve_session_auth() persisting the resolved
|
||||||
|
headers into the sessions table so they aren't re-resolved on every request.
|
||||||
|
If the is_chatgpt_subscription guard ever widened, this would silently break;
|
||||||
|
this test pins the persistence path as still reached for normal endpoints.
|
||||||
|
"""
|
||||||
|
base = "https://api.example.com/v1"
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
db.add(ModelEndpoint(
|
||||||
|
id="ep1", name="Generic", base_url=base,
|
||||||
|
owner="alice", is_enabled=True, api_key="sk-static",
|
||||||
|
))
|
||||||
|
db.add(DbSession(
|
||||||
|
id="sess1", name="chat", endpoint_url=base,
|
||||||
|
model="gpt-x", owner="alice", headers={},
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
endpoint_resolver, "resolve_endpoint_runtime",
|
||||||
|
lambda ep, owner=None: (base, "sk-static"),
|
||||||
|
)
|
||||||
|
|
||||||
|
sess = types.SimpleNamespace(
|
||||||
|
id="sess1", endpoint_url=base, model="gpt-x", owner="alice", headers={},
|
||||||
|
)
|
||||||
|
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
|
||||||
|
|
||||||
|
# In-memory session got auth...
|
||||||
|
assert any(k.lower() in ("authorization", "x-api-key") for k in sess.headers)
|
||||||
|
|
||||||
|
# ...AND it was persisted to the DB row (the normal, non-subscription path).
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||||
|
stored = row.headers or {}
|
||||||
|
assert any(k.lower() in ("authorization", "x-api-key") for k in stored), (
|
||||||
|
f"non-subscription auth was not persisted: {stored}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_clears_previously_persisted_bearer(monkeypatch):
|
||||||
|
"""A bearer left at rest by an older code path is stripped on next resolve."""
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
db.add(ModelEndpoint(
|
||||||
|
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
|
||||||
|
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
|
||||||
|
))
|
||||||
|
# Simulate the leak: a stale bearer already sitting in the sessions table.
|
||||||
|
db.add(DbSession(
|
||||||
|
id="sess1", name="chat", endpoint_url=_CODEX_BASE,
|
||||||
|
model="gpt-5.1-codex", owner="alice",
|
||||||
|
headers={"Authorization": "Bearer stale-leaked-token"},
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
endpoint_resolver,
|
||||||
|
"resolve_endpoint_runtime",
|
||||||
|
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
sess = types.SimpleNamespace(
|
||||||
|
id="sess1", endpoint_url=_CODEX_BASE, model="gpt-5.1-codex",
|
||||||
|
owner="alice", headers={},
|
||||||
|
)
|
||||||
|
chat_helpers.resolve_session_auth(sess, "sess1", owner="alice")
|
||||||
|
|
||||||
|
# The stale bearer must have been stripped from the DB row.
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||||
|
stored = row.headers or {}
|
||||||
|
assert not any(k.lower() == "authorization" for k in stored), (
|
||||||
|
f"stale ChatGPT bearer was not cleared: {stored}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_chatgpt_subscription_fallback_auth_is_not_written_to_sessions_table(monkeypatch):
|
||||||
|
"""Fallback endpoint selection must keep the resolved bearer request-local."""
|
||||||
|
TestSessionLocal = _mem_db(monkeypatch)
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
db.add(ModelEndpoint(
|
||||||
|
id="ep1", name="ChatGPT Subscription", base_url=_CODEX_BASE,
|
||||||
|
provider_auth_id="auth1", owner="alice", is_enabled=True, api_key=None,
|
||||||
|
cached_models='["gpt-5.1-codex"]',
|
||||||
|
))
|
||||||
|
db.add(DbSession(
|
||||||
|
id="sess1", name="chat", endpoint_url="https://old.example/v1",
|
||||||
|
model="old-model", owner="alice", headers={},
|
||||||
|
))
|
||||||
|
db.commit()
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
endpoint_resolver,
|
||||||
|
"resolve_endpoint_runtime",
|
||||||
|
lambda ep, owner=None: (_CODEX_BASE, "live-access-token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
sess = types.SimpleNamespace(
|
||||||
|
id="sess1", endpoint_url="https://old.example/v1", model="old-model",
|
||||||
|
owner="alice", headers={},
|
||||||
|
)
|
||||||
|
result = chat_helpers.try_fallback_endpoint(sess, "sess1")
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"model": "gpt-5.1-codex",
|
||||||
|
"endpoint_url": _CODEX_BASE + "/responses",
|
||||||
|
"endpoint_name": "ChatGPT Subscription",
|
||||||
|
}
|
||||||
|
assert sess.headers["Authorization"] == "Bearer live-access-token"
|
||||||
|
|
||||||
|
db = TestSessionLocal()
|
||||||
|
try:
|
||||||
|
row = db.query(DbSession).filter(DbSession.id == "sess1").first()
|
||||||
|
assert row.model == "gpt-5.1-codex"
|
||||||
|
assert row.endpoint_url == _CODEX_BASE + "/responses"
|
||||||
|
stored = row.headers or {}
|
||||||
|
assert not any(k.lower() == "authorization" for k in stored), (
|
||||||
|
f"ChatGPT fallback bearer leaked into sessions table: {stored}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
@@ -386,7 +386,7 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
|
|||||||
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
|
monkeypatch.setattr(chat_helpers, "add_user_message", fake_add_user_message)
|
||||||
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {})
|
monkeypatch.setattr(chat_helpers, "load_prefs_for_user", lambda user: {})
|
||||||
monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester")
|
monkeypatch.setattr(chat_helpers, "get_current_user", lambda request: "tester")
|
||||||
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model: None)
|
monkeypatch.setattr(chat_helpers, "normalize_model_id", lambda endpoint_url, model, **kwargs: None)
|
||||||
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
|
monkeypatch.setattr(chat_helpers, "maybe_compact", fake_maybe_compact)
|
||||||
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
|
monkeypatch.setattr(chat_helpers, "trim_for_context", lambda messages, context_length: messages)
|
||||||
|
|
||||||
|
|||||||
@@ -137,3 +137,12 @@ def test_unauthenticated_caller_rejected(monkeypatch):
|
|||||||
with pytest.raises(HTTPException) as exc:
|
with pytest.raises(HTTPException) as exc:
|
||||||
SR._verify_session_owner(req, "sid")
|
SR._verify_session_owner(req, "sid")
|
||||||
assert exc.value.status_code == 401
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_allows_owner_stamped_session(monkeypatch):
|
||||||
|
monkeypatch.setenv("AUTH_ENABLED", "false")
|
||||||
|
monkeypatch.setattr(SR, "SessionLocal", _session_local_returning("admin"))
|
||||||
|
req = _req(api_token=False, current_user=None)
|
||||||
|
|
||||||
|
# Single-user/auth-disabled mode should verify existence but not compare owner.
|
||||||
|
SR._verify_session_owner(req, "sid-owned-by-admin")
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
"""Static regressions for `/setup` account sign-in providers."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
_REPO = Path(__file__).resolve().parent.parent
|
||||||
|
_SLASH = (_REPO / "static" / "js" / "slashCommands.js").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _between(src: str, start: str, end: str) -> str:
|
||||||
|
start_idx = src.index(start)
|
||||||
|
end_idx = src.index(end, start_idx)
|
||||||
|
return src[start_idx:end_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_guide_lists_account_sign_in_providers():
|
||||||
|
guide_block = _between(_SLASH, "function _showSetupEndpointChoices", "async function _hasConfiguredModels")
|
||||||
|
|
||||||
|
assert 'data-setup-provider="' in _SLASH
|
||||||
|
assert "provider.key" in _SLASH
|
||||||
|
assert "'copilot'" in _SLASH
|
||||||
|
assert "'chatgpt-subscription'" in _SLASH
|
||||||
|
assert "/setup copilot" in _SLASH
|
||||||
|
assert "/setup chatgpt-subscription" in _SLASH
|
||||||
|
|
||||||
|
|
||||||
|
def test_clicking_account_sign_in_provider_prefills_setup_command_not_api_key():
|
||||||
|
click_block = _between(_SLASH, "const providerEl = e.target.closest('.setup-clickable-provider')", "// 3. Check")
|
||||||
|
|
||||||
|
assert "providerEl.dataset.setupProvider" in click_block
|
||||||
|
assert "providerEl.dataset.setupKind === 'device-auth'" in click_block
|
||||||
|
assert "'/setup ' + providerKey" in click_block
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_chatgpt_subscription_prints_auth_url_without_auto_opening_tab():
|
||||||
|
flow_block = _between(_SLASH, "async function _setupProviderDeviceFlow", "async function _cmdSetup")
|
||||||
|
|
||||||
|
assert "providerKey === 'chatgpt-subscription'" in flow_block
|
||||||
|
assert "Open this URL" in flow_block
|
||||||
|
assert "authUrl" in flow_block
|
||||||
|
assert 'href="\' + uiModule.esc(authUrl || \'\') + \'"' in flow_block
|
||||||
|
assert "if (providerKey === 'chatgpt-subscription') return;" in flow_block
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
"""Static regressions for slash autocomplete command-group expansion."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
_REPO = Path(__file__).resolve().parent.parent
|
||||||
|
_AC = (_REPO / "static" / "js" / "slashAutocomplete.js").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def test_exact_parent_command_expands_subcommands_before_top_level_row_cap():
|
||||||
|
assert "function _exactCommandGroupItems" in _AC
|
||||||
|
assert "entry.token.toLowerCase().startsWith(prefix)" in _AC
|
||||||
|
assert "items = groupItems.slice(0, MAX_VISIBLE);" in _AC
|
||||||
|
|
||||||
|
|
||||||
|
def test_setup_group_has_room_for_chatgpt_subscription_suggestion():
|
||||||
|
assert "const MAX_VISIBLE = 14;" in _AC
|
||||||
Reference in New Issue
Block a user