diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index dc4468e8e7..4716373b52 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -504,6 +504,7 @@ from open_webui.utils.oauth import ( OAuthManager, OAuthClientManager, OAuthClientInformationFull, + periodic_oauth_token_refresh, ) from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.redis import get_redis_connection @@ -599,6 +600,11 @@ async def lifespan(app: FastAPI): asyncio.create_task(periodic_usage_pool_cleanup()) + # Start background task for proactive OAuth token refresh + app.state.oauth_token_refresh_task = asyncio.create_task( + periodic_oauth_token_refresh(app) + ) + if app.state.config.ENABLE_BASE_MODELS_CACHE: await get_all_models( Request( @@ -622,6 +628,10 @@ async def lifespan(app: FastAPI): yield + # Cleanup background tasks on shutdown + if hasattr(app.state, "oauth_token_refresh_task"): + app.state.oauth_token_refresh_task.cancel() + if hasattr(app.state, "redis_task_command_listener"): app.state.redis_task_command_listener.cancel() diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index d07faad35e..797aea5a3c 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -273,5 +273,40 @@ class OAuthSessionTable: log.error(f"Error deleting OAuth sessions by provider {provider}: {e}") return False + def get_expiring_sessions(self, minutes: int = 10) -> List[OAuthSessionModel]: + """Get all OAuth sessions expiring within the specified minutes. + + Args: + minutes: Number of minutes from now to check for expiring sessions. + + Returns: + List of OAuth sessions that will expire within the specified time window. + """ + try: + with get_db() as db: + current_time = int(time.time()) + expiry_threshold = current_time + (minutes * 60) + + sessions = ( + db.query(OAuthSession) + .filter(OAuthSession.expires_at <= expiry_threshold) + .filter(OAuthSession.expires_at > current_time) # Not already expired + .all() + ) + + results = [] + for session in sessions: + try: + session.token = self._decrypt_token(session.token) + results.append(OAuthSessionModel.model_validate(session)) + except Exception as e: + log.error(f"Error decrypting token for session {session.id}: {e}") + continue + + return results + except Exception as e: + log.error(f"Error getting expiring OAuth sessions: {e}") + return [] + OAuthSessions = OAuthSessionTable() diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 1ef5268bae..c536031e4a 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -645,34 +645,87 @@ class OAuthClientManager: try: client = self.get_client(client_id) - if not client: - log.error(f"No OAuth client found for provider {client_id}") - return None + client_info = self.get_client_info(client_id) if client else None + # Get token endpoint and client credentials token_endpoint = None - async with aiohttp.ClientSession(trust_env=True) as session_http: - async with session_http.get( - self.get_server_metadata_url(client_id) - ) as r: - if r.status == 200: - openid_data = await r.json() - token_endpoint = openid_data.get("token_endpoint") - else: - log.error( - f"Failed to fetch OpenID configuration for client_id {client_id}" - ) + oauth_client_id = None + oauth_client_secret = None + + if client and client_info: + # Client is registered, use stored info + server_metadata_url = self.get_server_metadata_url(client_id) + if server_metadata_url: + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.get( + server_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as r: + if r.status == 200: + openid_data = await r.json() + token_endpoint = openid_data.get("token_endpoint") + + oauth_client_id = client.client_id if hasattr(client, 'client_id') else None + oauth_client_secret = client.client_secret if hasattr(client, 'client_secret') else None + + # If client not registered but this is an MCP OAuth session, try to discover token endpoint + if not token_endpoint and client_id.startswith("mcp:"): + # For MCP clients, try to get info from app config or discover from well-known URL + # Extract server ID from client_id (e.g., "mcp:ntn" -> "ntn") + server_id = client_id.replace("mcp:", "") + log.debug(f"MCP client {client_id} not registered, attempting discovery for server {server_id}") + + # Try to get the tool server config to find the OAuth info + if hasattr(self.app, 'state') and hasattr(self.app.state, 'config'): + tool_servers = getattr(self.app.state.config, 'TOOL_SERVER_CONNECTIONS', []) + for server in tool_servers: + server_info = server.get('info', {}) + if server_info.get('id') == server_id and server.get('auth_type') == 'oauth_2.1': + # Found the server config, try to get oauth_client_info + encrypted_client_info = server_info.get('oauth_client_info') + if encrypted_client_info: + try: + client_info_data = decrypt_data(encrypted_client_info) + oauth_client_id = client_info_data.get('client_id') + oauth_client_secret = client_info_data.get('client_secret') + + # Get token endpoint from server metadata + server_metadata = client_info_data.get('server_metadata', {}) + if isinstance(server_metadata, dict): + token_endpoint = server_metadata.get('token_endpoint') + + # If no token endpoint in stored metadata, try discovery + if not token_endpoint: + issuer = client_info_data.get('issuer') + if issuer: + async with aiohttp.ClientSession(trust_env=True) as session_http: + async with session_http.get( + issuer, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as r: + if r.status == 200: + openid_data = await r.json() + token_endpoint = openid_data.get("token_endpoint") + + log.debug(f"Discovered OAuth info for MCP server {server_id}") + break + except Exception as e: + log.warning(f"Failed to decrypt oauth_client_info for {server_id}: {e}") + if not token_endpoint: log.error(f"No token endpoint found for client_id {client_id}") return None + if not oauth_client_id: + log.error(f"No OAuth client_id found for {client_id}") + return None + # Prepare refresh request refresh_data = { "grant_type": "refresh_token", "refresh_token": token_data["refresh_token"], - "client_id": client.client_id, + "client_id": oauth_client_id, } - if hasattr(client, "client_secret") and client.client_secret: - refresh_data["client_secret"] = client.client_secret + if oauth_client_secret: + refresh_data["client_secret"] = oauth_client_secret # Make refresh request async with aiohttp.ClientSession(trust_env=True) as session_http: @@ -1585,3 +1638,88 @@ class OAuthManager: log.error(f"Failed to store OAuth session server-side: {e}") return response + + +async def periodic_oauth_token_refresh(app, interval_minutes: int = 5): + """ + Background task that periodically refreshes OAuth tokens before they expire. + + This task runs every `interval_minutes` and checks for OAuth sessions that + will expire within the next 10 minutes. For each expiring session, it attempts + to refresh the token using the stored refresh_token. + + Args: + app: The FastAPI application instance (needed to access oauth_client_manager) + interval_minutes: How often to run the refresh check (default: 5 minutes) + """ + import asyncio + + log.info( + f"Starting periodic OAuth token refresh task (interval: {interval_minutes} minutes)" + ) + + while True: + try: + await asyncio.sleep(interval_minutes * 60) + + # Get sessions expiring within the next 10 minutes + expiring_sessions = OAuthSessions.get_expiring_sessions(minutes=10) + + if expiring_sessions: + log.info( + f"Found {len(expiring_sessions)} OAuth session(s) expiring soon, attempting refresh..." + ) + + for session in expiring_sessions: + try: + # Check if this is an MCP OAuth session (provider starts with "mcp:") + if session.provider.startswith("mcp:"): + # Use OAuthClientManager for MCP sessions + if hasattr(app.state, "oauth_client_manager"): + oauth_client_manager = app.state.oauth_client_manager + refreshed = await oauth_client_manager._refresh_token( + session + ) + if refreshed: + log.info( + f"Successfully refreshed MCP OAuth token for session {session.id} (provider: {session.provider})" + ) + else: + log.warning( + f"Failed to refresh MCP OAuth token for session {session.id} (provider: {session.provider})" + ) + else: + log.debug( + f"oauth_client_manager not available, skipping MCP session {session.id}" + ) + else: + # Use OAuthManager for standard OAuth sessions (login providers) + if hasattr(app.state, "oauth_manager"): + oauth_manager = app.state.oauth_manager + refreshed = await oauth_manager._refresh_token(session) + if refreshed: + log.info( + f"Successfully refreshed OAuth token for session {session.id} (provider: {session.provider})" + ) + else: + log.warning( + f"Failed to refresh OAuth token for session {session.id} (provider: {session.provider})" + ) + else: + log.debug( + f"oauth_manager not available, skipping session {session.id}" + ) + + except Exception as e: + log.error( + f"Error refreshing OAuth token for session {session.id}: {e}" + ) + continue + + except asyncio.CancelledError: + log.info("Periodic OAuth token refresh task cancelled") + break + except Exception as e: + log.error(f"Error in periodic OAuth token refresh task: {e}") + # Continue running even if there's an error + continue