refac/enh: mcp oauth auth method support

This commit is contained in:
Timothy Jaeryang Baek 2025-11-19 02:26:42 -05:00
parent 76acdabdc3
commit 0c47cbd16a

View file

@ -14,7 +14,7 @@ import fnmatch
import time import time
import secrets import secrets
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from typing import Literal
import aiohttp import aiohttp
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
@ -72,13 +72,20 @@ from open_webui.utils.auth import get_password_hash, create_token
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from mcp.shared.auth import ( from mcp.shared.auth import (
OAuthClientMetadata, OAuthClientMetadata as MCPOAuthClientMetadata,
OAuthMetadata, OAuthMetadata,
) )
from authlib.oauth2.rfc6749.errors import OAuth2Error from authlib.oauth2.rfc6749.errors import OAuth2Error
class OAuthClientMetadata(MCPOAuthClientMetadata):
token_endpoint_auth_method: Literal[
"none", "client_secret_basic", "client_secret_post"
] = "client_secret_post"
pass
class OAuthClientInformationFull(OAuthClientMetadata): class OAuthClientInformationFull(OAuthClientMetadata):
issuer: Optional[str] = None # URL of the OAuth server that issued this client issuer: Optional[str] = None # URL of the OAuth server that issued this client
@ -242,26 +249,28 @@ def get_discovery_urls(server_url) -> list[str]:
if parsed.path and parsed.path != "/": if parsed.path and parsed.path != "/":
# Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery # Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
tenant = parsed.path.rstrip('/') tenant = parsed.path.rstrip("/")
urls.extend([ urls.extend(
urllib.parse.urljoin( [
base_url, urllib.parse.urljoin(
f"/.well-known/oauth-authorization-server{tenant}", base_url,
), f"/.well-known/oauth-authorization-server{tenant}",
urllib.parse.urljoin( ),
base_url, urllib.parse.urljoin(
f"/.well-known/openid-configuration{tenant}" base_url, f"/.well-known/openid-configuration{tenant}"
), ),
urllib.parse.urljoin( urllib.parse.urljoin(
base_url, base_url, f"{tenant}/.well-known/openid-configuration"
f"{tenant}/.well-known/openid-configuration" ),
) ]
]) )
urls.extend([ urls.extend(
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"), [
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"), urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
]) urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
]
)
return urls return urls
@ -287,7 +296,6 @@ async def get_oauth_client_info_with_dynamic_client_registration(
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
grant_types=["authorization_code", "refresh_token"], grant_types=["authorization_code", "refresh_token"],
response_types=["code"], response_types=["code"],
token_endpoint_auth_method="client_secret_post",
) )
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes # Attempt to fetch OAuth server metadata to get registration endpoint & scopes
@ -310,6 +318,17 @@ async def get_oauth_client_info_with_dynamic_client_registration(
oauth_client_metadata.scope = " ".join( oauth_client_metadata.scope = " ".join(
oauth_server_metadata.scopes_supported oauth_server_metadata.scopes_supported
) )
if (
oauth_server_metadata.token_endpoint_auth_methods_supported
and oauth_client_metadata.token_endpoint_auth_method
not in oauth_server_metadata.token_endpoint_auth_methods_supported
):
# Pick the first supported method from the server
oauth_client_metadata.token_endpoint_auth_method = oauth_server_metadata.token_endpoint_auth_methods_supported[
0
]
break break
except Exception as e: except Exception as e:
log.error(f"Error parsing OAuth metadata from {url}: {e}") log.error(f"Error parsing OAuth metadata from {url}: {e}")