From 77e971dd9fbeee806e2864e686df5ec75e82104b Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 25 Sep 2025 01:49:16 -0500 Subject: [PATCH] feat: oauth2.1 mcp integration --- backend/open_webui/main.py | 60 +++++++++++++- backend/open_webui/models/oauth_sessions.py | 20 +++++ backend/open_webui/routers/configs.py | 30 ++++++- backend/open_webui/routers/tools.py | 26 ++++++ backend/open_webui/utils/middleware.py | 17 ++++ backend/open_webui/utils/oauth.py | 82 ++++++++++--------- src/lib/apis/configs/index.ts | 16 +++- src/lib/components/AddToolServerModal.svelte | 18 +++- .../chat/MessageInput/IntegrationsMenu.svelte | 22 ++++- src/routes/(app)/+page.svelte | 10 +++ 10 files changed, 248 insertions(+), 53 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 08a27331c9..180406c942 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -473,7 +473,12 @@ from open_webui.utils.auth import ( get_verified_user, ) from open_webui.utils.plugin import install_tool_and_function_dependencies -from open_webui.utils.oauth import OAuthManager +from open_webui.utils.oauth import ( + OAuthManager, + OAuthClientManager, + decrypt_data, + OAuthClientInformationFull, +) from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.redis import get_redis_connection @@ -603,9 +608,14 @@ app = FastAPI( lifespan=lifespan, ) +# For Open WebUI OIDC/OAuth2 oauth_manager = OAuthManager(app) app.state.oauth_manager = oauth_manager +# For Integrations +oauth_client_manager = OAuthClientManager(app) +app.state.oauth_client_manager = oauth_client_manager + app.state.instance_id = None app.state.config = AppConfig( redis_url=REDIS_URL, @@ -1881,6 +1891,24 @@ async def get_current_usage(user=Depends(get_verified_user)): # OAuth Login & Callback ############################ + +# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1 +if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0: + for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS: + if tool_server_connection.get("type", "openapi") == "mcp": + server_id = tool_server_connection.get("info", {}).get("id") + auth_type = tool_server_connection.get("auth_type", "none") + if server_id and auth_type == "oauth_2.1": + oauth_client_info = tool_server_connection.get("info", {}).get( + "oauth_client_info" + ) + + oauth_client_info = decrypt_data(oauth_client_info) + app.state.oauth_client_manager.add_client( + f"mcp:{server_id}", OAuthClientInformationFull(**oauth_client_info) + ) + + # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: try: @@ -1913,6 +1941,31 @@ if len(OAUTH_PROVIDERS) > 0: ) +@app.get("/oauth/clients/{client_id}/authorize") +async def oauth_client_authorize( + client_id: str, + request: Request, + response: Response, + user=Depends(get_verified_user), +): + return await oauth_client_manager.handle_authorize(request, client_id=client_id) + + +@app.get("/oauth/clients/{client_id}/callback") +async def oauth_client_callback( + client_id: str, + request: Request, + response: Response, + user=Depends(get_verified_user), +): + return await oauth_client_manager.handle_callback( + request, + client_id=client_id, + user_id=user.id if user else None, + response=response, + ) + + @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): return await oauth_manager.handle_login(request, provider) @@ -1924,8 +1977,9 @@ async def oauth_login(provider: str, request: Request): # - This is considered insecure in general, as OAuth providers do not always verify email addresses # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user # - Email addresses are considered unique, so we fail registration if the email address is already taken -@app.get("/oauth/{provider}/callback") -async def oauth_callback(provider: str, request: Request, response: Response): +@app.get("/oauth/{provider}/callback") # Legacy endpoint +@app.get("/oauth/{provider}/login/callback") +async def oauth_login_callback(provider: str, request: Request, response: Response): return await oauth_manager.handle_callback(request, provider, response) diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index 9fd5335ce5..81ce220384 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -176,6 +176,26 @@ class OAuthSessionTable: log.error(f"Error getting OAuth session by ID: {e}") return None + def get_session_by_provider_and_user_id( + self, provider: str, user_id: str + ) -> Optional[OAuthSessionModel]: + """Get OAuth session by provider and user ID""" + try: + with get_db() as db: + session = ( + db.query(OAuthSession) + .filter_by(provider=provider, user_id=user_id) + .first() + ) + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error getting OAuth session by provider and user ID: {e}") + return None + def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: """Get all OAuth sessions for a user""" try: diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 809a1cc47d..76b9bc77a1 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -21,7 +21,9 @@ from open_webui.env import SRC_LOG_LEVELS from open_webui.utils.oauth import ( get_discovery_urls, get_oauth_client_info_with_dynamic_client_registration, - encrypt_token, + encrypt_data, + decrypt_data, + OAuthClientInformationFull, ) from mcp.shared.auth import OAuthMetadata @@ -103,17 +105,22 @@ class OAuthClientRegistrationForm(BaseModel): async def register_oauth_client( request: Request, form_data: OAuthClientRegistrationForm, + type: Optional[str] = None, user=Depends(get_admin_user), ): try: + oauth_client_id = form_data.client_id + if type: + oauth_client_id = f"{type}:{form_data.client_id}" + oauth_client_info = ( await get_oauth_client_info_with_dynamic_client_registration( - request, form_data.url + request, oauth_client_id, form_data.url ) ) return { "status": True, - "oauth_client_info": encrypt_token( + "oauth_client_info": encrypt_data( oauth_client_info.model_dump(mode="json") ), } @@ -161,8 +168,25 @@ async def set_tool_servers_config( request.app.state.config.TOOL_SERVER_CONNECTIONS = [ connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS ] + await set_tool_servers(request) + for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: + server_type = connection.get("type", "openapi") + if server_type == "mcp": + server_id = connection.get("info", {}).get("id") + auth_type = connection.get("auth_type", "none") + if auth_type == "oauth_2.1" and server_id: + try: + oauth_client_info = decrypt_data(oauth_client_info) + await request.app.state.oauth_client_manager.add_client( + f"{server_type}:{server_id}", + OAuthClientInformationFull(**oauth_client_info), + ) + except Exception as e: + log.debug(f"Failed to add OAuth client for MCP tool server: {e}") + continue + return { "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, } diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index eb03969ad7..b42452b1ac 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, HttpUrl from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.tools import ( ToolForm, ToolModel, @@ -80,6 +81,24 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): # MCP Tool Servers for server in request.app.state.config.TOOL_SERVER_CONNECTIONS: if server.get("type", "openapi") == "mcp": + server_id = server.get("info", {}).get("id") + auth_type = server.get("auth_type", "none") + + session_token = None + if auth_type == "oauth_2.1": + splits = server_id.split(":") + server_id = splits[-1] if len(splits) > 1 else server_id + + session_token = ( + await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f"mcp:{server_id}" + ) + ) + + print("User ID:", user.id) + print("Server ID:", server_id) + print("MCP Session Token:", session_token) + tools.append( ToolUserResponse( **{ @@ -96,6 +115,13 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): ), "updated_at": int(time.time()), "created_at": int(time.time()), + **( + { + "authenticated": session_token is not None, + } + if auth_type == "oauth_2.1" + else {} + ), } ) ) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index b300bfa8d3..509f419b07 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse from starlette.responses import Response, StreamingResponse, JSONResponse +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.chats import Chats from open_webui.models.folders import Folders from open_webui.models.users import Users @@ -1047,6 +1048,22 @@ async def process_chat_payload(request, form_data, user, metadata, model): headers["Authorization"] = ( f"Bearer {oauth_token.get('access_token', '')}" ) + elif auth_type == "oauth_2.1": + try: + splits = server_id.split(":") + server_id = splits[-1] if len(splits) > 1 else server_id + + oauth_token = await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f"mcp:{server_id}" + ) + + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + oauth_token = None mcp_client = MCPClient() await mcp_client.connect( diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 6d5ea470cd..16ec1b18e4 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -126,24 +126,24 @@ except Exception as e: raise -def encrypt_token(token) -> str: - """Encrypt OAuth tokens for storage""" +def encrypt_data(data) -> str: + """Encrypt data for storage""" try: - token_json = json.dumps(token) - encrypted = FERNET.encrypt(token_json.encode()).decode() + data_json = json.dumps(data) + encrypted = FERNET.encrypt(data_json.encode()).decode() return encrypted except Exception as e: - log.error(f"Error encrypting tokens: {e}") + log.error(f"Error encrypting data: {e}") raise -def decrypt_token(token: str): - """Decrypt OAuth tokens from storage""" +def decrypt_data(data: str): + """Decrypt data from storage""" try: - decrypted = FERNET.decrypt(token.encode()).decode() + decrypted = FERNET.decrypt(data.encode()).decode() return json.loads(decrypted) except Exception as e: - log.error(f"Error decrypting tokens: {e}") + log.error(f"Error decrypting data: {e}") raise @@ -212,7 +212,10 @@ def get_discovery_urls(server_url) -> list[str]: # TODO: Some OAuth providers require Initial Access Tokens (IATs) for dynamic client registration. # This is not currently supported. async def get_oauth_client_info_with_dynamic_client_registration( - request, oauth_server_url, oauth_server_key: Optional[str] = None + request, + client_id: str, + oauth_server_url: str, + oauth_server_key: Optional[str] = None, ) -> OAuthClientInformationFull: try: oauth_server_metadata = None @@ -221,9 +224,10 @@ async def get_oauth_client_info_with_dynamic_client_registration( redirect_base_url = ( str(request.app.state.config.WEBUI_URL or request.base_url) ).rstrip("/") + oauth_client_metadata = OAuthClientMetadata( client_name="Open WebUI", - redirect_uris=[f"{redirect_base_url}/oauth/callback"], + redirect_uris=[f"{redirect_base_url}/oauth/clients/{client_id}/callback"], grant_types=["authorization_code", "refresh_token"], response_types=["code"], token_endpoint_auth_method="client_secret_post", @@ -315,23 +319,22 @@ class OAuthClientManager: self.clients = {} def add_client(self, client_id, oauth_client_info: OAuthClientInformationFull): - if client_id not in self.clients: - self.clients[client_id] = { - "client": self.oauth.register( - name=client_id, - client_id=oauth_client_info.client_id, - client_secret=oauth_client_info.client_secret, - client_kwargs=( - {"scope": oauth_client_info.scope} - if oauth_client_info.scope - else {} - ), - server_metadata_url=( - oauth_client_info.issuer if oauth_client_info.issuer else None - ), + self.clients[client_id] = { + "client": self.oauth.register( + name=client_id, + client_id=oauth_client_info.client_id, + client_secret=oauth_client_info.client_secret, + client_kwargs=( + {"scope": oauth_client_info.scope} + if oauth_client_info.scope + else {} ), - "client_info": oauth_client_info, - } + server_metadata_url=( + oauth_client_info.issuer if oauth_client_info.issuer else None + ), + ), + "client_info": oauth_client_info, + } return self.clients[client_id] def remove_client(self, client_id): @@ -359,7 +362,7 @@ class OAuthClientManager: return None async def get_oauth_token( - self, user_id: str, session_id: str, force_refresh: bool = False + self, user_id: str, client_id: str, force_refresh: bool = False ): """ Get a valid OAuth token for the user, automatically refreshing if needed. @@ -374,10 +377,12 @@ class OAuthClientManager: """ try: # Get the OAuth session - session = OAuthSessions.get_session_by_id_and_user_id(session_id, user_id) + session = OAuthSessions.get_session_by_provider_and_user_id( + client_id, user_id + ) if not session: log.warning( - f"No OAuth session found for user {user_id}, session {session_id}" + f"No OAuth session found for user {user_id}, client_id {client_id}" ) return None @@ -392,8 +397,9 @@ class OAuthClientManager: return refreshed_token else: log.warning( - f"Token refresh failed for user {user_id}, client_id {session.provider}" + f"Token refresh failed for user {user_id}, client_id {session.provider}, deleting session {session.id}" ) + OAuthSessions.delete_session_by_id(session.id) return None return session.token @@ -533,7 +539,7 @@ class OAuthClientManager: redirect_uri = ( client_info.redirect_uris[0] if client_info.redirect_uris else None ) - return await client.authorize_redirect(request, redirect_uri) + return await client.authorize_redirect(request, str(redirect_uri)) async def handle_callback(self, request, client_id: str, user_id: str, response): client = self.get_client(client_id) @@ -565,7 +571,6 @@ class OAuthClientManager: provider=client_id, token=token, ) - log.info( f"Stored OAuth session server-side for user {user_id}, client_id {client_id}" ) @@ -579,16 +584,17 @@ class OAuthClientManager: error_message = "OAuth callback error" log.warning(f"OAuth callback error: {e}") - redirect_base_url = ( + redirect_url = ( str(request.app.state.config.WEBUI_URL or request.base_url) ).rstrip("/") - redirect_url = f"{redirect_base_url}/auth" if error_message: - redirect_url = f"{redirect_url}?error={error_message}" + log.debug(error_message) + redirect_url = f"{redirect_url}/?error={error_message}" return RedirectResponse(url=redirect_url, headers=response.headers) response = RedirectResponse(url=redirect_url, headers=response.headers) + return response class OAuthManager: @@ -649,8 +655,10 @@ class OAuthManager: return refreshed_token else: log.warning( - f"Token refresh failed for user {user_id}, provider {session.provider}" + f"Token refresh failed for user {user_id}, provider {session.provider}, deleting session {session.id}" ) + OAuthSessions.delete_session_by_id(session.id) + return None return session.token diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index 77374c93b6..c6cfdd2b2b 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -1,4 +1,4 @@ -import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import type { Banner } from '$lib/types'; export const importConfig = async (token: string, config) => { @@ -208,10 +208,15 @@ type RegisterOAuthClientForm = { client_name?: string; }; -export const registerOAuthClient = async (token: string, formData: RegisterOAuthClientForm) => { +export const registerOAuthClient = async ( + token: string, + formData: RegisterOAuthClientForm, + type: null | string = null +) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register`, { + const searchParams = type ? `?type=${type}` : ''; + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/oauth/clients/register${searchParams}`, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -238,6 +243,11 @@ export const registerOAuthClient = async (token: string, formData: RegisterOAuth return res; }; +export const getOAuthClientAuthorizationUrl = (clientId: string, type: null | string = null) => { + const oauthClientId = type ? `${type}:${clientId}` : clientId; + return `${WEBUI_BASE_URL}/oauth/clients/${oauthClientId}/authorize`; +}; + export const getCodeExecutionConfig = async (token: string) => { let error = null; diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte index fea4551b3a..0ddcd9025b 100644 --- a/src/lib/components/AddToolServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -57,16 +57,26 @@ return; } - const res = await registerOAuthClient(localStorage.token, { - url: url, - client_id: id - }).catch((err) => { + const res = await registerOAuthClient( + localStorage.token, + { + url: url, + client_id: id + }, + 'mcp' + ).catch((err) => { toast.error($i18n.t('Registration failed')); return null; }); if (res) { + toast.warning( + $i18n.t( + 'Please save the connection to persist the OAuth client information and do not change the ID' + ) + ); toast.success($i18n.t('Registration successful')); + console.debug('Registration successful', res); oauthClientInfo = res?.oauth_client_info ?? null; } diff --git a/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte b/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte index 7f38d82745..698d1985e1 100644 --- a/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte +++ b/src/lib/components/chat/MessageInput/IntegrationsMenu.svelte @@ -20,6 +20,8 @@ import ChevronRight from '$lib/components/icons/ChevronRight.svelte'; import ChevronLeft from '$lib/components/icons/ChevronLeft.svelte'; import ValvesModal from '$lib/components/workspace/common/ValvesModal.svelte'; + import { getOAuthClientAuthorizationUrl } from '$lib/apis/configs'; + import { partition } from 'd3-hierarchy'; const i18n = getContext('i18n'); @@ -321,11 +323,25 @@ {#each Object.keys(tools) as toolId}