mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac/enh: mcp oauth auth method support
This commit is contained in:
parent
76acdabdc3
commit
0c47cbd16a
1 changed files with 41 additions and 22 deletions
|
|
@ -14,7 +14,7 @@ import fnmatch
|
||||||
import time
|
import time
|
||||||
import secrets
|
import secrets
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from authlib.integrations.starlette_client import OAuth
|
from authlib.integrations.starlette_client import OAuth
|
||||||
|
|
@ -72,13 +72,20 @@ from open_webui.utils.auth import get_password_hash, create_token
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
|
||||||
from mcp.shared.auth import (
|
from mcp.shared.auth import (
|
||||||
OAuthClientMetadata,
|
OAuthClientMetadata as MCPOAuthClientMetadata,
|
||||||
OAuthMetadata,
|
OAuthMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
from authlib.oauth2.rfc6749.errors import OAuth2Error
|
from authlib.oauth2.rfc6749.errors import OAuth2Error
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientMetadata(MCPOAuthClientMetadata):
|
||||||
|
token_endpoint_auth_method: Literal[
|
||||||
|
"none", "client_secret_basic", "client_secret_post"
|
||||||
|
] = "client_secret_post"
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientInformationFull(OAuthClientMetadata):
|
class OAuthClientInformationFull(OAuthClientMetadata):
|
||||||
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
||||||
|
|
||||||
|
|
@ -242,26 +249,28 @@ def get_discovery_urls(server_url) -> list[str]:
|
||||||
|
|
||||||
if parsed.path and parsed.path != "/":
|
if parsed.path and parsed.path != "/":
|
||||||
# Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
|
# Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
|
||||||
tenant = parsed.path.rstrip('/')
|
tenant = parsed.path.rstrip("/")
|
||||||
urls.extend([
|
urls.extend(
|
||||||
urllib.parse.urljoin(
|
[
|
||||||
base_url,
|
urllib.parse.urljoin(
|
||||||
f"/.well-known/oauth-authorization-server{tenant}",
|
base_url,
|
||||||
),
|
f"/.well-known/oauth-authorization-server{tenant}",
|
||||||
urllib.parse.urljoin(
|
),
|
||||||
base_url,
|
urllib.parse.urljoin(
|
||||||
f"/.well-known/openid-configuration{tenant}"
|
base_url, f"/.well-known/openid-configuration{tenant}"
|
||||||
),
|
),
|
||||||
urllib.parse.urljoin(
|
urllib.parse.urljoin(
|
||||||
base_url,
|
base_url, f"{tenant}/.well-known/openid-configuration"
|
||||||
f"{tenant}/.well-known/openid-configuration"
|
),
|
||||||
)
|
]
|
||||||
])
|
)
|
||||||
|
|
||||||
urls.extend([
|
urls.extend(
|
||||||
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
|
[
|
||||||
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
|
urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server"),
|
||||||
])
|
urllib.parse.urljoin(base_url, "/.well-known/openid-configuration"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|
||||||
|
|
@ -287,7 +296,6 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
|
redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"],
|
||||||
grant_types=["authorization_code", "refresh_token"],
|
grant_types=["authorization_code", "refresh_token"],
|
||||||
response_types=["code"],
|
response_types=["code"],
|
||||||
token_endpoint_auth_method="client_secret_post",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
|
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
|
||||||
|
|
@ -310,6 +318,17 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||||
oauth_client_metadata.scope = " ".join(
|
oauth_client_metadata.scope = " ".join(
|
||||||
oauth_server_metadata.scopes_supported
|
oauth_server_metadata.scopes_supported
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
oauth_server_metadata.token_endpoint_auth_methods_supported
|
||||||
|
and oauth_client_metadata.token_endpoint_auth_method
|
||||||
|
not in oauth_server_metadata.token_endpoint_auth_methods_supported
|
||||||
|
):
|
||||||
|
# Pick the first supported method from the server
|
||||||
|
oauth_client_metadata.token_endpoint_auth_method = oauth_server_metadata.token_endpoint_auth_methods_supported[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error parsing OAuth metadata from {url}: {e}")
|
log.error(f"Error parsing OAuth metadata from {url}: {e}")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue