mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 12:55:19 +00:00
refac
This commit is contained in:
parent
bfadbc9934
commit
92aafd6c06
2 changed files with 32 additions and 67 deletions
|
|
@ -167,59 +167,27 @@ async def set_tool_servers_config(
|
|||
form_data: ToolServersConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
old_connections = copy.deepcopy(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||
)
|
||||
mcp_server_ids = [
|
||||
conn.get("info", {}).get("id")
|
||||
for conn in form_data.TOOL_SERVER_CONNECTIONS
|
||||
if conn.get("type") == "mcp"
|
||||
]
|
||||
|
||||
new_connections = [
|
||||
for server_id in mcp_server_ids:
|
||||
# Remove existing OAuth clients for MCP tool servers that are no longer present
|
||||
client_key = f"mcp:{server_id}"
|
||||
try:
|
||||
request.app.state.oauth_client_manager.remove_client(client_key)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Set new tool server connections
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
old_mcp_connections = {
|
||||
conn.get("info", {}).get("id"): conn
|
||||
for conn in old_connections
|
||||
if conn.get("type") == "mcp"
|
||||
}
|
||||
new_mcp_connections = {
|
||||
conn.get("info", {}).get("id"): conn
|
||||
for conn in new_connections
|
||||
if conn.get("type") == "mcp"
|
||||
}
|
||||
|
||||
purge_oauth_clients = set()
|
||||
|
||||
for server_id, old_conn in old_mcp_connections.items():
|
||||
if not server_id:
|
||||
continue
|
||||
|
||||
old_auth_type = old_conn.get("auth_type", "none")
|
||||
new_conn = new_mcp_connections.get(server_id)
|
||||
|
||||
if new_conn is None:
|
||||
if old_auth_type == "oauth_2.1":
|
||||
purge_oauth_clients.add(server_id)
|
||||
continue
|
||||
|
||||
new_auth_type = new_conn.get("auth_type", "none")
|
||||
|
||||
if old_auth_type == "oauth_2.1":
|
||||
if (
|
||||
new_auth_type != "oauth_2.1"
|
||||
or old_conn.get("url") != new_conn.get("url")
|
||||
or old_conn.get("info", {}).get("oauth_client_info")
|
||||
!= new_conn.get("info", {}).get("oauth_client_info")
|
||||
):
|
||||
purge_oauth_clients.add(server_id)
|
||||
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
|
||||
|
||||
await set_tool_servers(request)
|
||||
|
||||
for server_id in purge_oauth_clients:
|
||||
client_key = f"mcp:{server_id}"
|
||||
request.app.state.oauth_client_manager.remove_client(client_key)
|
||||
OAuthSessions.delete_sessions_by_provider(client_key)
|
||||
|
||||
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||
server_type = connection.get("type", "openapi")
|
||||
if server_type == "mcp":
|
||||
|
|
|
|||
|
|
@ -153,32 +153,32 @@ def decrypt_data(data: str):
|
|||
raise
|
||||
|
||||
|
||||
def _build_oauth_callback_error_message(exc: Exception) -> str:
|
||||
def _build_oauth_callback_error_message(e: Exception) -> str:
|
||||
"""
|
||||
Produce a user-facing callback error string with actionable context.
|
||||
Keeps the message short and strips newlines for safe redirect usage.
|
||||
"""
|
||||
if isinstance(exc, OAuth2Error):
|
||||
parts = [p for p in [exc.error, exc.description] if p]
|
||||
if isinstance(e, OAuth2Error):
|
||||
parts = [p for p in [e.error, e.description] if p]
|
||||
detail = " - ".join(parts)
|
||||
elif isinstance(exc, HTTPException):
|
||||
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||
elif isinstance(exc, aiohttp.ClientResponseError):
|
||||
detail = f"Upstream provider returned {exc.status}: {exc.message}"
|
||||
elif isinstance(exc, aiohttp.ClientError):
|
||||
detail = str(exc)
|
||||
elif isinstance(exc, KeyError):
|
||||
missing = str(exc).strip("'")
|
||||
elif isinstance(e, HTTPException):
|
||||
detail = e.detail if isinstance(e.detail, str) else str(e.detail)
|
||||
elif isinstance(e, aiohttp.ClientResponseError):
|
||||
detail = f"Upstream provider returned {e.status}: {e.message}"
|
||||
elif isinstance(e, aiohttp.ClientError):
|
||||
detail = str(e)
|
||||
elif isinstance(e, KeyError):
|
||||
missing = str(e).strip("'")
|
||||
if missing.lower() == "state":
|
||||
detail = "Missing state parameter in callback (session may have expired)"
|
||||
else:
|
||||
detail = f"Missing expected key '{missing}' in OAuth response"
|
||||
else:
|
||||
detail = str(exc)
|
||||
detail = str(e)
|
||||
|
||||
detail = detail.replace("\n", " ").strip()
|
||||
if not detail:
|
||||
detail = exc.__class__.__name__
|
||||
detail = e.__class__.__name__
|
||||
|
||||
message = f"OAuth callback failed: {detail}"
|
||||
return message[:197] + "..." if len(message) > 200 else message
|
||||
|
|
@ -402,20 +402,18 @@ class OAuthClientManager:
|
|||
return self.clients[client_id]
|
||||
|
||||
def remove_client(self, client_id):
|
||||
removed = False
|
||||
if client_id in self.clients:
|
||||
del self.clients[client_id]
|
||||
removed = True
|
||||
log.info(f"Removed OAuth client {client_id}")
|
||||
|
||||
if hasattr(self.oauth, "_clients"):
|
||||
if client_id in self.oauth._clients:
|
||||
self.oauth._clients.pop(client_id, None)
|
||||
removed = True
|
||||
|
||||
if hasattr(self.oauth, "_registry"):
|
||||
if client_id in self.oauth._registry:
|
||||
self.oauth._registry.pop(client_id, None)
|
||||
removed = True
|
||||
if removed:
|
||||
log.info(f"Removed OAuth client {client_id}")
|
||||
|
||||
return True
|
||||
|
||||
def _find_mcp_connection(self, request, client_id: str):
|
||||
|
|
@ -574,7 +572,6 @@ class OAuthClientManager:
|
|||
|
||||
self.remove_client(client_id)
|
||||
self.add_client(client_id, oauth_client_info)
|
||||
OAuthSessions.delete_sessions_by_provider(client_id)
|
||||
|
||||
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in a new issue