This commit is contained in:
Timothy Jaeryang Baek 2025-10-27 15:31:25 -07:00
parent bfadbc9934
commit 92aafd6c06
2 changed files with 32 additions and 67 deletions

View file

@ -167,59 +167,27 @@ async def set_tool_servers_config(
form_data: ToolServersConfigForm,
user=Depends(get_admin_user),
):
old_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
mcp_server_ids = [
conn.get("info", {}).get("id")
for conn in form_data.TOOL_SERVER_CONNECTIONS
if conn.get("type") == "mcp"
]
new_connections = [
for server_id in mcp_server_ids:
# Remove existing OAuth clients for MCP tool servers that are no longer present
client_key = f"mcp:{server_id}"
try:
request.app.state.oauth_client_manager.remove_client(client_key)
except:
pass
# Set new tool server connections
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
old_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in old_connections
if conn.get("type") == "mcp"
}
new_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in new_connections
if conn.get("type") == "mcp"
}
purge_oauth_clients = set()
for server_id, old_conn in old_mcp_connections.items():
if not server_id:
continue
old_auth_type = old_conn.get("auth_type", "none")
new_conn = new_mcp_connections.get(server_id)
if new_conn is None:
if old_auth_type == "oauth_2.1":
purge_oauth_clients.add(server_id)
continue
new_auth_type = new_conn.get("auth_type", "none")
if old_auth_type == "oauth_2.1":
if (
new_auth_type != "oauth_2.1"
or old_conn.get("url") != new_conn.get("url")
or old_conn.get("info", {}).get("oauth_client_info")
!= new_conn.get("info", {}).get("oauth_client_info")
):
purge_oauth_clients.add(server_id)
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
await set_tool_servers(request)
for server_id in purge_oauth_clients:
client_key = f"mcp:{server_id}"
request.app.state.oauth_client_manager.remove_client(client_key)
OAuthSessions.delete_sessions_by_provider(client_key)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi")
if server_type == "mcp":

View file

@ -153,32 +153,32 @@ def decrypt_data(data: str):
raise
def _build_oauth_callback_error_message(exc: Exception) -> str:
def _build_oauth_callback_error_message(e: Exception) -> str:
"""
Produce a user-facing callback error string with actionable context.
Keeps the message short and strips newlines for safe redirect usage.
"""
if isinstance(exc, OAuth2Error):
parts = [p for p in [exc.error, exc.description] if p]
if isinstance(e, OAuth2Error):
parts = [p for p in [e.error, e.description] if p]
detail = " - ".join(parts)
elif isinstance(exc, HTTPException):
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
elif isinstance(exc, aiohttp.ClientResponseError):
detail = f"Upstream provider returned {exc.status}: {exc.message}"
elif isinstance(exc, aiohttp.ClientError):
detail = str(exc)
elif isinstance(exc, KeyError):
missing = str(exc).strip("'")
elif isinstance(e, HTTPException):
detail = e.detail if isinstance(e.detail, str) else str(e.detail)
elif isinstance(e, aiohttp.ClientResponseError):
detail = f"Upstream provider returned {e.status}: {e.message}"
elif isinstance(e, aiohttp.ClientError):
detail = str(e)
elif isinstance(e, KeyError):
missing = str(e).strip("'")
if missing.lower() == "state":
detail = "Missing state parameter in callback (session may have expired)"
else:
detail = f"Missing expected key '{missing}' in OAuth response"
else:
detail = str(exc)
detail = str(e)
detail = detail.replace("\n", " ").strip()
if not detail:
detail = exc.__class__.__name__
detail = e.__class__.__name__
message = f"OAuth callback failed: {detail}"
return message[:197] + "..." if len(message) > 200 else message
@ -402,20 +402,18 @@ class OAuthClientManager:
return self.clients[client_id]
def remove_client(self, client_id):
removed = False
if client_id in self.clients:
del self.clients[client_id]
removed = True
log.info(f"Removed OAuth client {client_id}")
if hasattr(self.oauth, "_clients"):
if client_id in self.oauth._clients:
self.oauth._clients.pop(client_id, None)
removed = True
if hasattr(self.oauth, "_registry"):
if client_id in self.oauth._registry:
self.oauth._registry.pop(client_id, None)
removed = True
if removed:
log.info(f"Removed OAuth client {client_id}")
return True
def _find_mcp_connection(self, request, client_id: str):
@ -574,7 +572,6 @@ class OAuthClientManager:
self.remove_client(client_id)
self.add_client(client_id, oauth_client_info)
OAuthSessions.delete_sessions_by_provider(client_id)
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
return True