From ecbf74dbea59534cebdd3ce1a40749fcfd68133e Mon Sep 17 00:00:00 2001 From: Taylor Wilsdon Date: Sat, 18 Oct 2025 16:53:44 -0400 Subject: [PATCH] Added a preflight authorize check that automatically re-registers MCP OAuth clients when the stored client ID no longer exists on the server, so the browser flow never hits the stale-ID failure --- backend/open_webui/utils/oauth.py | 206 +++++++++++++++++++++++++++++- 1 file changed, 204 insertions(+), 2 deletions(-) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 85f694cb13..d24a379ede 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1,4 +1,5 @@ import base64 +import copy import hashlib import logging import mimetypes @@ -417,6 +418,205 @@ class OAuthClientManager: log.info(f"Removed OAuth client {client_id}") return True + def _find_mcp_connection(self, request, client_id: str): + try: + connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or [] + except Exception: + connections = [] + + normalized_client_id = client_id.split(":")[-1] + + for idx, connection in enumerate(connections): + if not isinstance(connection, dict): + continue + if connection.get("type") != "mcp": + continue + + info = connection.get("info") or {} + server_id = info.get("id") + if not server_id: + continue + + normalized_server_id = server_id.split(":")[-1] + if normalized_server_id == normalized_client_id: + return idx, connection + + return None, None + + async def _preflight_authorization_url( + self, client, client_info: OAuthClientInformationFull + ) -> bool: + # Only perform preflight checks for Starlette OAuth clients + if not hasattr(client, "create_authorization_url"): + return True + + redirect_uri = None + if client_info.redirect_uris: + redirect_uri = str(client_info.redirect_uris[0]) + + try: + auth_data = await client.create_authorization_url(redirect_uri=redirect_uri) + authorize_url = auth_data.get("url") + if not authorize_url: + return True + except Exception as e: + log.debug( + "Skipping OAuth preflight for client %s: %s", + client_info.client_id, + e, + ) + return True + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + authorize_url, + allow_redirects=False, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as resp: + if resp.status < 400: + return True + + body_text = await resp.text() + error = None + error_description = "" + content_type = resp.headers.get("content-type", "") + + if "application/json" in content_type: + try: + payload = json.loads(body_text) + error = payload.get("error") + error_description = payload.get( + "error_description", "" + ) + except json.JSONDecodeError: + error = None + error_description = "" + else: + error_description = body_text + + combined = f"{error or ''} {error_description}".lower() + if "invalid_client" in combined or "invalid client" in combined or "client id" in combined: + log.warning( + "OAuth client preflight detected invalid registration for %s: %s %s", + client_info.client_id, + error, + error_description, + ) + return False + except Exception as e: + log.debug( + "Skipping OAuth preflight network check for client %s: %s", + client_info.client_id, + e, + ) + + return True + + async def _re_register_client(self, request, client_id: str) -> bool: + idx, connection = self._find_mcp_connection(request, client_id) + if idx is None or connection is None: + log.warning( + "Unable to locate MCP tool server configuration for client %s during re-registration", + client_id, + ) + return False + + server_url = connection.get("url") + oauth_server_key = (connection.get("config") or {}).get("oauth_server_key") + + try: + oauth_client_info = ( + await get_oauth_client_info_with_dynamic_client_registration( + request, + client_id, + server_url, + oauth_server_key, + ) + ) + except Exception as e: + log.error( + "Dynamic client re-registration failed for %s: %s", + client_id, + e, + ) + return False + + encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json")) + + updated_connections = copy.deepcopy( + request.app.state.config.TOOL_SERVER_CONNECTIONS or [] + ) + if idx >= len(updated_connections): + log.error( + "MCP tool server index %s out of range during OAuth client re-registration for %s", + idx, + client_id, + ) + return False + + updated_connection = copy.deepcopy(connection) + updated_connection.setdefault("info", {}) + updated_connection["info"]["oauth_client_info"] = encrypted_info + updated_connections[idx] = updated_connection + + try: + request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections + except Exception as e: + log.error( + "Failed to persist updated OAuth client info for %s: %s", + client_id, + e, + ) + return False + + self.remove_client(client_id) + self.add_client(client_id, oauth_client_info) + OAuthSessions.delete_sessions_by_provider(client_id) + + log.info("Re-registered OAuth client %s for MCP tool server", client_id) + return True + + async def _ensure_valid_client_registration( + self, request, client_id: str + ) -> None: + if not client_id.startswith("mcp:"): + return + + client = self.get_client(client_id) + client_info = self.get_client_info(client_id) + if client is None or client_info is None: + raise HTTPException(status.HTTP_404_NOT_FOUND) + + is_valid = await self._preflight_authorization_url(client, client_info) + if is_valid: + return + + log.info( + "Detected invalid OAuth client %s; attempting re-registration", + client_id, + ) + re_registered = await self._re_register_client(request, client_id) + if not re_registered: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to re-register OAuth client", + ) + + client = self.get_client(client_id) + client_info = self.get_client_info(client_id) + if client is None or client_info is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth client unavailable after re-registration", + ) + + if not await self._preflight_authorization_url(client, client_info): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth client registration is still invalid after re-registration", + ) + def get_client(self, client_id): client = self.clients.get(client_id) return client["client"] if client else None @@ -602,10 +802,11 @@ class OAuthClientManager: return None async def handle_authorize(self, request, client_id: str) -> RedirectResponse: + await self._ensure_valid_client_registration(request, client_id) + client = self.get_client(client_id) if client is None: raise HTTPException(404) - client_info = self.get_client_info(client_id) if client_info is None: raise HTTPException(404) @@ -613,7 +814,8 @@ class OAuthClientManager: redirect_uri = ( client_info.redirect_uris[0] if client_info.redirect_uris else None ) - return await client.authorize_redirect(request, str(redirect_uri)) + redirect_uri_str = str(redirect_uri) if redirect_uri else None + return await client.authorize_redirect(request, redirect_uri_str) async def handle_callback(self, request, client_id: str, user_id: str, response): client = self.get_client(client_id)