feat(mcp): add Streamable HTTP transport with OAuth 2.0 (#1033)

* feat(mcp): add Streamable HTTP transport with OAuth 2.0

  Odysseus could only reach MCP servers over stdio and SSE, so modern
  remote servers like https://mcp.higgsfield.ai/mcp (Streamable HTTP,
  gated behind OAuth) could not be connected.

  Add an `http` transport that connects via the SDK's
  streamablehttp_client and authenticates with the SDK's
  OAuthClientProvider: RFC 9728 protected-resource discovery, RFC 8414
  authorization-server metadata, Dynamic Client Registration,
  authorization-code + PKCE, and token refresh. A small bridge
  (src/mcp_oauth.py) connects the SDK's blocking callback to the existing
  web callback route via an asyncio.Future keyed by the OAuth `state`,
  and the dynamic client registration plus tokens persist per-server in a
  new encrypted `oauth_tokens` column.

  The connect runs as a bounded background task so the "Add server"
  request returns immediately; redirect_handler publishes needs_auth +
  auth_url to connection state as soon as discovery/DCR completes (which
  can exceed the bounded wait), and the UI polls until connected. Remote
  users finish via the existing paste-back flow. The Google OAuth path is
  left unchanged.

  - core/database.py: encrypted oauth_tokens column + migration
  - src/mcp_oauth.py: OAuth provider, DB-backed TokenStorage, state registry
  - src/mcp_manager.py: http dispatch, background connect, _connect_http
  - routes/mcp_routes.py: http validation, needs_auth/auth_url, callback bridge
  - static/js/settings.js: Streamable HTTP option + OAuth flow with polling
  - tests: 5 new unit tests (transport dispatch, registry, token storage)

  Verified against the live Higgsfield server: discovery, DCR (client_id
  issued), loopback redirect accepted, and a PKCE authorization URL with
  needs_auth status. No regressions (full suite delta is only the 5 added
  passing tests).

* fix(mcp): address PR #1033 review feedback

  - mcp_oauth: derive redirect URI from OAUTH_REDIRECT_BASE_URL/APP_PUBLIC_URL
    (default http://localhost:7000) instead of hardcoding the port
  - mcp_oauth: leave OAuth scope unset so the SDK derives it from the server's
    WWW-Authenticate/protected-resource metadata; hardcoding an OIDC scope broke
    non-OpenID MCP servers (verified: Higgsfield still gets its server-derived
    scope)
  - mcp_oauth: prune abandoned OAuth flows (_prune_stale + _pending_ts) so the
    module-level registries can't grow unbounded
  - mcp_oauth: persist tokens/client-info in a single DB session/commit
    (_update) instead of a load+save double round-trip
  - mcp_manager: cancel and drop the background connect task in
    disconnect_server so a deleted server stops publishing status
  - database: document why the oauth_tokens migration uses TEXT while the model
    declares EncryptedText (encryption is applied at the Python layer)
  - settings.js: surface persistent OAuth-poll failures and an explicit timeout
    message instead of silently swallowing errors
  - tests: cover the stale-flow pruning

* static/js/settings.js now shows an in-flight loading state on the buttons that fire requests:
This commit is contained in:
Abylaikhan Zulbukharov
2026-06-05 05:40:52 +05:00
committed by GitHub
parent 85334e8f3d
commit 1d80bf5e65
7 changed files with 519 additions and 11 deletions
+99 -2
View File
@@ -70,7 +70,9 @@ class McpManager:
self._sessions: Dict[str, Any] = {}
# server_id -> exit stack (for cleanup)
self._stacks: Dict[str, Any] = {}
# Tracking updates to tools/connections for RAG indexing
# server_id -> background connect task (HTTP transport / OAuth)
self._connect_tasks: Dict[str, Any] = {}
# Tracking updates to tools/connections for RAG indexing / prompt cache
self._generation = 0
async def connect_server(
@@ -83,12 +85,14 @@ class McpManager:
env: Optional[Dict[str, str]] = None,
url: Optional[str] = None,
) -> bool:
"""Connect to an MCP server via stdio or SSE transport."""
"""Connect to an MCP server via stdio, SSE, or Streamable HTTP transport."""
try:
if transport == "stdio":
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
elif transport == "sse":
res = await self._connect_sse(server_id, name, url)
elif transport == "http":
res = await self._start_http_connect(server_id, name, url)
else:
logger.error(f"Unknown MCP transport: {transport}")
res = False
@@ -211,8 +215,101 @@ class McpManager:
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
return False
async def _start_http_connect(self, server_id: str, name: str, url: str, wait: float = 8.0) -> bool:
"""Begin a Streamable HTTP connect in the background. Returns within
`wait` seconds: True if it connected (cached-token path), otherwise the
flow is awaiting browser authorization and status becomes 'needs_auth'."""
import asyncio
self._connections[server_id] = {"status": "connecting", "name": name, "transport": "http"}
task = asyncio.create_task(self._connect_http(server_id, name, url))
self._connect_tasks[server_id] = task
done, _ = await asyncio.wait({task}, timeout=wait)
if task in done:
try:
return task.result()
except Exception as e:
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
return False
# Still running → either awaiting authorization, or discovery/DCR is
# still in flight. If _on_redirect already published needs_auth+auth_url,
# leave it; otherwise mark needs_auth (auth_url filled in once it fires).
from src.mcp_oauth import pop_auth_url
cur = self._connections.get(server_id, {})
if cur.get("status") != "needs_auth":
self._connections[server_id] = {
"status": "needs_auth", "name": name, "transport": "http",
"auth_url": pop_auth_url(server_id),
}
return False
async def _connect_http(self, server_id: str, name: str, url: str) -> bool:
"""Connect to a Streamable HTTP MCP server (with automatic OAuth)."""
try:
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from contextlib import AsyncExitStack
from src.mcp_oauth import build_provider, clear_auth_url
def _on_redirect(auth_url):
# Publish needs_auth the moment the URL is known, independent of
# how long discovery/DCR took (may exceed the bounded start wait).
self._connections[server_id] = {
"status": "needs_auth", "name": name, "transport": "http",
"auth_url": auth_url,
}
provider = build_provider(server_id, url, on_redirect=_on_redirect)
stack = AsyncExitStack()
transport = await stack.enter_async_context(streamablehttp_client(url, auth=provider))
read_stream, write_stream, _get_session_id = transport
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
await session.initialize()
tools_result = await session.list_tools()
tools = []
for tool in tools_result.tools:
tools.append({
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
})
self._sessions[server_id] = session
self._stacks[server_id] = stack
self._tools[server_id] = tools
self._connections[server_id] = {
"status": "connected", "name": name, "transport": "http",
"tool_count": len(tools),
}
clear_auth_url(server_id)
# Tools changed (this can complete after connect_server already
# returned, via the background OAuth flow), so bump the generation
# to invalidate the tool-prompt cache.
self._generation += 1
logger.info(f"MCP server connected: {name} ({server_id}) - {len(tools)} tools via http")
return True
except ImportError:
logger.warning("MCP package not installed. Install with: pip install mcp")
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
return False
except Exception as e:
logger.error(f"Failed to connect HTTP MCP server {name} ({server_id}): {e}")
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
return False
async def disconnect_server(self, server_id: str):
"""Disconnect from an MCP server."""
# Cancel any in-flight HTTP/OAuth background connect so it stops
# publishing status for a server that may be getting deleted.
task = self._connect_tasks.pop(server_id, None)
if task is not None and not task.done():
task.cancel()
try:
from src.mcp_oauth import clear_auth_url
clear_auth_url(server_id)
except Exception:
pass
stack = self._stacks.pop(server_id, None)
if stack:
try:
+193
View File
@@ -0,0 +1,193 @@
"""mcp_oauth.py — generic OAuth for remote (Streamable HTTP) MCP servers.
Bridges the mcp SDK's OAuthClientProvider (RFC 9728 discovery, Dynamic Client
Registration, authorization-code + PKCE, token refresh) to Odysseus's web
callback route. Tokens and the dynamic registration persist per-server,
encrypted, so the interactive flow runs only once.
"""
import asyncio
import json
import logging
import os
import time
from typing import Dict, Optional, Tuple
from urllib.parse import urlparse, parse_qs
logger = logging.getLogger(__name__)
# OAuth redirect URI registered with every authorization server via DCR. Loopback
# is allowed for native/desktop clients (RFC 8252); remote users finish via the
# paste-back flow. Deployments not reachable at http://localhost:7000 (custom
# port, reverse proxy, or public domain) must set OAUTH_REDIRECT_BASE_URL (or
# APP_PUBLIC_URL) to their externally reachable origin so the redirect lands back
# on Odysseus. APP_PORT is intentionally not used: it is only the Docker host
# port-map; the app always listens on 7000 inside the container.
_REDIRECT_BASE = (
os.environ.get("OAUTH_REDIRECT_BASE_URL")
or os.environ.get("APP_PUBLIC_URL")
or "http://localhost:7000"
).rstrip("/")
REDIRECT_URI = f"{_REDIRECT_BASE}/api/mcp/oauth/callback"
# How long the background connect waits for the user to authorize before giving up.
AUTH_WAIT_SECONDS = 300
_pending: Dict[str, asyncio.Future] = {} # state -> Future[(code, state)]
_pending_ts: Dict[str, float] = {} # state -> monotonic timestamp, for pruning
_auth_urls: Dict[str, str] = {} # server_id -> authorization URL
def _prune_stale() -> None:
"""Drop abandoned flows whose authorization window has elapsed so the
module-level registries don't grow unbounded (e.g. a user who never
finishes the browser step)."""
now = time.monotonic()
for state in [s for s, ts in _pending_ts.items() if now - ts > AUTH_WAIT_SECONDS]:
fut = _pending.pop(state, None)
_pending_ts.pop(state, None)
if fut is not None and not fut.done():
fut.cancel()
def _discard_pending(state: Optional[str]) -> None:
if state is None:
return
_pending.pop(state, None)
_pending_ts.pop(state, None)
def register_pending(state: str) -> asyncio.Future:
_prune_stale()
fut = asyncio.get_running_loop().create_future()
_pending[state] = fut
_pending_ts[state] = time.monotonic()
return fut
def resolve_pending(state: str, code: str) -> bool:
fut = _pending.get(state)
if fut is not None and not fut.done():
fut.set_result((code, state))
return True
return False
def pop_auth_url(server_id: str) -> Optional[str]:
return _auth_urls.get(server_id)
def clear_auth_url(server_id: str) -> None:
_auth_urls.pop(server_id, None)
class DbTokenStorage:
"""SDK TokenStorage backed by the encrypted McpServer.oauth_tokens column."""
def __init__(self, server_id: str, session_factory=None):
self.server_id = server_id
if session_factory is None:
from core.database import SessionLocal
session_factory = SessionLocal
self._sf = session_factory
def _load(self) -> dict:
from core.database import McpServer
db = self._sf()
try:
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
if srv and srv.oauth_tokens:
return json.loads(srv.oauth_tokens)
finally:
db.close()
return {}
def _update(self, key: str, value: dict) -> None:
"""Load, set one key, and persist the oauth_tokens JSON in a single
session/commit (avoids the load+save double round-trip per write)."""
from core.database import McpServer
db = self._sf()
try:
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
if srv is None:
return
data = json.loads(srv.oauth_tokens) if srv.oauth_tokens else {}
data[key] = value
srv.oauth_tokens = json.dumps(data)
db.commit()
finally:
db.close()
async def get_tokens(self):
from mcp.shared.auth import OAuthToken
data = self._load().get("tokens")
return OAuthToken.model_validate(data) if data else None
async def set_tokens(self, tokens) -> None:
self._update("tokens", json.loads(tokens.model_dump_json()))
async def get_client_info(self):
from mcp.shared.auth import OAuthClientInformationFull
data = self._load().get("client_info")
return OAuthClientInformationFull.model_validate(data) if data else None
async def set_client_info(self, client_info) -> None:
self._update("client_info", json.loads(client_info.model_dump_json()))
def build_provider(server_id: str, url: str, on_redirect=None):
"""Construct an OAuthClientProvider that drives the browser flow via the
Odysseus callback route.
on_redirect(authorization_url): optional sync callback invoked the moment
the authorization URL is known (after discovery + DCR). The manager uses it
to publish 'needs_auth' + auth_url to connection state regardless of how
long discovery/DCR took.
"""
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import OAuthClientMetadata
client_metadata = OAuthClientMetadata(
client_name="Odysseus",
redirect_uris=[REDIRECT_URI],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
# Leave scope unset: the SDK applies the MCP scope-selection strategy and
# overwrites this from the server's WWW-Authenticate / protected-resource
# metadata before building the auth URL. Hardcoding an OIDC scope here
# would break the many MCP servers that are not OpenID providers.
scope=None,
token_endpoint_auth_method="none",
)
async def redirect_handler(authorization_url: str) -> None:
state = (parse_qs(urlparse(authorization_url).query).get("state") or [None])[0]
if state:
register_pending(state)
_auth_urls[server_id] = authorization_url
if on_redirect is not None:
try:
on_redirect(authorization_url)
except Exception as e:
logger.warning(f"MCP OAuth on_redirect callback failed: {e}")
logger.info(f"MCP OAuth: server {server_id} awaiting authorization (state={state})")
async def callback_handler() -> Tuple[str, Optional[str]]:
auth_url = _auth_urls.get(server_id)
state = (parse_qs(urlparse(auth_url).query).get("state") or [None])[0] if auth_url else None
fut = _pending.get(state)
if fut is None:
raise RuntimeError("No pending OAuth flow for this server")
try:
code, ret_state = await asyncio.wait_for(fut, timeout=AUTH_WAIT_SECONDS)
return code, ret_state
finally:
_discard_pending(state)
_auth_urls.pop(server_id, None)
return OAuthClientProvider(
server_url=url,
client_metadata=client_metadata,
storage=DbTokenStorage(server_id),
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)