From db658a730c32dcdb804d34beba07459caeda53f0 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 27 Oct 2025 16:46:04 -0700 Subject: [PATCH] refac --- backend/open_webui/main.py | 97 ++++++++++++++++- backend/open_webui/utils/oauth.py | 175 ++++-------------------------- 2 files changed, 116 insertions(+), 156 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index da89fd7de4..14ee4dc870 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -482,9 +482,11 @@ from open_webui.utils.auth import ( ) from open_webui.utils.plugin import install_tool_and_function_dependencies from open_webui.utils.oauth import ( + get_oauth_client_info_with_dynamic_client_registration, + encrypt_data, + decrypt_data, OAuthManager, OAuthClientManager, - decrypt_data, OAuthClientInformationFull, ) from open_webui.utils.security_headers import SecurityHeadersMiddleware @@ -1987,6 +1989,64 @@ except Exception as e: ) +async def register_client(self, request, client_id: str) -> bool: + server_type, server_id = client_id.split(":", 1) + + connection = None + connection_idx = None + + for idx, conn in enumerate(request.app.state.config.TOOL_SERVER_CONNECTIONS or []): + if conn.get("type", "openapi") == server_type: + info = conn.get("info", {}) + if info.get("id") == server_id: + connection = conn + connection_idx = idx + break + + if connection is None or connection_idx is None: + log.warning( + f"Unable to locate MCP tool server configuration for client {client_id} during re-registration" + ) + return False + + server_url = connection.get("url") + oauth_server_key = (connection.get("config") or {}).get("oauth_server_key") + + try: + oauth_client_info = ( + await get_oauth_client_info_with_dynamic_client_registration( + request, + client_id, + server_url, + oauth_server_key, + ) + ) + except Exception as e: + log.error(f"Dynamic client re-registration failed for {client_id}: {e}") + return False + + try: + request.app.state.config.TOOL_SERVER_CONNECTIONS[connection_idx] = { + **connection, + "info": { + **connection.get("info", {}), + "oauth_client_info": encrypt_data( + oauth_client_info.model_dump(mode="json") + ), + }, + } + except Exception as e: + log.error( + f"Failed to persist updated OAuth client info for tool server {client_id}: {e}" + ) + return False + + oauth_client_manager.remove_client(client_id) + oauth_client_manager.add_client(client_id, oauth_client_info) + log.info(f"Re-registered OAuth client {client_id} for tool server") + return True + + @app.get("/oauth/clients/{client_id}/authorize") async def oauth_client_authorize( client_id: str, @@ -1994,6 +2054,41 @@ async def oauth_client_authorize( response: Response, user=Depends(get_verified_user), ): + # ensure_valid_client_registration + client = oauth_client_manager.get_client(client_id) + client_info = oauth_client_manager.get_client_info(client_id) + if client is None or client_info is None: + raise HTTPException(status.HTTP_404_NOT_FOUND) + + if not await oauth_client_manager._preflight_authorization_url(client, client_info): + log.info( + "Detected invalid OAuth client %s; attempting re-registration", + client_id, + ) + + re_registered = await register_client(request, client_id) + if not re_registered: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to re-register OAuth client", + ) + + client = oauth_client_manager.get_client(client_id) + client_info = oauth_client_manager.get_client_info(client_id) + if client is None or client_info is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth client unavailable after re-registration", + ) + + if not await oauth_client_manager._preflight_authorization_url( + client, client_info + ): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth client registration is still invalid after re-registration", + ) + return await oauth_client_manager.handle_authorize(request, client_id=client_id) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 03f7337774..6889f377bc 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -416,34 +416,10 @@ class OAuthClientManager: return True - def _find_mcp_connection(self, request, client_id: str): - try: - connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or [] - except Exception: - connections = [] - - normalized_client_id = client_id.split(":")[-1] - - for idx, connection in enumerate(connections): - if not isinstance(connection, dict): - continue - if connection.get("type") != "mcp": - continue - - info = connection.get("info") or {} - server_id = info.get("id") - if not server_id: - continue - - normalized_server_id = server_id.split(":")[-1] - if normalized_server_id == normalized_client_id: - return idx, connection - - return None, None - async def _preflight_authorization_url( self, client, client_info: OAuthClientInformationFull ) -> bool: + # TODO: Replace this logic with a more robust OAuth client registration validation # Only perform preflight checks for Starlette OAuth clients if not hasattr(client, "create_authorization_url"): return True @@ -454,168 +430,59 @@ class OAuthClientManager: try: auth_data = await client.create_authorization_url(redirect_uri=redirect_uri) - authorize_url = auth_data.get("url") - if not authorize_url: + authorization_url = auth_data.get("url") + + if not authorization_url: return True except Exception as e: log.debug( - "Skipping OAuth preflight for client %s: %s", - client_info.client_id, - e, + f"Skipping OAuth preflight for client {client_info.client_id}: {e}", ) return True try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - authorize_url, + authorization_url, allow_redirects=False, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as resp: if resp.status < 400: return True + response_text = await resp.text() - body_text = await resp.text() error = None error_description = "" - content_type = resp.headers.get("content-type", "") + content_type = resp.headers.get("content-type", "") if "application/json" in content_type: try: - payload = json.loads(body_text) + payload = json.loads(response_text) error = payload.get("error") error_description = payload.get("error_description", "") - except json.JSONDecodeError: - error = None - error_description = "" + except: + pass else: - error_description = body_text + error_description = response_text - combined = f"{error or ''} {error_description}".lower() - if ( - "invalid_client" in combined - or "invalid client" in combined - or "client id" in combined + error_message = f"{error or ''} {error_description or ''}".lower() + + if any( + keyword in error_message + for keyword in ("invalid_client", "invalid client", "client id") ): log.warning( - "OAuth client preflight detected invalid registration for %s: %s %s", - client_info.client_id, - error, - error_description, + f"OAuth client preflight detected invalid registration for {client_info.client_id}: {error} {error_description}" ) + return False except Exception as e: log.debug( - "Skipping OAuth preflight network check for client %s: %s", - client_info.client_id, - e, + f"Skipping OAuth preflight network check for client {client_info.client_id}: {e}" ) return True - async def _re_register_client(self, request, client_id: str) -> bool: - idx, connection = self._find_mcp_connection(request, client_id) - if idx is None or connection is None: - log.warning( - "Unable to locate MCP tool server configuration for client %s during re-registration", - client_id, - ) - return False - - server_url = connection.get("url") - oauth_server_key = (connection.get("config") or {}).get("oauth_server_key") - - try: - oauth_client_info = ( - await get_oauth_client_info_with_dynamic_client_registration( - request, - client_id, - server_url, - oauth_server_key, - ) - ) - except Exception as e: - log.error( - "Dynamic client re-registration failed for %s: %s", - client_id, - e, - ) - return False - - encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json")) - - updated_connections = copy.deepcopy( - request.app.state.config.TOOL_SERVER_CONNECTIONS or [] - ) - if idx >= len(updated_connections): - log.error( - "MCP tool server index %s out of range during OAuth client re-registration for %s", - idx, - client_id, - ) - return False - - updated_connection = copy.deepcopy(connection) - updated_connection.setdefault("info", {}) - updated_connection["info"]["oauth_client_info"] = encrypted_info - updated_connections[idx] = updated_connection - - try: - request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections - except Exception as e: - log.error( - "Failed to persist updated OAuth client info for %s: %s", - client_id, - e, - ) - return False - - self.remove_client(client_id) - self.add_client(client_id, oauth_client_info) - - log.info("Re-registered OAuth client %s for MCP tool server", client_id) - return True - - async def _ensure_valid_client_registration(self, request, client_id: str) -> None: - if not client_id.startswith("mcp:"): - return - - client = self.get_client(client_id) - client_info = self.get_client_info(client_id) - - if client is None or client_info is None: - raise HTTPException(status.HTTP_404_NOT_FOUND) - - is_valid = await self._preflight_authorization_url(client, client_info) - if is_valid: - return - - log.info( - "Detected invalid OAuth client %s; attempting re-registration", - client_id, - ) - - re_registered = await self._re_register_client(request, client_id) - if not re_registered: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to re-register OAuth client", - ) - - client = self.get_client(client_id) - client_info = self.get_client_info(client_id) - if client is None or client_info is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth client unavailable after re-registration", - ) - - if not await self._preflight_authorization_url(client, client_info): - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth client registration is still invalid after re-registration", - ) - def get_client(self, client_id): client = self.clients.get(client_id) return client["client"] if client else None @@ -801,8 +668,6 @@ class OAuthClientManager: return None async def handle_authorize(self, request, client_id: str) -> RedirectResponse: - # await self._ensure_valid_client_registration(request, client_id) - client = self.get_client(client_id) if client is None: raise HTTPException(404)