mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
Added a targeted utility to wipe all OAuth sessions for a provider so the cleanup can remove stale access tokens across every user when a connection is updated
This commit is contained in:
parent
40c450e6e5
commit
c107a3799f
2 changed files with 61 additions and 1 deletions
|
|
@ -262,5 +262,16 @@ class OAuthSessionTable:
|
|||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||
return False
|
||||
|
||||
def delete_sessions_by_provider(self, provider: str) -> bool:
|
||||
"""Delete all OAuth sessions for a provider"""
|
||||
try:
|
||||
with get_db() as db:
|
||||
db.query(OAuthSession).filter_by(provider=provider).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
OAuthSessions = OAuthSessionTable()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import copy
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
import aiohttp
|
||||
|
|
@ -15,6 +16,7 @@ from open_webui.utils.tools import (
|
|||
set_tool_servers,
|
||||
)
|
||||
from open_webui.utils.mcp.client import MCPClient
|
||||
from open_webui.models.oauth_sessions import OAuthSessions
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
|
@ -165,12 +167,59 @@ async def set_tool_servers_config(
|
|||
form_data: ToolServersConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
old_connections = copy.deepcopy(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||
)
|
||||
|
||||
new_connections = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
old_mcp_connections = {
|
||||
conn.get("info", {}).get("id"): conn
|
||||
for conn in old_connections
|
||||
if conn.get("type") == "mcp"
|
||||
}
|
||||
new_mcp_connections = {
|
||||
conn.get("info", {}).get("id"): conn
|
||||
for conn in new_connections
|
||||
if conn.get("type") == "mcp"
|
||||
}
|
||||
|
||||
purge_oauth_clients = set()
|
||||
|
||||
for server_id, old_conn in old_mcp_connections.items():
|
||||
if not server_id:
|
||||
continue
|
||||
|
||||
old_auth_type = old_conn.get("auth_type", "none")
|
||||
new_conn = new_mcp_connections.get(server_id)
|
||||
|
||||
if new_conn is None:
|
||||
if old_auth_type == "oauth_2.1":
|
||||
purge_oauth_clients.add(server_id)
|
||||
continue
|
||||
|
||||
new_auth_type = new_conn.get("auth_type", "none")
|
||||
|
||||
if old_auth_type == "oauth_2.1":
|
||||
if (
|
||||
new_auth_type != "oauth_2.1"
|
||||
or old_conn.get("url") != new_conn.get("url")
|
||||
or old_conn.get("info", {}).get("oauth_client_info")
|
||||
!= new_conn.get("info", {}).get("oauth_client_info")
|
||||
):
|
||||
purge_oauth_clients.add(server_id)
|
||||
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
|
||||
|
||||
await set_tool_servers(request)
|
||||
|
||||
for server_id in purge_oauth_clients:
|
||||
client_key = f"mcp:{server_id}"
|
||||
request.app.state.oauth_client_manager.remove_client(client_key)
|
||||
OAuthSessions.delete_sessions_by_provider(client_key)
|
||||
|
||||
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||
server_type = connection.get("type", "openapi")
|
||||
if server_type == "mcp":
|
||||
|
|
|
|||
Loading…
Reference in a new issue