diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 1f1084af2b..079c12876c 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -83,6 +83,8 @@ class OAuthClientInformationFull(OAuthClientMetadata): client_id_issued_at: int | None = None client_secret_expires_at: int | None = None + server_metadata: Optional[OAuthMetadata] = None # Fetched from the OAuth server + from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL @@ -297,6 +299,7 @@ async def get_oauth_client_info_with_dynamic_client_registration( { **registration_response_json, **{"issuer": oauth_server_metadata_url}, + **{"server_metadata": oauth_server_metadata}, } ) log.info( @@ -332,20 +335,27 @@ class OAuthClientManager: self.clients = {} def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull): - self.clients[client_id] = { - "client": self.oauth.register( - name=client_id, - client_id=oauth_client_info.client_id, - client_secret=oauth_client_info.client_secret, - client_kwargs=( - {"scope": oauth_client_info.scope} - if oauth_client_info.scope - else {} - ), - server_metadata_url=( - oauth_client_info.issuer if oauth_client_info.issuer else None - ), + kwargs = { + "name": client_id, + "client_id": oauth_client_info.client_id, + "client_secret": oauth_client_info.client_secret, + "client_kwargs": ( + {"scope": oauth_client_info.scope} if oauth_client_info.scope else {} ), + "server_metadata_url": ( + oauth_client_info.issuer if oauth_client_info.issuer else None + ), + } + + if ( + oauth_client_info.server_metadata + and "S256" + in oauth_client_info.server_metadata.code_challenge_methods_supported + ): + kwargs["code_challenge_method"] = "S256" + + self.clients[client_id] = { + "client": self.oauth.register(**kwargs), "client_info": oauth_client_info, } return self.clients[client_id] @@ -561,7 +571,17 @@ class OAuthClientManager: error_message = None try: - token = await client.authorize_access_token(request) + client_info = self.get_client_info(client_id) + token_params = {} + if ( + client_info + and hasattr(client_info, "client_id") + and hasattr(client_info, "client_secret") + ): + token_params["client_id"] = client_info.client_id + token_params["client_secret"] = client_info.client_secret + + token = await client.authorize_access_token(request, **token_params) if token: try: # Add timestamp for tracking diff --git a/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte b/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte index f0113a8b24..b120495045 100644 --- a/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte +++ b/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte @@ -338,7 +338,7 @@ let serverId = parts?.at(-1) ?? toolId; const authUrl = getOAuthClientAuthorizationUrl(serverId, 'mcp'); - window.open(authUrl, '_blank', 'noopener'); + window.open(authUrl, '_self', 'noopener'); } else { tools[toolId].enabled = !tools[toolId].enabled;