diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 243b8212a8..e02424f969 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -474,6 +474,10 @@ ENABLE_OAUTH_ID_TOKEN_COOKIE = ( os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true" ) +OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get( + "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY +) + OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get( "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY ) diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 31d7bce404..809a1cc47d 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,6 +1,7 @@ import logging from fastapi import APIRouter, Depends, Request, HTTPException from pydantic import BaseModel, ConfigDict +import aiohttp from typing import Optional @@ -17,6 +18,12 @@ from open_webui.utils.mcp.client import MCPClient from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.oauth import ( + get_discovery_urls, + get_oauth_client_info_with_dynamic_client_registration, + encrypt_token, +) +from mcp.shared.auth import OAuthMetadata router = APIRouter() @@ -86,6 +93,38 @@ async def set_connections_config( } +class OAuthClientRegistrationForm(BaseModel): + url: str + client_id: str + client_name: Optional[str] = None + + +@router.post("/oauth/clients/register") +async def register_oauth_client( + request: Request, + form_data: OAuthClientRegistrationForm, + user=Depends(get_admin_user), +): + try: + oauth_client_info = ( + await get_oauth_client_info_with_dynamic_client_registration( + request, form_data.url + ) + ) + return { + "status": True, + "oauth_client_info": encrypt_token( + oauth_client_info.model_dump(mode="json") + ), + } + except Exception as e: + log.debug(f"Failed to register OAuth client: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to register OAuth client", + ) + + ############################ # ToolServers Config ############################ @@ -138,46 +177,79 @@ async def verify_tool_servers_config( """ try: if form_data.type == "mcp": - try: - client = MCPClient() - auth = None - headers = None + if form_data.auth_type == "oauth_2.1": + discovery_urls = get_discovery_urls(form_data.url) + async with aiohttp.ClientSession() as session: + async with session.get( + discovery_urls[0] + ) as oauth_server_metadata_response: + if oauth_server_metadata_response.status != 200: + raise HTTPException( + status_code=400, + detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}", + ) - token = None - if form_data.auth_type == "bearer": - token = form_data.key - elif form_data.auth_type == "session": - token = request.state.token.credentials - elif form_data.auth_type == "system_oauth": - try: - if request.cookies.get("oauth_session_id", None): - token = ( - await request.app.state.oauth_manager.get_oauth_token( + try: + oauth_server_metadata = OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() + ) + return { + "status": True, + "oauth_server_metadata": oauth_server_metadata.model_dump( + mode="json" + ), + } + except Exception as e: + log.info( + f"Failed to parse OAuth 2.1 discovery document: {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_urls[0]}", + ) + + raise HTTPException( + status_code=400, + detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls[0]}", + ) + else: + try: + client = MCPClient() + headers = None + + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( user.id, request.cookies.get("oauth_session_id", None), ) - ) - except Exception as e: - pass + except Exception as e: + pass - if token: - headers = {"Authorization": f"Bearer {token}"} + if token: + headers = {"Authorization": f"Bearer {token}"} - await client.connect(form_data.url, auth=auth, headers=headers) - specs = await client.list_tool_specs() - return { - "status": True, - "specs": specs, - } - except Exception as e: - log.debug(f"Failed to create MCP client: {e}") - raise HTTPException( - status_code=400, - detail=f"Failed to create MCP client", - ) - finally: - if client: - await client.disconnect() + await client.connect(form_data.url, headers=headers) + specs = await client.list_tool_specs() + return { + "status": True, + "specs": specs, + } + except Exception as e: + log.debug(f"Failed to create MCP client: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to create MCP client", + ) + finally: + if client: + await client.disconnect() else: # openapi token = None if form_data.auth_type == "bearer": diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index 2d352ead24..01df38886c 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -13,13 +13,9 @@ class MCPClient: self.session: Optional[ClientSession] = None self.exit_stack = AsyncExitStack() - async def connect( - self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None - ): + async def connect(self, url: str, headers: Optional[dict] = None): try: - self._streams_context = streamablehttp_client( - url, headers=headers, auth=auth - ) + self._streams_context = streamablehttp_client(url, headers=headers) transport = await self.exit_stack.enter_async_context(self._streams_context) read_stream, write_stream, _ = transport diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 2147428443..6d5ea470cd 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1,7 +1,9 @@ import base64 +import hashlib import logging import mimetypes import sys +import urllib import uuid import json from datetime import datetime, timedelta @@ -9,6 +11,9 @@ from datetime import datetime, timedelta import re import fnmatch import time +import secrets +from cryptography.fernet import Fernet + import aiohttp from authlib.integrations.starlette_client import OAuth @@ -18,6 +23,7 @@ from fastapi import ( status, ) from starlette.responses import RedirectResponse +from typing import Optional from open_webui.models.auths import Auths @@ -56,11 +62,27 @@ from open_webui.env import ( WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, ENABLE_OAUTH_ID_TOKEN_COOKIE, + OAUTH_CLIENT_INFO_ENCRYPTION_KEY, ) from open_webui.utils.misc import parse_duration from open_webui.utils.auth import get_password_hash, create_token from open_webui.utils.webhook import post_webhook +from mcp.shared.auth import ( + OAuthClientMetadata, + OAuthMetadata, +) + + +class OAuthClientInformationFull(OAuthClientMetadata): + issuer: Optional[str] = None # URL of the OAuth server that issued this client + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -89,6 +111,42 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN = OAUTH_UPDATE_PICTURE_ON_LOGIN +FERNET = None + +if len(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) != 44: + key_bytes = hashlib.sha256(OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode()).digest() + OAUTH_CLIENT_INFO_ENCRYPTION_KEY = base64.urlsafe_b64encode(key_bytes) +else: + OAUTH_CLIENT_INFO_ENCRYPTION_KEY = OAUTH_CLIENT_INFO_ENCRYPTION_KEY.encode() + +try: + FERNET = Fernet(OAUTH_CLIENT_INFO_ENCRYPTION_KEY) +except Exception as e: + log.error(f"Error initializing Fernet with provided key: {e}") + raise + + +def encrypt_token(token) -> str: + """Encrypt OAuth tokens for storage""" + try: + token_json = json.dumps(token) + encrypted = FERNET.encrypt(token_json.encode()).decode() + return encrypted + except Exception as e: + log.error(f"Error encrypting tokens: {e}") + raise + + +def decrypt_token(token: str): + """Decrypt OAuth tokens from storage""" + try: + decrypted = FERNET.decrypt(token.encode()).decode() + return json.loads(decrypted) + except Exception as e: + log.error(f"Error decrypting tokens: {e}") + raise + + def is_in_blocked_groups(group_name: str, groups: list) -> bool: """ Check if a group name matches any blocked pattern. @@ -133,6 +191,406 @@ def is_in_blocked_groups(group_name: str, groups: list) -> bool: return False +def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: + parsed = urllib.parse.urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + return parsed, base_url + + +def get_discovery_urls(server_url) -> list[str]: + urls = [] + parsed, base_url = get_parsed_and_base_url(server_url) + + urls.append( + urllib.parse.urljoin(base_url, "/.well-known/oauth-authorization-server") + ) + urls.append(urllib.parse.urljoin(base_url, "/.well-known/openid-configuration")) + + return urls + + +# TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration. +# This is not currently supported. +async def get_oauth_client_info_with_dynamic_client_registration( + request, oauth_server_url, oauth_server_key: Optional[str] = None +) -> OAuthClientInformationFull: + try: + oauth_server_metadata = None + oauth_server_metadata_url = None + + redirect_base_url = ( + str(request.app.state.config.WEBUI_URL or request.base_url) + ).rstrip("/") + oauth_client_metadata = OAuthClientMetadata( + client_name="Open WebUI", + redirect_uris=[f"{redirect_base_url}/oauth/callback"], + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + token_endpoint_auth_method="client_secret_post", + ) + + # Attempt to fetch OAuth server metadata to get registration endpoint & scopes + discovery_urls = get_discovery_urls(oauth_server_url) + for url in discovery_urls: + async with aiohttp.ClientSession() as session: + async with session.get( + url, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as oauth_server_metadata_response: + if oauth_server_metadata_response.status == 200: + try: + oauth_server_metadata = OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() + ) + oauth_server_metadata_url = url + if ( + oauth_client_metadata.scope is None + and oauth_server_metadata.scopes_supported is not None + ): + oauth_client_metadata.scope = " ".join( + oauth_server_metadata.scopes_supported + ) + break + except Exception as e: + log.error(f"Error parsing OAuth metadata from {url}: {e}") + continue + + registration_url = None + if oauth_server_metadata and oauth_server_metadata.registration_endpoint: + registration_url = str(oauth_server_metadata.registration_endpoint) + else: + _, base_url = get_parsed_and_base_url(oauth_server_url) + registration_url = urllib.parse.urljoin(base_url, "/register") + + registration_data = oauth_client_metadata.model_dump( + exclude_none=True, + mode="json", + by_alias=True, + ) + + # Perform dynamic client registration and return client info + async with aiohttp.ClientSession() as session: + async with session.post( + registration_url, json=registration_data, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as oauth_client_registration_response: + try: + registration_response_json = ( + await oauth_client_registration_response.json() + ) + oauth_client_info = OAuthClientInformationFull.model_validate( + { + **registration_response_json, + **{"issuer": oauth_server_metadata_url}, + } + ) + log.info( + f"Dynamic client registration successful at {registration_url}, client_id: {oauth_client_info.client_id}" + ) + return oauth_client_info + except Exception as e: + error_text = None + try: + error_text = await oauth_client_registration_response.text() + log.error( + f"Dynamic client registration failed at {registration_url}: {oauth_client_registration_response.status} - {error_text}" + ) + except Exception as e: + pass + + log.error(f"Error parsing client registration response: {e}") + raise Exception( + f"Dynamic client registration failed: {error_text}" + if error_text + else "Error parsing client registration response" + ) + raise Exception("Dynamic client registration failed") + except Exception as e: + log.error(f"Exception during dynamic client registration: {e}") + raise e + + +class OAuthClientManager: + def __init__(self, app): + self.oauth = OAuth() + self.app = app + self.clients = {} + + def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull): + if client_id not in self.clients: + self.clients[client_id] = { + "client": self.oauth.register( + name=client_id, + client_id=oauth_client_info.client_id, + client_secret=oauth_client_info.client_secret, + client_kwargs=( + {"scope": oauth_client_info.scope} + if oauth_client_info.scope + else {} + ), + server_metadata_url=( + oauth_client_info.issuer if oauth_client_info.issuer else None + ), + ), + "client_info": oauth_client_info, + } + return self.clients[client_id] + + def remove_client(self, client_id): + if client_id in self.clients: + del self.clients[client_id] + log.info(f"Removed OAuth client {client_id}") + return True + + def get_client(self, client_id): + client = self.clients.get(client_id) + return client["client"] if client else None + + def get_client_info(self, client_id): + client = self.clients.get(client_id) + return client["client_info"] if client else None + + def get_server_metadata_url(self, client_id): + if client_id in self.clients: + client = self.clients[client_id] + return ( + client.server_metadata_url + if hasattr(client, "server_metadata_url") + else None + ) + return None + + async def get_oauth_token( + self, user_id: str, session_id: str, force_refresh: bool = False + ): + """ + Get a valid OAuth token for the user, automatically refreshing if needed. + + Args: + user_id: The user ID + session_id: The OAuth session ID + force_refresh: Force token refresh even if current token appears valid + + Returns: + dict: OAuth token data with access_token, or None if no valid token available + """ + try: + # Get the OAuth session + session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id) + if not session: + log.warning( + f"No OAuth session found for user {user_id}, session {session_id}" + ) + return None + + if force_refresh or datetime.now() + timedelta( + minutes=5 + ) >= datetime.fromtimestamp(session.expires_at): + log.debug( + f"Token refresh needed for user {user_id}, client_id {session.provider}" + ) + refreshed_token = await self._refresh_token(session) + if refreshed_token: + return refreshed_token + else: + log.warning( + f"Token refresh failed for user {user_id}, client_id {session.provider}" + ) + return None + return session.token + + except Exception as e: + log.error(f"Error getting OAuth token for user {user_id}: {e}") + return None + + async def _refresh_token(self, session) -> dict: + """ + Refresh an OAuth token if needed, with concurrency protection. + + Args: + session: The OAuth session object + + Returns: + dict: Refreshed token data, or None if refresh failed + """ + try: + # Perform the actual refresh + refreshed_token = await self._perform_token_refresh(session) + + if refreshed_token: + # Update the session with new token data + session = OAuthSessions.update_session_by_id( + session.id, refreshed_token + ) + log.info(f"Successfully refreshed token for session {session.id}") + return session.token + else: + log.error(f"Failed to refresh token for session {session.id}") + return None + + except Exception as e: + log.error(f"Error refreshing token for session {session.id}: {e}") + return None + + async def _perform_token_refresh(self, session) -> dict: + """ + Perform the actual OAuth token refresh. + + Args: + session: The OAuth session object + + Returns: + dict: New token data, or None if refresh failed + """ + client_id = session.provider + token_data = session.token + + if not token_data.get("refresh_token"): + log.warning(f"No refresh token available for session {session.id}") + return None + + try: + client = self.get_client(client_id) + if not client: + log.error(f"No OAuth client found for provider {client_id}") + return None + + token_endpoint = None + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.get( + self.get_server_metadata_url(client_id) + ) as r: + if r.status == 200: + openid_data = await r.json() + token_endpoint = openid_data.get("token_endpoint") + else: + log.error( + f"Failed to fetch OpenID configuration for client_id {client_id}" + ) + if not token_endpoint: + log.error(f"No token endpoint found for client_id {client_id}") + return None + + # Prepare refresh request + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": token_data["refresh_token"], + "client_id": client.client_id, + } + if hasattr(client, "client_secret") and client.client_secret: + refresh_data["client_secret"] = client.client_secret + + # Make refresh request + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.post( + token_endpoint, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as r: + if r.status == 200: + new_token_data = await r.json() + + # Merge with existing token data (preserve refresh_token if not provided) + if "refresh_token" not in new_token_data: + new_token_data["refresh_token"] = token_data[ + "refresh_token" + ] + + # Add timestamp for tracking + new_token_data["issued_at"] = datetime.now().timestamp() + + # Calculate expires_at if we have expires_in + if ( + "expires_in" in new_token_data + and "expires_at" not in new_token_data + ): + new_token_data["expires_at"] = int( + datetime.now().timestamp() + + new_token_data["expires_in"] + ) + + log.debug(f"Token refresh successful for client_id {client_id}") + return new_token_data + else: + error_text = await r.text() + log.error( + f"Token refresh failed for client_id {client_id}: {r.status} - {error_text}" + ) + return None + + except Exception as e: + log.error(f"Exception during token refresh for client_id {client_id}: {e}") + return None + + async def handle_authorize(self, request, client_id: str) -> RedirectResponse: + client = self.get_client(client_id) + if client is None: + raise HTTPException(404) + + client_info = self.get_client_info(client_id) + if client_info is None: + raise HTTPException(404) + + redirect_uri = ( + client_info.redirect_uris[0] if client_info.redirect_uris else None + ) + return await client.authorize_redirect(request, redirect_uri) + + async def handle_callback(self, request, client_id: str, user_id: str, response): + client = self.get_client(client_id) + if client is None: + raise HTTPException(404) + + error_message = None + try: + token = await client.authorize_access_token(request) + if token: + try: + # Add timestamp for tracking + token["issued_at"] = datetime.now().timestamp() + + # Calculate expires_at if we have expires_in + if "expires_in" in token and "expires_at" not in token: + token["expires_at"] = ( + datetime.now().timestamp() + token["expires_in"] + ) + + # Clean up any existing sessions for this user/client_id first + sessions = OAuthSessions.get_sessions_by_user_id(user_id) + for session in sessions: + if session.provider == client_id: + OAuthSessions.delete_session_by_id(session.id) + + session = OAuthSessions.create_session( + user_id=user_id, + provider=client_id, + token=token, + ) + + log.info( + f"Stored OAuth session server-side for user {user_id}, client_id {client_id}" + ) + except Exception as e: + error_message = "Failed to store OAuth session server-side" + log.error(f"Failed to store OAuth session server-side: {e}") + else: + error_message = "Failed to obtain OAuth token" + log.warning(error_message) + except Exception as e: + error_message = "OAuth callback error" + log.warning(f"OAuth callback error: {e}") + + redirect_base_url = ( + str(request.app.state.config.WEBUI_URL or request.base_url) + ).rstrip("/") + redirect_url = f"{redirect_base_url}/auth" + + if error_message: + redirect_url = f"{redirect_url}?error={error_message}" + return RedirectResponse(url=redirect_url, headers=response.headers) + + response = RedirectResponse(url=redirect_url, headers=response.headers) + + class OAuthManager: def __init__(self, app): self.oauth = OAuth() @@ -792,9 +1250,9 @@ class OAuthManager: else ERROR_MESSAGES.DEFAULT("Error during OAuth process") ) - redirect_base_url = str(request.app.state.config.WEBUI_URL or request.base_url) - if redirect_base_url.endswith("/"): - redirect_base_url = redirect_base_url[:-1] + redirect_base_url = ( + str(request.app.state.config.WEBUI_URL or request.base_url) + ).rstrip("/") redirect_url = f"{redirect_base_url}/auth" if error_message: diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index ef983e63bf..77374c93b6 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -202,6 +202,42 @@ export const verifyToolServerConnection = async (token: string, connection: obje return res; }; +type RegisterOAuthClientForm = { + url: string; + client_id: string; + client_name?: string; +}; + +export const registerOAuthClient = async (token: string, formData: RegisterOAuthClientForm) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...formData + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.error(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getCodeExecutionConfig = async (token: string) => { let error = null; diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte index 01c87010ef..fea4551b3a 100644 --- a/src/lib/components/AddToolServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -13,7 +13,7 @@ import Switch from '$lib/components/common/Switch.svelte'; import Tags from './common/Tags.svelte'; import { getToolServerData } from '$lib/apis'; - import { verifyToolServerConnection } from '$lib/apis/configs'; + import { verifyToolServerConnection, registerOAuthClient } from '$lib/apis/configs'; import AccessControl from './workspace/common/AccessControl.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; import XMark from '$lib/components/icons/XMark.svelte'; @@ -41,10 +41,37 @@ let name = ''; let description = ''; - let enable = true; + let oauthClientInfo = null; + let enable = true; let loading = false; + const registerOAuthClientHandler = async () => { + if (url === '') { + toast.error($i18n.t('Please enter a valid URL')); + return; + } + + if (id === '') { + toast.error($i18n.t('Please enter a valid ID')); + return; + } + + const res = await registerOAuthClient(localStorage.token, { + url: url, + client_id: id + }).catch((err) => { + toast.error($i18n.t('Registration failed')); + return null; + }); + + if (res) { + toast.success($i18n.t('Registration successful')); + console.debug('Registration successful', res); + oauthClientInfo = res?.oauth_client_info ?? null; + } + }; + const verifyHandler = async () => { if (url === '') { toast.error($i18n.t('Please enter a valid URL')); @@ -106,6 +133,12 @@ return; } + if (type === 'mcp' && auth_type === 'oauth_2.1' && !oauthClientInfo) { + toast.error($i18n.t('Please register the OAuth client')); + loading = false; + return; + } + const connection = { url, path, @@ -119,7 +152,8 @@ info: { id: id, name: name, - description: description + description: description, + ...(oauthClientInfo ? { oauth_client_info: oauthClientInfo } : {}) } }; @@ -139,6 +173,7 @@ id = ''; name = ''; description = ''; + oauthClientInfo = null; enable = true; accessControl = null; @@ -156,6 +191,7 @@ id = connection.info?.id ?? ''; name = connection.info?.name ?? ''; description = connection.info?.description ?? ''; + oauthClientInfo = connection.info?.oauth_client_info ?? null; enable = connection.config?.enable ?? true; accessControl = connection.config?.access_control ?? null; @@ -227,25 +263,6 @@ {/if} - {#if type === 'mcp'} -