Added a targeted utility to wipe all OAuth sessions for a provider so the cleanup can remove stale access tokens across every user when a connection is updated

This commit is contained in:
Taylor Wilsdon 2025-10-18 14:00:46 -04:00
parent 40c450e6e5
commit c107a3799f
2 changed files with 61 additions and 1 deletions

View file

@ -262,5 +262,16 @@ class OAuthSessionTable:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False
def delete_sessions_by_provider(self, provider: str) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db() as db:
db.query(OAuthSession).filter_by(provider=provider).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
return False
OAuthSessions = OAuthSessionTable()

View file

@ -1,4 +1,5 @@
import logging
import copy
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
import aiohttp
@ -15,6 +16,7 @@ from open_webui.utils.tools import (
set_tool_servers,
)
from open_webui.utils.mcp.client import MCPClient
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.env import SRC_LOG_LEVELS
@ -165,12 +167,59 @@ async def set_tool_servers_config(
form_data: ToolServersConfigForm,
user=Depends(get_admin_user),
):
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
old_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
new_connections = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
old_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in old_connections
if conn.get("type") == "mcp"
}
new_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in new_connections
if conn.get("type") == "mcp"
}
purge_oauth_clients = set()
for server_id, old_conn in old_mcp_connections.items():
if not server_id:
continue
old_auth_type = old_conn.get("auth_type", "none")
new_conn = new_mcp_connections.get(server_id)
if new_conn is None:
if old_auth_type == "oauth_2.1":
purge_oauth_clients.add(server_id)
continue
new_auth_type = new_conn.get("auth_type", "none")
if old_auth_type == "oauth_2.1":
if (
new_auth_type != "oauth_2.1"
or old_conn.get("url") != new_conn.get("url")
or old_conn.get("info", {}).get("oauth_client_info")
!= new_conn.get("info", {}).get("oauth_client_info")
):
purge_oauth_clients.add(server_id)
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
await set_tool_servers(request)
for server_id in purge_oauth_clients:
client_key = f"mcp:{server_id}"
request.app.state.oauth_client_manager.remove_client(client_key)
OAuthSessions.delete_sessions_by_provider(client_key)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi")
if server_type == "mcp":