Added a targeted utility to wipe all OAuth sessions for a provider so the cleanup can remove stale access tokens across every user when a connection is updated

This commit is contained in:
Taylor Wilsdon 2025-10-18 14:00:46 -04:00
parent 40c450e6e5
commit c107a3799f
2 changed files with 61 additions and 1 deletions

View file

@ -262,5 +262,16 @@ class OAuthSessionTable:
log.error(f"Error deleting OAuth sessions by user ID: {e}") log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False return False
def delete_sessions_by_provider(self, provider: str) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db() as db:
db.query(OAuthSession).filter_by(provider=provider).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
return False
OAuthSessions = OAuthSessionTable() OAuthSessions = OAuthSessionTable()

View file

@ -1,4 +1,5 @@
import logging import logging
import copy
from fastapi import APIRouter, Depends, Request, HTTPException from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import aiohttp import aiohttp
@ -15,6 +16,7 @@ from open_webui.utils.tools import (
set_tool_servers, set_tool_servers,
) )
from open_webui.utils.mcp.client import MCPClient from open_webui.utils.mcp.client import MCPClient
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -165,12 +167,59 @@ async def set_tool_servers_config(
form_data: ToolServersConfigForm, form_data: ToolServersConfigForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
request.app.state.config.TOOL_SERVER_CONNECTIONS = [ old_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
new_connections = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
] ]
old_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in old_connections
if conn.get("type") == "mcp"
}
new_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in new_connections
if conn.get("type") == "mcp"
}
purge_oauth_clients = set()
for server_id, old_conn in old_mcp_connections.items():
if not server_id:
continue
old_auth_type = old_conn.get("auth_type", "none")
new_conn = new_mcp_connections.get(server_id)
if new_conn is None:
if old_auth_type == "oauth_2.1":
purge_oauth_clients.add(server_id)
continue
new_auth_type = new_conn.get("auth_type", "none")
if old_auth_type == "oauth_2.1":
if (
new_auth_type != "oauth_2.1"
or old_conn.get("url") != new_conn.get("url")
or old_conn.get("info", {}).get("oauth_client_info")
!= new_conn.get("info", {}).get("oauth_client_info")
):
purge_oauth_clients.add(server_id)
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
await set_tool_servers(request) await set_tool_servers(request)
for server_id in purge_oauth_clients:
client_key = f"mcp:{server_id}"
request.app.state.oauth_client_manager.remove_client(client_key)
OAuthSessions.delete_sessions_by_provider(client_key)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi") server_type = connection.get("type", "openapi")
if server_type == "mcp": if server_type == "mcp":