Added a preflight authorize check that automatically re-registers MCP OAuth clients when the stored client ID no longer exists on the server, so the browser flow never hits the stale-ID failure

This commit is contained in:
Taylor Wilsdon 2025-10-18 16:53:44 -04:00
parent d49fb9c010
commit ecbf74dbea

View file

@ -1,4 +1,5 @@
import base64
import copy
import hashlib
import logging
import mimetypes
@ -417,6 +418,205 @@ class OAuthClientManager:
log.info(f"Removed OAuth client {client_id}")
return True
def _find_mcp_connection(self, request, client_id: str):
try:
connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
except Exception:
connections = []
normalized_client_id = client_id.split(":")[-1]
for idx, connection in enumerate(connections):
if not isinstance(connection, dict):
continue
if connection.get("type") != "mcp":
continue
info = connection.get("info") or {}
server_id = info.get("id")
if not server_id:
continue
normalized_server_id = server_id.split(":")[-1]
if normalized_server_id == normalized_client_id:
return idx, connection
return None, None
async def _preflight_authorization_url(
self, client, client_info: OAuthClientInformationFull
) -> bool:
# Only perform preflight checks for Starlette OAuth clients
if not hasattr(client, "create_authorization_url"):
return True
redirect_uri = None
if client_info.redirect_uris:
redirect_uri = str(client_info.redirect_uris[0])
try:
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
authorize_url = auth_data.get("url")
if not authorize_url:
return True
except Exception as e:
log.debug(
"Skipping OAuth preflight for client %s: %s",
client_info.client_id,
e,
)
return True
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
authorize_url,
allow_redirects=False,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as resp:
if resp.status < 400:
return True
body_text = await resp.text()
error = None
error_description = ""
content_type = resp.headers.get("content-type", "")
if "application/json" in content_type:
try:
payload = json.loads(body_text)
error = payload.get("error")
error_description = payload.get(
"error_description", ""
)
except json.JSONDecodeError:
error = None
error_description = ""
else:
error_description = body_text
combined = f"{error or ''} {error_description}".lower()
if "invalid_client" in combined or "invalid client" in combined or "client id" in combined:
log.warning(
"OAuth client preflight detected invalid registration for %s: %s %s",
client_info.client_id,
error,
error_description,
)
return False
except Exception as e:
log.debug(
"Skipping OAuth preflight network check for client %s: %s",
client_info.client_id,
e,
)
return True
async def _re_register_client(self, request, client_id: str) -> bool:
idx, connection = self._find_mcp_connection(request, client_id)
if idx is None or connection is None:
log.warning(
"Unable to locate MCP tool server configuration for client %s during re-registration",
client_id,
)
return False
server_url = connection.get("url")
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
try:
oauth_client_info = (
await get_oauth_client_info_with_dynamic_client_registration(
request,
client_id,
server_url,
oauth_server_key,
)
)
except Exception as e:
log.error(
"Dynamic client re-registration failed for %s: %s",
client_id,
e,
)
return False
encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
updated_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
if idx >= len(updated_connections):
log.error(
"MCP tool server index %s out of range during OAuth client re-registration for %s",
idx,
client_id,
)
return False
updated_connection = copy.deepcopy(connection)
updated_connection.setdefault("info", {})
updated_connection["info"]["oauth_client_info"] = encrypted_info
updated_connections[idx] = updated_connection
try:
request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
except Exception as e:
log.error(
"Failed to persist updated OAuth client info for %s: %s",
client_id,
e,
)
return False
self.remove_client(client_id)
self.add_client(client_id, oauth_client_info)
OAuthSessions.delete_sessions_by_provider(client_id)
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
return True
async def _ensure_valid_client_registration(
self, request, client_id: str
) -> None:
if not client_id.startswith("mcp:"):
return
client = self.get_client(client_id)
client_info = self.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
is_valid = await self._preflight_authorization_url(client, client_info)
if is_valid:
return
log.info(
"Detected invalid OAuth client %s; attempting re-registration",
client_id,
)
re_registered = await self._re_register_client(request, client_id)
if not re_registered:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to re-register OAuth client",
)
client = self.get_client(client_id)
client_info = self.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client unavailable after re-registration",
)
if not await self._preflight_authorization_url(client, client_info):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client registration is still invalid after re-registration",
)
def get_client(self, client_id):
client = self.clients.get(client_id)
return client["client"] if client else None
@ -602,10 +802,11 @@ class OAuthClientManager:
return None
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
await self._ensure_valid_client_registration(request, client_id)
client = self.get_client(client_id)
if client is None:
raise HTTPException(404)
client_info = self.get_client_info(client_id)
if client_info is None:
raise HTTPException(404)
@ -613,7 +814,8 @@ class OAuthClientManager:
redirect_uri = (
client_info.redirect_uris[0] if client_info.redirect_uris else None
)
return await client.authorize_redirect(request, str(redirect_uri))
redirect_uri_str = str(redirect_uri) if redirect_uri else None
return await client.authorize_redirect(request, redirect_uri_str)
async def handle_callback(self, request, client_id: str, user_id: str, response):
client = self.get_client(client_id)