diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index c27dbec447..169042e816 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -14,7 +14,7 @@ import fnmatch import time import secrets from cryptography.fernet import Fernet - +from typing import Literal import aiohttp from authlib.integrations.starlette_client import OAuth @@ -72,13 +72,20 @@ from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook from mcp.shared.auth import ( - OAuthClientMetadata, + OAuthClientMetadata as MCPOAuthClientMetadata, OAuthMetadata, ) from authlib.oauth2.rfc6749.errors import OAuth2Error +class OAuthClientMetadata(MCPOAuthClientMetadata): + token_endpoint_auth_method: Literal[ + "none", "client_secret_basic", "client_secret_post" + ] = "client_secret_post" + pass + + class OAuthClientInformationFull(OAuthClientMetadata): issuer: Optional[str] = None # URL of the OAuth server that issued this client @@ -242,26 +249,28 @@ def get_discovery_urls(server_url) -> list[str]: if parsed.path and parsed.path != "/": # Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery - tenant = parsed.path.rstrip('/') - urls.extend([ - urllib.parse.urljoin( - base_url, - f"/.well-known/oauth-authorization-server{tenant}", - ), - urllib.parse.urljoin( - base_url, - f"/.well-known/openid-configuration{tenant}" - ), - urllib.parse.urljoin( - base_url, - f"{tenant}/.well-known/openid-configuration" - ) - ]) + tenant = parsed.path.rstrip("/") + urls.extend( + [ + urllib.parse.urljoin( + base_url, + f"/.well-known/oauth-authorization-server{tenant}", + ), + urllib.parse.urljoin( + base_url, f"/.well-known/openid-configuration{tenant}" + ), + urllib.parse.urljoin( + base_url, f"{tenant}/.well-known/openid-configuration" + ), + ] + ) - urls.extend([ - urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"), - urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"), - ]) + urls.extend( + [ + urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"), + urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"), + ] + ) return urls @@ -287,7 +296,6 @@ async def get_oauth_client_info_with_dynamic_client_registration( redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], grant_types=["authorization_code", "refresh_token"], response_types=["code"], - token_endpoint_auth_method="client_secret_post", ) # Attempt to fetch OAuth server metadata to get registration endpoint & scopes @@ -310,6 +318,17 @@ async def get_oauth_client_info_with_dynamic_client_registration( oauth_client_metadata.scope = " ".join( oauth_server_metadata.scopes_supported ) + + if ( + oauth_server_metadata.token_endpoint_auth_methods_supported + and oauth_client_metadata.token_endpoint_auth_method + not in oauth_server_metadata.token_endpoint_auth_methods_supported + ): + # Pick the first supported method from the server + oauth_client_metadata.token_endpoint_auth_method = oauth_server_metadata.token_endpoint_auth_methods_supported[ + 0 + ] + break except Exception as e: log.error(f"Error parsing OAuth metadata from {url}: {e}")