This commit is contained in:
Timothy Jaeryang Baek 2025-10-27 15:38:59 -07:00
parent 92aafd6c06
commit c8b2313362
3 changed files with 17 additions and 13 deletions

View file

@ -1941,6 +1941,7 @@ if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0:
if tool_server_connection.get("type", "openapi") == "mcp": if tool_server_connection.get("type", "openapi") == "mcp":
server_id = tool_server_connection.get("info", {}).get("id") server_id = tool_server_connection.get("info", {}).get("id")
auth_type = tool_server_connection.get("auth_type", "none") auth_type = tool_server_connection.get("auth_type", "none")
if server_id and auth_type == "oauth_2.1": if server_id and auth_type == "oauth_2.1":
oauth_client_info = tool_server_connection.get("info", {}).get( oauth_client_info = tool_server_connection.get("info", {}).get(
"oauth_client_info", "" "oauth_client_info", ""

View file

@ -167,15 +167,15 @@ async def set_tool_servers_config(
form_data: ToolServersConfigForm, form_data: ToolServersConfigForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
mcp_server_ids = [ for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
conn.get("info", {}).get("id") server_type = connection.get("type", "openapi")
for conn in form_data.TOOL_SERVER_CONNECTIONS auth_type = connection.get("auth_type", "none")
if conn.get("type") == "mcp"
] if auth_type == "oauth_2.1":
# Remove existing OAuth clients for tool servers
server_id = connection.get("info", {}).get("id")
client_key = f"{server_type}:{server_id}"
for server_id in mcp_server_ids:
# Remove existing OAuth clients for MCP tool servers that are no longer present
client_key = f"mcp:{server_id}"
try: try:
request.app.state.oauth_client_manager.remove_client(client_key) request.app.state.oauth_client_manager.remove_client(client_key)
except: except:
@ -193,6 +193,7 @@ async def set_tool_servers_config(
if server_type == "mcp": if server_type == "mcp":
server_id = connection.get("info", {}).get("id") server_id = connection.get("info", {}).get("id")
auth_type = connection.get("auth_type", "none") auth_type = connection.get("auth_type", "none")
if auth_type == "oauth_2.1" and server_id: if auth_type == "oauth_2.1" and server_id:
try: try:
oauth_client_info = connection.get("info", {}).get( oauth_client_info = connection.get("info", {}).get(

View file

@ -582,6 +582,7 @@ class OAuthClientManager:
client = self.get_client(client_id) client = self.get_client(client_id)
client_info = self.get_client_info(client_id) client_info = self.get_client_info(client_id)
if client is None or client_info is None: if client is None or client_info is None:
raise HTTPException(status.HTTP_404_NOT_FOUND) raise HTTPException(status.HTTP_404_NOT_FOUND)
@ -593,6 +594,7 @@ class OAuthClientManager:
"Detected invalid OAuth client %s; attempting re-registration", "Detected invalid OAuth client %s; attempting re-registration",
client_id, client_id,
) )
re_registered = await self._re_register_client(request, client_id) re_registered = await self._re_register_client(request, client_id)
if not re_registered: if not re_registered:
raise HTTPException( raise HTTPException(
@ -799,7 +801,7 @@ class OAuthClientManager:
return None return None
async def handle_authorize(self, request, client_id: str) -> RedirectResponse: async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
await self._ensure_valid_client_registration(request, client_id) # await self._ensure_valid_client_registration(request, client_id)
client = self.get_client(client_id) client = self.get_client(client_id)
if client is None: if client is None: