This commit is contained in:
Timothy Jaeryang Baek 2025-10-27 15:31:25 -07:00 committed by Stoyan Zlatev
parent f0bf0e3074
commit ec9359e360
2 changed files with 32 additions and 67 deletions

View file

@ -167,59 +167,27 @@ async def set_tool_servers_config(
form_data: ToolServersConfigForm, form_data: ToolServersConfigForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
old_connections = copy.deepcopy( mcp_server_ids = [
request.app.state.config.TOOL_SERVER_CONNECTIONS or [] conn.get("info", {}).get("id")
) for conn in form_data.TOOL_SERVER_CONNECTIONS
if conn.get("type") == "mcp"
]
new_connections = [ for server_id in mcp_server_ids:
# Remove existing OAuth clients for MCP tool servers that are no longer present
client_key = f"mcp:{server_id}"
try:
request.app.state.oauth_client_manager.remove_client(client_key)
except:
pass
# Set new tool server connections
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
] ]
old_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in old_connections
if conn.get("type") == "mcp"
}
new_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in new_connections
if conn.get("type") == "mcp"
}
purge_oauth_clients = set()
for server_id, old_conn in old_mcp_connections.items():
if not server_id:
continue
old_auth_type = old_conn.get("auth_type", "none")
new_conn = new_mcp_connections.get(server_id)
if new_conn is None:
if old_auth_type == "oauth_2.1":
purge_oauth_clients.add(server_id)
continue
new_auth_type = new_conn.get("auth_type", "none")
if old_auth_type == "oauth_2.1":
if (
new_auth_type != "oauth_2.1"
or old_conn.get("url") != new_conn.get("url")
or old_conn.get("info", {}).get("oauth_client_info")
!= new_conn.get("info", {}).get("oauth_client_info")
):
purge_oauth_clients.add(server_id)
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
await set_tool_servers(request) await set_tool_servers(request)
for server_id in purge_oauth_clients:
client_key = f"mcp:{server_id}"
request.app.state.oauth_client_manager.remove_client(client_key)
OAuthSessions.delete_sessions_by_provider(client_key)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi") server_type = connection.get("type", "openapi")
if server_type == "mcp": if server_type == "mcp":

View file

@ -153,32 +153,32 @@ def decrypt_data(data: str):
raise raise
def _build_oauth_callback_error_message(exc: Exception) -> str: def _build_oauth_callback_error_message(e: Exception) -> str:
""" """
Produce a user-facing callback error string with actionable context. Produce a user-facing callback error string with actionable context.
Keeps the message short and strips newlines for safe redirect usage. Keeps the message short and strips newlines for safe redirect usage.
""" """
if isinstance(exc, OAuth2Error): if isinstance(e, OAuth2Error):
parts = [p for p in [exc.error, exc.description] if p] parts = [p for p in [e.error, e.description] if p]
detail = " - ".join(parts) detail = " - ".join(parts)
elif isinstance(exc, HTTPException): elif isinstance(e, HTTPException):
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail) detail = e.detail if isinstance(e.detail, str) else str(e.detail)
elif isinstance(exc, aiohttp.ClientResponseError): elif isinstance(e, aiohttp.ClientResponseError):
detail = f"Upstream provider returned {exc.status}: {exc.message}" detail = f"Upstream provider returned {e.status}: {e.message}"
elif isinstance(exc, aiohttp.ClientError): elif isinstance(e, aiohttp.ClientError):
detail = str(exc) detail = str(e)
elif isinstance(exc, KeyError): elif isinstance(e, KeyError):
missing = str(exc).strip("'") missing = str(e).strip("'")
if missing.lower() == "state": if missing.lower() == "state":
detail = "Missing state parameter in callback (session may have expired)" detail = "Missing state parameter in callback (session may have expired)"
else: else:
detail = f"Missing expected key '{missing}' in OAuth response" detail = f"Missing expected key '{missing}' in OAuth response"
else: else:
detail = str(exc) detail = str(e)
detail = detail.replace("\n", " ").strip() detail = detail.replace("\n", " ").strip()
if not detail: if not detail:
detail = exc.__class__.__name__ detail = e.__class__.__name__
message = f"OAuth callback failed: {detail}" message = f"OAuth callback failed: {detail}"
return message[:197] + "..." if len(message) > 200 else message return message[:197] + "..." if len(message) > 200 else message
@ -402,20 +402,18 @@ class OAuthClientManager:
return self.clients[client_id] return self.clients[client_id]
def remove_client(self, client_id): def remove_client(self, client_id):
removed = False
if client_id in self.clients: if client_id in self.clients:
del self.clients[client_id] del self.clients[client_id]
removed = True log.info(f"Removed OAuth client {client_id}")
if hasattr(self.oauth, "_clients"): if hasattr(self.oauth, "_clients"):
if client_id in self.oauth._clients: if client_id in self.oauth._clients:
self.oauth._clients.pop(client_id, None) self.oauth._clients.pop(client_id, None)
removed = True
if hasattr(self.oauth, "_registry"): if hasattr(self.oauth, "_registry"):
if client_id in self.oauth._registry: if client_id in self.oauth._registry:
self.oauth._registry.pop(client_id, None) self.oauth._registry.pop(client_id, None)
removed = True
if removed:
log.info(f"Removed OAuth client {client_id}")
return True return True
def _find_mcp_connection(self, request, client_id: str): def _find_mcp_connection(self, request, client_id: str):
@ -574,7 +572,6 @@ class OAuthClientManager:
self.remove_client(client_id) self.remove_client(client_id)
self.add_client(client_id, oauth_client_info) self.add_client(client_id, oauth_client_info)
OAuthSessions.delete_sessions_by_provider(client_id)
log.info("Re-registered OAuth client %s for MCP tool server", client_id) log.info("Re-registered OAuth client %s for MCP tool server", client_id)
return True return True