mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-14 21:35:19 +00:00
refac
This commit is contained in:
parent
f0bf0e3074
commit
ec9359e360
2 changed files with 32 additions and 67 deletions
|
|
@ -167,59 +167,27 @@ async def set_tool_servers_config(
|
||||||
form_data: ToolServersConfigForm,
|
form_data: ToolServersConfigForm,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
old_connections = copy.deepcopy(
|
mcp_server_ids = [
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
conn.get("info", {}).get("id")
|
||||||
)
|
for conn in form_data.TOOL_SERVER_CONNECTIONS
|
||||||
|
if conn.get("type") == "mcp"
|
||||||
|
]
|
||||||
|
|
||||||
new_connections = [
|
for server_id in mcp_server_ids:
|
||||||
|
# Remove existing OAuth clients for MCP tool servers that are no longer present
|
||||||
|
client_key = f"mcp:{server_id}"
|
||||||
|
try:
|
||||||
|
request.app.state.oauth_client_manager.remove_client(client_key)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Set new tool server connections
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||||
]
|
]
|
||||||
|
|
||||||
old_mcp_connections = {
|
|
||||||
conn.get("info", {}).get("id"): conn
|
|
||||||
for conn in old_connections
|
|
||||||
if conn.get("type") == "mcp"
|
|
||||||
}
|
|
||||||
new_mcp_connections = {
|
|
||||||
conn.get("info", {}).get("id"): conn
|
|
||||||
for conn in new_connections
|
|
||||||
if conn.get("type") == "mcp"
|
|
||||||
}
|
|
||||||
|
|
||||||
purge_oauth_clients = set()
|
|
||||||
|
|
||||||
for server_id, old_conn in old_mcp_connections.items():
|
|
||||||
if not server_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
old_auth_type = old_conn.get("auth_type", "none")
|
|
||||||
new_conn = new_mcp_connections.get(server_id)
|
|
||||||
|
|
||||||
if new_conn is None:
|
|
||||||
if old_auth_type == "oauth_2.1":
|
|
||||||
purge_oauth_clients.add(server_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_auth_type = new_conn.get("auth_type", "none")
|
|
||||||
|
|
||||||
if old_auth_type == "oauth_2.1":
|
|
||||||
if (
|
|
||||||
new_auth_type != "oauth_2.1"
|
|
||||||
or old_conn.get("url") != new_conn.get("url")
|
|
||||||
or old_conn.get("info", {}).get("oauth_client_info")
|
|
||||||
!= new_conn.get("info", {}).get("oauth_client_info")
|
|
||||||
):
|
|
||||||
purge_oauth_clients.add(server_id)
|
|
||||||
|
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
|
|
||||||
|
|
||||||
await set_tool_servers(request)
|
await set_tool_servers(request)
|
||||||
|
|
||||||
for server_id in purge_oauth_clients:
|
|
||||||
client_key = f"mcp:{server_id}"
|
|
||||||
request.app.state.oauth_client_manager.remove_client(client_key)
|
|
||||||
OAuthSessions.delete_sessions_by_provider(client_key)
|
|
||||||
|
|
||||||
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
server_type = connection.get("type", "openapi")
|
server_type = connection.get("type", "openapi")
|
||||||
if server_type == "mcp":
|
if server_type == "mcp":
|
||||||
|
|
|
||||||
|
|
@ -153,32 +153,32 @@ def decrypt_data(data: str):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _build_oauth_callback_error_message(exc: Exception) -> str:
|
def _build_oauth_callback_error_message(e: Exception) -> str:
|
||||||
"""
|
"""
|
||||||
Produce a user-facing callback error string with actionable context.
|
Produce a user-facing callback error string with actionable context.
|
||||||
Keeps the message short and strips newlines for safe redirect usage.
|
Keeps the message short and strips newlines for safe redirect usage.
|
||||||
"""
|
"""
|
||||||
if isinstance(exc, OAuth2Error):
|
if isinstance(e, OAuth2Error):
|
||||||
parts = [p for p in [exc.error, exc.description] if p]
|
parts = [p for p in [e.error, e.description] if p]
|
||||||
detail = " - ".join(parts)
|
detail = " - ".join(parts)
|
||||||
elif isinstance(exc, HTTPException):
|
elif isinstance(e, HTTPException):
|
||||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
detail = e.detail if isinstance(e.detail, str) else str(e.detail)
|
||||||
elif isinstance(exc, aiohttp.ClientResponseError):
|
elif isinstance(e, aiohttp.ClientResponseError):
|
||||||
detail = f"Upstream provider returned {exc.status}: {exc.message}"
|
detail = f"Upstream provider returned {e.status}: {e.message}"
|
||||||
elif isinstance(exc, aiohttp.ClientError):
|
elif isinstance(e, aiohttp.ClientError):
|
||||||
detail = str(exc)
|
detail = str(e)
|
||||||
elif isinstance(exc, KeyError):
|
elif isinstance(e, KeyError):
|
||||||
missing = str(exc).strip("'")
|
missing = str(e).strip("'")
|
||||||
if missing.lower() == "state":
|
if missing.lower() == "state":
|
||||||
detail = "Missing state parameter in callback (session may have expired)"
|
detail = "Missing state parameter in callback (session may have expired)"
|
||||||
else:
|
else:
|
||||||
detail = f"Missing expected key '{missing}' in OAuth response"
|
detail = f"Missing expected key '{missing}' in OAuth response"
|
||||||
else:
|
else:
|
||||||
detail = str(exc)
|
detail = str(e)
|
||||||
|
|
||||||
detail = detail.replace("\n", " ").strip()
|
detail = detail.replace("\n", " ").strip()
|
||||||
if not detail:
|
if not detail:
|
||||||
detail = exc.__class__.__name__
|
detail = e.__class__.__name__
|
||||||
|
|
||||||
message = f"OAuth callback failed: {detail}"
|
message = f"OAuth callback failed: {detail}"
|
||||||
return message[:197] + "..." if len(message) > 200 else message
|
return message[:197] + "..." if len(message) > 200 else message
|
||||||
|
|
@ -402,20 +402,18 @@ class OAuthClientManager:
|
||||||
return self.clients[client_id]
|
return self.clients[client_id]
|
||||||
|
|
||||||
def remove_client(self, client_id):
|
def remove_client(self, client_id):
|
||||||
removed = False
|
|
||||||
if client_id in self.clients:
|
if client_id in self.clients:
|
||||||
del self.clients[client_id]
|
del self.clients[client_id]
|
||||||
removed = True
|
log.info(f"Removed OAuth client {client_id}")
|
||||||
|
|
||||||
if hasattr(self.oauth, "_clients"):
|
if hasattr(self.oauth, "_clients"):
|
||||||
if client_id in self.oauth._clients:
|
if client_id in self.oauth._clients:
|
||||||
self.oauth._clients.pop(client_id, None)
|
self.oauth._clients.pop(client_id, None)
|
||||||
removed = True
|
|
||||||
if hasattr(self.oauth, "_registry"):
|
if hasattr(self.oauth, "_registry"):
|
||||||
if client_id in self.oauth._registry:
|
if client_id in self.oauth._registry:
|
||||||
self.oauth._registry.pop(client_id, None)
|
self.oauth._registry.pop(client_id, None)
|
||||||
removed = True
|
|
||||||
if removed:
|
|
||||||
log.info(f"Removed OAuth client {client_id}")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _find_mcp_connection(self, request, client_id: str):
|
def _find_mcp_connection(self, request, client_id: str):
|
||||||
|
|
@ -574,7 +572,6 @@ class OAuthClientManager:
|
||||||
|
|
||||||
self.remove_client(client_id)
|
self.remove_client(client_id)
|
||||||
self.add_client(client_id, oauth_client_info)
|
self.add_client(client_id, oauth_client_info)
|
||||||
OAuthSessions.delete_sessions_by_provider(client_id)
|
|
||||||
|
|
||||||
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue