mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 12:55:19 +00:00
fix: Add proactive OAuth token refresh for MCP sessions
Add a background task that periodically refreshes OAuth tokens before they expire, preventing users from having to re-authenticate when MCP OAuth tokens (like Notion) expire after 1 hour. Changes: - Add get_expiring_sessions() method to OAuthSessionTable to query sessions expiring within a specified time window - Enhance OAuthClientManager._perform_token_refresh() to handle unregistered MCP clients by discovering OAuth info from stored tool server config - Add periodic_oauth_token_refresh() background task that runs every 5 minutes and refreshes tokens expiring within 10 minutes - Start the background task in app lifespan with proper cleanup Fixes #19809
This commit is contained in:
parent
ce945a9334
commit
0fe0fcff4d
3 changed files with 200 additions and 17 deletions
|
|
@ -504,6 +504,7 @@ from open_webui.utils.oauth import (
|
|||
OAuthManager,
|
||||
OAuthClientManager,
|
||||
OAuthClientInformationFull,
|
||||
periodic_oauth_token_refresh,
|
||||
)
|
||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
|
|
@ -599,6 +600,11 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||
|
||||
# Start background task for proactive OAuth token refresh
|
||||
app.state.oauth_token_refresh_task = asyncio.create_task(
|
||||
periodic_oauth_token_refresh(app)
|
||||
)
|
||||
|
||||
if app.state.config.ENABLE_BASE_MODELS_CACHE:
|
||||
await get_all_models(
|
||||
Request(
|
||||
|
|
@ -622,6 +628,10 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
yield
|
||||
|
||||
# Cleanup background tasks on shutdown
|
||||
if hasattr(app.state, "oauth_token_refresh_task"):
|
||||
app.state.oauth_token_refresh_task.cancel()
|
||||
|
||||
if hasattr(app.state, "redis_task_command_listener"):
|
||||
app.state.redis_task_command_listener.cancel()
|
||||
|
||||
|
|
|
|||
|
|
@ -273,5 +273,40 @@ class OAuthSessionTable:
|
|||
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
||||
return False
|
||||
|
||||
def get_expiring_sessions(self, minutes: int = 10) -> List[OAuthSessionModel]:
|
||||
"""Get all OAuth sessions expiring within the specified minutes.
|
||||
|
||||
Args:
|
||||
minutes: Number of minutes from now to check for expiring sessions.
|
||||
|
||||
Returns:
|
||||
List of OAuth sessions that will expire within the specified time window.
|
||||
"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
current_time = int(time.time())
|
||||
expiry_threshold = current_time + (minutes * 60)
|
||||
|
||||
sessions = (
|
||||
db.query(OAuthSession)
|
||||
.filter(OAuthSession.expires_at <= expiry_threshold)
|
||||
.filter(OAuthSession.expires_at > current_time) # Not already expired
|
||||
.all()
|
||||
)
|
||||
|
||||
results = []
|
||||
for session in sessions:
|
||||
try:
|
||||
session.token = self._decrypt_token(session.token)
|
||||
results.append(OAuthSessionModel.model_validate(session))
|
||||
except Exception as e:
|
||||
log.error(f"Error decrypting token for session {session.id}: {e}")
|
||||
continue
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error getting expiring OAuth sessions: {e}")
|
||||
return []
|
||||
|
||||
|
||||
OAuthSessions = OAuthSessionTable()
|
||||
|
|
|
|||
|
|
@ -645,34 +645,87 @@ class OAuthClientManager:
|
|||
|
||||
try:
|
||||
client = self.get_client(client_id)
|
||||
if not client:
|
||||
log.error(f"No OAuth client found for provider {client_id}")
|
||||
return None
|
||||
client_info = self.get_client_info(client_id) if client else None
|
||||
|
||||
# Get token endpoint and client credentials
|
||||
token_endpoint = None
|
||||
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||
async with session_http.get(
|
||||
self.get_server_metadata_url(client_id)
|
||||
) as r:
|
||||
if r.status == 200:
|
||||
openid_data = await r.json()
|
||||
token_endpoint = openid_data.get("token_endpoint")
|
||||
else:
|
||||
log.error(
|
||||
f"Failed to fetch OpenID configuration for client_id {client_id}"
|
||||
)
|
||||
oauth_client_id = None
|
||||
oauth_client_secret = None
|
||||
|
||||
if client and client_info:
|
||||
# Client is registered, use stored info
|
||||
server_metadata_url = self.get_server_metadata_url(client_id)
|
||||
if server_metadata_url:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||
async with session_http.get(
|
||||
server_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as r:
|
||||
if r.status == 200:
|
||||
openid_data = await r.json()
|
||||
token_endpoint = openid_data.get("token_endpoint")
|
||||
|
||||
oauth_client_id = client.client_id if hasattr(client, 'client_id') else None
|
||||
oauth_client_secret = client.client_secret if hasattr(client, 'client_secret') else None
|
||||
|
||||
# If client not registered but this is an MCP OAuth session, try to discover token endpoint
|
||||
if not token_endpoint and client_id.startswith("mcp:"):
|
||||
# For MCP clients, try to get info from app config or discover from well-known URL
|
||||
# Extract server ID from client_id (e.g., "mcp:ntn" -> "ntn")
|
||||
server_id = client_id.replace("mcp:", "")
|
||||
log.debug(f"MCP client {client_id} not registered, attempting discovery for server {server_id}")
|
||||
|
||||
# Try to get the tool server config to find the OAuth info
|
||||
if hasattr(self.app, 'state') and hasattr(self.app.state, 'config'):
|
||||
tool_servers = getattr(self.app.state.config, 'TOOL_SERVER_CONNECTIONS', [])
|
||||
for server in tool_servers:
|
||||
server_info = server.get('info', {})
|
||||
if server_info.get('id') == server_id and server.get('auth_type') == 'oauth_2.1':
|
||||
# Found the server config, try to get oauth_client_info
|
||||
encrypted_client_info = server_info.get('oauth_client_info')
|
||||
if encrypted_client_info:
|
||||
try:
|
||||
client_info_data = decrypt_data(encrypted_client_info)
|
||||
oauth_client_id = client_info_data.get('client_id')
|
||||
oauth_client_secret = client_info_data.get('client_secret')
|
||||
|
||||
# Get token endpoint from server metadata
|
||||
server_metadata = client_info_data.get('server_metadata', {})
|
||||
if isinstance(server_metadata, dict):
|
||||
token_endpoint = server_metadata.get('token_endpoint')
|
||||
|
||||
# If no token endpoint in stored metadata, try discovery
|
||||
if not token_endpoint:
|
||||
issuer = client_info_data.get('issuer')
|
||||
if issuer:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||
async with session_http.get(
|
||||
issuer, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as r:
|
||||
if r.status == 200:
|
||||
openid_data = await r.json()
|
||||
token_endpoint = openid_data.get("token_endpoint")
|
||||
|
||||
log.debug(f"Discovered OAuth info for MCP server {server_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to decrypt oauth_client_info for {server_id}: {e}")
|
||||
|
||||
if not token_endpoint:
|
||||
log.error(f"No token endpoint found for client_id {client_id}")
|
||||
return None
|
||||
|
||||
if not oauth_client_id:
|
||||
log.error(f"No OAuth client_id found for {client_id}")
|
||||
return None
|
||||
|
||||
# Prepare refresh request
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": client.client_id,
|
||||
"client_id": oauth_client_id,
|
||||
}
|
||||
if hasattr(client, "client_secret") and client.client_secret:
|
||||
refresh_data["client_secret"] = client.client_secret
|
||||
if oauth_client_secret:
|
||||
refresh_data["client_secret"] = oauth_client_secret
|
||||
|
||||
# Make refresh request
|
||||
async with aiohttp.ClientSession(trust_env=True) as session_http:
|
||||
|
|
@ -1585,3 +1638,88 @@ class OAuthManager:
|
|||
log.error(f"Failed to store OAuth session server-side: {e}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def periodic_oauth_token_refresh(app, interval_minutes: int = 5):
|
||||
"""
|
||||
Background task that periodically refreshes OAuth tokens before they expire.
|
||||
|
||||
This task runs every `interval_minutes` and checks for OAuth sessions that
|
||||
will expire within the next 10 minutes. For each expiring session, it attempts
|
||||
to refresh the token using the stored refresh_token.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application instance (needed to access oauth_client_manager)
|
||||
interval_minutes: How often to run the refresh check (default: 5 minutes)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
log.info(
|
||||
f"Starting periodic OAuth token refresh task (interval: {interval_minutes} minutes)"
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(interval_minutes * 60)
|
||||
|
||||
# Get sessions expiring within the next 10 minutes
|
||||
expiring_sessions = OAuthSessions.get_expiring_sessions(minutes=10)
|
||||
|
||||
if expiring_sessions:
|
||||
log.info(
|
||||
f"Found {len(expiring_sessions)} OAuth session(s) expiring soon, attempting refresh..."
|
||||
)
|
||||
|
||||
for session in expiring_sessions:
|
||||
try:
|
||||
# Check if this is an MCP OAuth session (provider starts with "mcp:")
|
||||
if session.provider.startswith("mcp:"):
|
||||
# Use OAuthClientManager for MCP sessions
|
||||
if hasattr(app.state, "oauth_client_manager"):
|
||||
oauth_client_manager = app.state.oauth_client_manager
|
||||
refreshed = await oauth_client_manager._refresh_token(
|
||||
session
|
||||
)
|
||||
if refreshed:
|
||||
log.info(
|
||||
f"Successfully refreshed MCP OAuth token for session {session.id} (provider: {session.provider})"
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Failed to refresh MCP OAuth token for session {session.id} (provider: {session.provider})"
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
f"oauth_client_manager not available, skipping MCP session {session.id}"
|
||||
)
|
||||
else:
|
||||
# Use OAuthManager for standard OAuth sessions (login providers)
|
||||
if hasattr(app.state, "oauth_manager"):
|
||||
oauth_manager = app.state.oauth_manager
|
||||
refreshed = await oauth_manager._refresh_token(session)
|
||||
if refreshed:
|
||||
log.info(
|
||||
f"Successfully refreshed OAuth token for session {session.id} (provider: {session.provider})"
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Failed to refresh OAuth token for session {session.id} (provider: {session.provider})"
|
||||
)
|
||||
else:
|
||||
log.debug(
|
||||
f"oauth_manager not available, skipping session {session.id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error refreshing OAuth token for session {session.id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Periodic OAuth token refresh task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
log.error(f"Error in periodic OAuth token refresh task: {e}")
|
||||
# Continue running even if there's an error
|
||||
continue
|
||||
|
|
|
|||
Loading…
Reference in a new issue