fix: Add proactive OAuth token refresh for MCP sessions

Add a background task that periodically refreshes OAuth tokens before
they expire, preventing users from having to re-authenticate when MCP
OAuth tokens (like Notion) expire after 1 hour.

Changes:
- Add get_expiring_sessions() method to OAuthSessionTable to query
  sessions expiring within a specified time window
- Enhance OAuthClientManager._perform_token_refresh() to handle
  unregistered MCP clients by discovering OAuth info from stored
  tool server config
- Add periodic_oauth_token_refresh() background task that runs every
  5 minutes and refreshes tokens expiring within 10 minutes
- Start the background task in app lifespan with proper cleanup

Fixes #19809
This commit is contained in:
root 2025-12-08 18:29:07 +11:00
parent ce945a9334
commit 0fe0fcff4d
3 changed files with 200 additions and 17 deletions

View file

@ -504,6 +504,7 @@ from open_webui.utils.oauth import (
OAuthManager,
OAuthClientManager,
OAuthClientInformationFull,
periodic_oauth_token_refresh,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
from open_webui.utils.redis import get_redis_connection
@ -599,6 +600,11 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup())
# Start background task for proactive OAuth token refresh
app.state.oauth_token_refresh_task = asyncio.create_task(
periodic_oauth_token_refresh(app)
)
if app.state.config.ENABLE_BASE_MODELS_CACHE:
await get_all_models(
Request(
@ -622,6 +628,10 @@ async def lifespan(app: FastAPI):
yield
# Cleanup background tasks on shutdown
if hasattr(app.state, "oauth_token_refresh_task"):
app.state.oauth_token_refresh_task.cancel()
if hasattr(app.state, "redis_task_command_listener"):
app.state.redis_task_command_listener.cancel()

View file

@ -273,5 +273,40 @@ class OAuthSessionTable:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
return False
def get_expiring_sessions(self, minutes: int = 10) -> List[OAuthSessionModel]:
"""Get all OAuth sessions expiring within the specified minutes.
Args:
minutes: Number of minutes from now to check for expiring sessions.
Returns:
List of OAuth sessions that will expire within the specified time window.
"""
try:
with get_db() as db:
current_time = int(time.time())
expiry_threshold = current_time + (minutes * 60)
sessions = (
db.query(OAuthSession)
.filter(OAuthSession.expires_at <= expiry_threshold)
.filter(OAuthSession.expires_at > current_time) # Not already expired
.all()
)
results = []
for session in sessions:
try:
session.token = self._decrypt_token(session.token)
results.append(OAuthSessionModel.model_validate(session))
except Exception as e:
log.error(f"Error decrypting token for session {session.id}: {e}")
continue
return results
except Exception as e:
log.error(f"Error getting expiring OAuth sessions: {e}")
return []
OAuthSessions = OAuthSessionTable()

View file

@ -645,34 +645,87 @@ class OAuthClientManager:
try:
client = self.get_client(client_id)
if not client:
log.error(f"No OAuth client found for provider {client_id}")
return None
client_info = self.get_client_info(client_id) if client else None
# Get token endpoint and client credentials
token_endpoint = None
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.get(
self.get_server_metadata_url(client_id)
) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
else:
log.error(
f"Failed to fetch OpenID configuration for client_id {client_id}"
)
oauth_client_id = None
oauth_client_secret = None
if client and client_info:
# Client is registered, use stored info
server_metadata_url = self.get_server_metadata_url(client_id)
if server_metadata_url:
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.get(
server_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
oauth_client_id = client.client_id if hasattr(client, 'client_id') else None
oauth_client_secret = client.client_secret if hasattr(client, 'client_secret') else None
# If client not registered but this is an MCP OAuth session, try to discover token endpoint
if not token_endpoint and client_id.startswith("mcp:"):
# For MCP clients, try to get info from app config or discover from well-known URL
# Extract server ID from client_id (e.g., "mcp:ntn" -> "ntn")
server_id = client_id.replace("mcp:", "")
log.debug(f"MCP client {client_id} not registered, attempting discovery for server {server_id}")
# Try to get the tool server config to find the OAuth info
if hasattr(self.app, 'state') and hasattr(self.app.state, 'config'):
tool_servers = getattr(self.app.state.config, 'TOOL_SERVER_CONNECTIONS', [])
for server in tool_servers:
server_info = server.get('info', {})
if server_info.get('id') == server_id and server.get('auth_type') == 'oauth_2.1':
# Found the server config, try to get oauth_client_info
encrypted_client_info = server_info.get('oauth_client_info')
if encrypted_client_info:
try:
client_info_data = decrypt_data(encrypted_client_info)
oauth_client_id = client_info_data.get('client_id')
oauth_client_secret = client_info_data.get('client_secret')
# Get token endpoint from server metadata
server_metadata = client_info_data.get('server_metadata', {})
if isinstance(server_metadata, dict):
token_endpoint = server_metadata.get('token_endpoint')
# If no token endpoint in stored metadata, try discovery
if not token_endpoint:
issuer = client_info_data.get('issuer')
if issuer:
async with aiohttp.ClientSession(trust_env=True) as session_http:
async with session_http.get(
issuer, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as r:
if r.status == 200:
openid_data = await r.json()
token_endpoint = openid_data.get("token_endpoint")
log.debug(f"Discovered OAuth info for MCP server {server_id}")
break
except Exception as e:
log.warning(f"Failed to decrypt oauth_client_info for {server_id}: {e}")
if not token_endpoint:
log.error(f"No token endpoint found for client_id {client_id}")
return None
if not oauth_client_id:
log.error(f"No OAuth client_id found for {client_id}")
return None
# Prepare refresh request
refresh_data = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": client.client_id,
"client_id": oauth_client_id,
}
if hasattr(client, "client_secret") and client.client_secret:
refresh_data["client_secret"] = client.client_secret
if oauth_client_secret:
refresh_data["client_secret"] = oauth_client_secret
# Make refresh request
async with aiohttp.ClientSession(trust_env=True) as session_http:
@ -1585,3 +1638,88 @@ class OAuthManager:
log.error(f"Failed to store OAuth session server-side: {e}")
return response
async def periodic_oauth_token_refresh(app, interval_minutes: int = 5):
"""
Background task that periodically refreshes OAuth tokens before they expire.
This task runs every `interval_minutes` and checks for OAuth sessions that
will expire within the next 10 minutes. For each expiring session, it attempts
to refresh the token using the stored refresh_token.
Args:
app: The FastAPI application instance (needed to access oauth_client_manager)
interval_minutes: How often to run the refresh check (default: 5 minutes)
"""
import asyncio
log.info(
f"Starting periodic OAuth token refresh task (interval: {interval_minutes} minutes)"
)
while True:
try:
await asyncio.sleep(interval_minutes * 60)
# Get sessions expiring within the next 10 minutes
expiring_sessions = OAuthSessions.get_expiring_sessions(minutes=10)
if expiring_sessions:
log.info(
f"Found {len(expiring_sessions)} OAuth session(s) expiring soon, attempting refresh..."
)
for session in expiring_sessions:
try:
# Check if this is an MCP OAuth session (provider starts with "mcp:")
if session.provider.startswith("mcp:"):
# Use OAuthClientManager for MCP sessions
if hasattr(app.state, "oauth_client_manager"):
oauth_client_manager = app.state.oauth_client_manager
refreshed = await oauth_client_manager._refresh_token(
session
)
if refreshed:
log.info(
f"Successfully refreshed MCP OAuth token for session {session.id} (provider: {session.provider})"
)
else:
log.warning(
f"Failed to refresh MCP OAuth token for session {session.id} (provider: {session.provider})"
)
else:
log.debug(
f"oauth_client_manager not available, skipping MCP session {session.id}"
)
else:
# Use OAuthManager for standard OAuth sessions (login providers)
if hasattr(app.state, "oauth_manager"):
oauth_manager = app.state.oauth_manager
refreshed = await oauth_manager._refresh_token(session)
if refreshed:
log.info(
f"Successfully refreshed OAuth token for session {session.id} (provider: {session.provider})"
)
else:
log.warning(
f"Failed to refresh OAuth token for session {session.id} (provider: {session.provider})"
)
else:
log.debug(
f"oauth_manager not available, skipping session {session.id}"
)
except Exception as e:
log.error(
f"Error refreshing OAuth token for session {session.id}: {e}"
)
continue
except asyncio.CancelledError:
log.info("Periodic OAuth token refresh task cancelled")
break
except Exception as e:
log.error(f"Error in periodic OAuth token refresh task: {e}")
# Continue running even if there's an error
continue