diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 1ef5268bae..b300ab57e9 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -241,6 +241,61 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool: return False +async def discover_authorization_server_from_mcp(mcp_server_url: str) -> list[str]: + """ + Discover OAuth authorization servers by following the MCP Protected Resource flow. + + According to the MCP spec (https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization): + 1. Make an unauthenticated request to the MCP endpoint + 2. Parse WWW-Authenticate header to get resource_metadata URL + 3. Fetch Protected Resource metadata to get authorization_servers + + Returns: + List of authorization server base URLs, or empty list if discovery fails + """ + authorization_servers = [] + + try: + # Step 1: Make unauthenticated request to MCP endpoint to get WWW-Authenticate header + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + mcp_server_url, + json={"jsonrpc": "2.0", "method": "initialize", "params": {}, "id": 1}, + headers={"Content-Type": "application/json"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + if response.status == 401: + www_auth = response.headers.get("WWW-Authenticate", "") + + # Parse resource_metadata from WWW-Authenticate header + # Format: Bearer resource_metadata="https://..." + match = re.search(r'resource_metadata="([^"]+)"', www_auth) + if match: + resource_metadata_url = match.group(1) + log.debug(f"Found resource_metadata URL: {resource_metadata_url}") + + # Step 2: Fetch Protected Resource metadata + async with session.get( + resource_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as resource_response: + if resource_response.status == 200: + resource_metadata = await resource_response.json() + + # Step 3: Extract authorization_servers + servers = resource_metadata.get( + "authorization_servers", [] + ) + if servers: + authorization_servers = servers + log.debug( + f"Discovered authorization servers: {servers}" + ) + except Exception as e: + log.debug(f"MCP Protected Resource discovery failed: {e}") + + return authorization_servers + + def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: parsed = urllib.parse.urlparse(server_url) base_url = f"{parsed.scheme}://{parsed.netloc}" @@ -303,9 +358,29 @@ async def get_oauth_client_info_with_dynamic_client_registration( response_types=["code"], ) + # First, try MCP Protected Resource discovery flow + # This handles cases where the OAuth server is on a different domain than the MCP server + # (e.g., Todoist MCP at ai.todoist.net, OAuth at todoist.com) + authorization_servers = await discover_authorization_server_from_mcp( + oauth_server_url + ) + + # Build discovery URLs - prioritize authorization servers from MCP discovery + all_discovery_urls = [] + for auth_server in authorization_servers: + auth_server = auth_server.rstrip("/") + all_discovery_urls.extend( + [ + f"{auth_server}/.well-known/oauth-authorization-server", + f"{auth_server}/.well-known/openid-configuration", + ] + ) + + # Fall back to standard discovery URLs based on the MCP server URL + all_discovery_urls.extend(get_discovery_urls(oauth_server_url)) + # Attempt to fetch OAuth server metadata to get registration endpoint & scopes - discovery_urls = get_discovery_urls(oauth_server_url) - for url in discovery_urls: + for url in all_discovery_urls: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( url, ssl=AIOHTTP_CLIENT_SESSION_SSL