From ec9359e360437c01a28878ff2194f46e76a57e1c Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 27 Oct 2025 15:31:25 -0700 Subject: [PATCH] refac --- backend/open_webui/routers/configs.py | 62 +++++++-------------------- backend/open_webui/utils/oauth.py | 37 ++++++++-------- 2 files changed, 32 insertions(+), 67 deletions(-) diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index e8e876eac7..43ef73f29b 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -167,59 +167,27 @@ async def set_tool_servers_config( form_data: ToolServersConfigForm, user=Depends(get_admin_user), ): - old_connections = copy.deepcopy( - request.app.state.config.TOOL_SERVER_CONNECTIONS or [] - ) + mcp_server_ids = [ + conn.get("info", {}).get("id") + for conn in form_data.TOOL_SERVER_CONNECTIONS + if conn.get("type") == "mcp" + ] - new_connections = [ + for server_id in mcp_server_ids: + # Remove existing OAuth clients for MCP tool servers that are no longer present + client_key = f"mcp:{server_id}" + try: + request.app.state.oauth_client_manager.remove_client(client_key) + except: + pass + + # Set new tool server connections + request.app.state.config.TOOL_SERVER_CONNECTIONS = [ connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS ] - old_mcp_connections = { - conn.get("info", {}).get("id"): conn - for conn in old_connections - if conn.get("type") == "mcp" - } - new_mcp_connections = { - conn.get("info", {}).get("id"): conn - for conn in new_connections - if conn.get("type") == "mcp" - } - - purge_oauth_clients = set() - - for server_id, old_conn in old_mcp_connections.items(): - if not server_id: - continue - - old_auth_type = old_conn.get("auth_type", "none") - new_conn = new_mcp_connections.get(server_id) - - if new_conn is None: - if old_auth_type == "oauth_2.1": - purge_oauth_clients.add(server_id) - continue - - new_auth_type = new_conn.get("auth_type", "none") - - if old_auth_type == "oauth_2.1": - if ( - new_auth_type != "oauth_2.1" - or old_conn.get("url") != new_conn.get("url") - or old_conn.get("info", {}).get("oauth_client_info") - != new_conn.get("info", {}).get("oauth_client_info") - ): - purge_oauth_clients.add(server_id) - - request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections - await set_tool_servers(request) - for server_id in purge_oauth_clients: - client_key = f"mcp:{server_id}" - request.app.state.oauth_client_manager.remove_client(client_key) - OAuthSessions.delete_sessions_by_provider(client_key) - for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: server_type = connection.get("type", "openapi") if server_type == "mcp": diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 34fb441679..30939eb20a 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -153,32 +153,32 @@ def decrypt_data(data: str): raise -def _build_oauth_callback_error_message(exc: Exception) -> str: +def _build_oauth_callback_error_message(e: Exception) -> str: """ Produce a user-facing callback error string with actionable context. Keeps the message short and strips newlines for safe redirect usage. """ - if isinstance(exc, OAuth2Error): - parts = [p for p in [exc.error, exc.description] if p] + if isinstance(e, OAuth2Error): + parts = [p for p in [e.error, e.description] if p] detail = " - ".join(parts) - elif isinstance(exc, HTTPException): - detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail) - elif isinstance(exc, aiohttp.ClientResponseError): - detail = f"Upstream provider returned {exc.status}: {exc.message}" - elif isinstance(exc, aiohttp.ClientError): - detail = str(exc) - elif isinstance(exc, KeyError): - missing = str(exc).strip("'") + elif isinstance(e, HTTPException): + detail = e.detail if isinstance(e.detail, str) else str(e.detail) + elif isinstance(e, aiohttp.ClientResponseError): + detail = f"Upstream provider returned {e.status}: {e.message}" + elif isinstance(e, aiohttp.ClientError): + detail = str(e) + elif isinstance(e, KeyError): + missing = str(e).strip("'") if missing.lower() == "state": detail = "Missing state parameter in callback (session may have expired)" else: detail = f"Missing expected key '{missing}' in OAuth response" else: - detail = str(exc) + detail = str(e) detail = detail.replace("\n", " ").strip() if not detail: - detail = exc.__class__.__name__ + detail = e.__class__.__name__ message = f"OAuth callback failed: {detail}" return message[:197] + "..." if len(message) > 200 else message @@ -402,20 +402,18 @@ class OAuthClientManager: return self.clients[client_id] def remove_client(self, client_id): - removed = False if client_id in self.clients: del self.clients[client_id] - removed = True + log.info(f"Removed OAuth client {client_id}") + if hasattr(self.oauth, "_clients"): if client_id in self.oauth._clients: self.oauth._clients.pop(client_id, None) - removed = True + if hasattr(self.oauth, "_registry"): if client_id in self.oauth._registry: self.oauth._registry.pop(client_id, None) - removed = True - if removed: - log.info(f"Removed OAuth client {client_id}") + return True def _find_mcp_connection(self, request, client_id: str): @@ -574,7 +572,6 @@ class OAuthClientManager: self.remove_client(client_id) self.add_client(client_id, oauth_client_info) - OAuthSessions.delete_sessions_by_provider(client_id) log.info("Re-registered OAuth client %s for MCP tool server", client_id) return True