diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index 81ce220384..b0e465dbe7 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -262,5 +262,16 @@ class OAuthSessionTable: log.error(f"Error deleting OAuth sessions by user ID: {e}") return False + def delete_sessions_by_provider(self, provider: str) -> bool: + """Delete all OAuth sessions for a provider""" + try: + with get_db() as db: + db.query(OAuthSession).filter_by(provider=provider).delete() + db.commit() + return True + except Exception as e: + log.error(f"Error deleting OAuth sessions by provider {provider}: {e}") + return False + OAuthSessions = OAuthSessionTable() diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index e7fa13d1ff..e8e876eac7 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,4 +1,5 @@ import logging +import copy from fastapi import APIRouter, Depends, Request, HTTPException from pydantic import BaseModel, ConfigDict import aiohttp @@ -15,6 +16,7 @@ from open_webui.utils.tools import ( set_tool_servers, ) from open_webui.utils.mcp.client import MCPClient +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.env import SRC_LOG_LEVELS @@ -165,12 +167,59 @@ async def set_tool_servers_config( form_data: ToolServersConfigForm, user=Depends(get_admin_user), ): - request.app.state.config.TOOL_SERVER_CONNECTIONS = [ + old_connections = copy.deepcopy( + request.app.state.config.TOOL_SERVER_CONNECTIONS or [] + ) + + new_connections = [ connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS ] + old_mcp_connections = { + conn.get("info", {}).get("id"): conn + for conn in old_connections + if conn.get("type") == "mcp" + } + new_mcp_connections = { + conn.get("info", {}).get("id"): conn + for conn in new_connections + if conn.get("type") == "mcp" + } + + purge_oauth_clients = set() + + for server_id, old_conn in old_mcp_connections.items(): + if not server_id: + continue + + old_auth_type = old_conn.get("auth_type", "none") + new_conn = new_mcp_connections.get(server_id) + + if new_conn is None: + if old_auth_type == "oauth_2.1": + purge_oauth_clients.add(server_id) + continue + + new_auth_type = new_conn.get("auth_type", "none") + + if old_auth_type == "oauth_2.1": + if ( + new_auth_type != "oauth_2.1" + or old_conn.get("url") != new_conn.get("url") + or old_conn.get("info", {}).get("oauth_client_info") + != new_conn.get("info", {}).get("oauth_client_info") + ): + purge_oauth_clients.add(server_id) + + request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections + await set_tool_servers(request) + for server_id in purge_oauth_clients: + client_key = f"mcp:{server_id}" + request.app.state.oauth_client_manager.remove_client(client_key) + OAuthSessions.delete_sessions_by_provider(client_key) + for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: server_type = connection.get("type", "openapi") if server_type == "mcp": diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index e0bf7582c6..34fb441679 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1,4 +1,5 @@ import base64 +import copy import hashlib import logging import mimetypes @@ -74,6 +75,8 @@ from mcp.shared.auth import ( OAuthMetadata, ) +from authlib.oauth2.rfc6749.errors import OAuth2Error + class OAuthClientInformationFull(OAuthClientMetadata): issuer: Optional[str] = None # URL of the OAuth server that issued this client @@ -150,6 +153,37 @@ def decrypt_data(data: str): raise +def _build_oauth_callback_error_message(exc: Exception) -> str: + """ + Produce a user-facing callback error string with actionable context. + Keeps the message short and strips newlines for safe redirect usage. + """ + if isinstance(exc, OAuth2Error): + parts = [p for p in [exc.error, exc.description] if p] + detail = " - ".join(parts) + elif isinstance(exc, HTTPException): + detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail) + elif isinstance(exc, aiohttp.ClientResponseError): + detail = f"Upstream provider returned {exc.status}: {exc.message}" + elif isinstance(exc, aiohttp.ClientError): + detail = str(exc) + elif isinstance(exc, KeyError): + missing = str(exc).strip("'") + if missing.lower() == "state": + detail = "Missing state parameter in callback (session may have expired)" + else: + detail = f"Missing expected key '{missing}' in OAuth response" + else: + detail = str(exc) + + detail = detail.replace("\n", " ").strip() + if not detail: + detail = exc.__class__.__name__ + + message = f"OAuth callback failed: {detail}" + return message[:197] + "..." if len(message) > 200 else message + + def is_in_blocked_groups(group_name: str, groups: list) -> bool: """ Check if a group name matches any blocked pattern. @@ -368,11 +402,221 @@ class OAuthClientManager: return self.clients[client_id] def remove_client(self, client_id): + removed = False if client_id in self.clients: del self.clients[client_id] + removed = True + if hasattr(self.oauth, "_clients"): + if client_id in self.oauth._clients: + self.oauth._clients.pop(client_id, None) + removed = True + if hasattr(self.oauth, "_registry"): + if client_id in self.oauth._registry: + self.oauth._registry.pop(client_id, None) + removed = True + if removed: log.info(f"Removed OAuth client {client_id}") return True + def _find_mcp_connection(self, request, client_id: str): + try: + connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or [] + except Exception: + connections = [] + + normalized_client_id = client_id.split(":")[-1] + + for idx, connection in enumerate(connections): + if not isinstance(connection, dict): + continue + if connection.get("type") != "mcp": + continue + + info = connection.get("info") or {} + server_id = info.get("id") + if not server_id: + continue + + normalized_server_id = server_id.split(":")[-1] + if normalized_server_id == normalized_client_id: + return idx, connection + + return None, None + + async def _preflight_authorization_url( + self, client, client_info: OAuthClientInformationFull + ) -> bool: + # Only perform preflight checks for Starlette OAuth clients + if not hasattr(client, "create_authorization_url"): + return True + + redirect_uri = None + if client_info.redirect_uris: + redirect_uri = str(client_info.redirect_uris[0]) + + try: + auth_data = await client.create_authorization_url(redirect_uri=redirect_uri) + authorize_url = auth_data.get("url") + if not authorize_url: + return True + except Exception as e: + log.debug( + "Skipping OAuth preflight for client %s: %s", + client_info.client_id, + e, + ) + return True + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + authorize_url, + allow_redirects=False, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as resp: + if resp.status < 400: + return True + + body_text = await resp.text() + error = None + error_description = "" + content_type = resp.headers.get("content-type", "") + + if "application/json" in content_type: + try: + payload = json.loads(body_text) + error = payload.get("error") + error_description = payload.get("error_description", "") + except json.JSONDecodeError: + error = None + error_description = "" + else: + error_description = body_text + + combined = f"{error or ''} {error_description}".lower() + if ( + "invalid_client" in combined + or "invalid client" in combined + or "client id" in combined + ): + log.warning( + "OAuth client preflight detected invalid registration for %s: %s %s", + client_info.client_id, + error, + error_description, + ) + return False + except Exception as e: + log.debug( + "Skipping OAuth preflight network check for client %s: %s", + client_info.client_id, + e, + ) + + return True + + async def _re_register_client(self, request, client_id: str) -> bool: + idx, connection = self._find_mcp_connection(request, client_id) + if idx is None or connection is None: + log.warning( + "Unable to locate MCP tool server configuration for client %s during re-registration", + client_id, + ) + return False + + server_url = connection.get("url") + oauth_server_key = (connection.get("config") or {}).get("oauth_server_key") + + try: + oauth_client_info = ( + await get_oauth_client_info_with_dynamic_client_registration( + request, + client_id, + server_url, + oauth_server_key, + ) + ) + except Exception as e: + log.error( + "Dynamic client re-registration failed for %s: %s", + client_id, + e, + ) + return False + + encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json")) + + updated_connections = copy.deepcopy( + request.app.state.config.TOOL_SERVER_CONNECTIONS or [] + ) + if idx >= len(updated_connections): + log.error( + "MCP tool server index %s out of range during OAuth client re-registration for %s", + idx, + client_id, + ) + return False + + updated_connection = copy.deepcopy(connection) + updated_connection.setdefault("info", {}) + updated_connection["info"]["oauth_client_info"] = encrypted_info + updated_connections[idx] = updated_connection + + try: + request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections + except Exception as e: + log.error( + "Failed to persist updated OAuth client info for %s: %s", + client_id, + e, + ) + return False + + self.remove_client(client_id) + self.add_client(client_id, oauth_client_info) + OAuthSessions.delete_sessions_by_provider(client_id) + + log.info("Re-registered OAuth client %s for MCP tool server", client_id) + return True + + async def _ensure_valid_client_registration(self, request, client_id: str) -> None: + if not client_id.startswith("mcp:"): + return + + client = self.get_client(client_id) + client_info = self.get_client_info(client_id) + if client is None or client_info is None: + raise HTTPException(status.HTTP_404_NOT_FOUND) + + is_valid = await self._preflight_authorization_url(client, client_info) + if is_valid: + return + + log.info( + "Detected invalid OAuth client %s; attempting re-registration", + client_id, + ) + re_registered = await self._re_register_client(request, client_id) + if not re_registered: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to re-register OAuth client", + ) + + client = self.get_client(client_id) + client_info = self.get_client_info(client_id) + if client is None or client_info is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth client unavailable after re-registration", + ) + + if not await self._preflight_authorization_url(client, client_info): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth client registration is still invalid after re-registration", + ) + def get_client(self, client_id): client = self.clients.get(client_id) return client["client"] if client else None @@ -558,10 +802,11 @@ class OAuthClientManager: return None async def handle_authorize(self, request, client_id: str) -> RedirectResponse: + await self._ensure_valid_client_registration(request, client_id) + client = self.get_client(client_id) if client is None: raise HTTPException(404) - client_info = self.get_client_info(client_id) if client_info is None: raise HTTPException(404) @@ -569,7 +814,8 @@ class OAuthClientManager: redirect_uri = ( client_info.redirect_uris[0] if client_info.redirect_uris else None ) - return await client.authorize_redirect(request, str(redirect_uri)) + redirect_uri_str = str(redirect_uri) if redirect_uri else None + return await client.authorize_redirect(request, redirect_uri_str) async def handle_callback(self, request, client_id: str, user_id: str, response): client = self.get_client(client_id) @@ -621,8 +867,14 @@ class OAuthClientManager: error_message = "Failed to obtain OAuth token" log.warning(error_message) except Exception as e: - error_message = "OAuth callback error" - log.warning(f"OAuth callback error: {e}") + error_message = _build_oauth_callback_error_message(e) + log.warning( + "OAuth callback error for user_id=%s client_id=%s: %s", + user_id, + client_id, + error_message, + exc_info=True, + ) redirect_url = ( str(request.app.state.config.WEBUI_URL or request.base_url) @@ -630,7 +882,9 @@ class OAuthClientManager: if error_message: log.debug(error_message) - redirect_url = f"{redirect_url}/?error={error_message}" + redirect_url = ( + f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}" + ) return RedirectResponse(url=redirect_url, headers=response.headers) response = RedirectResponse(url=redirect_url, headers=response.headers) @@ -1104,7 +1358,13 @@ class OAuthManager: try: token = await client.authorize_access_token(request) except Exception as e: - log.warning(f"OAuth callback error: {e}") + detailed_error = _build_oauth_callback_error_message(e) + log.warning( + "OAuth callback error during authorize_access_token for provider %s: %s", + provider, + detailed_error, + exc_info=True, + ) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Try to get userinfo from the token first, some providers include it there