mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
Added a preflight authorize check that automatically re-registers MCP OAuth clients when the stored client ID no longer exists on the server, so the browser flow never hits the stale-ID failure
This commit is contained in:
parent
d49fb9c010
commit
ecbf74dbea
1 changed files with 204 additions and 2 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import copy
|
||||
import hashlib
|
||||
import logging
|
||||
import mimetypes
|
||||
|
|
@ -417,6 +418,205 @@ class OAuthClientManager:
|
|||
log.info(f"Removed OAuth client {client_id}")
|
||||
return True
|
||||
|
||||
def _find_mcp_connection(self, request, client_id: str):
|
||||
try:
|
||||
connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||
except Exception:
|
||||
connections = []
|
||||
|
||||
normalized_client_id = client_id.split(":")[-1]
|
||||
|
||||
for idx, connection in enumerate(connections):
|
||||
if not isinstance(connection, dict):
|
||||
continue
|
||||
if connection.get("type") != "mcp":
|
||||
continue
|
||||
|
||||
info = connection.get("info") or {}
|
||||
server_id = info.get("id")
|
||||
if not server_id:
|
||||
continue
|
||||
|
||||
normalized_server_id = server_id.split(":")[-1]
|
||||
if normalized_server_id == normalized_client_id:
|
||||
return idx, connection
|
||||
|
||||
return None, None
|
||||
|
||||
async def _preflight_authorization_url(
|
||||
self, client, client_info: OAuthClientInformationFull
|
||||
) -> bool:
|
||||
# Only perform preflight checks for Starlette OAuth clients
|
||||
if not hasattr(client, "create_authorization_url"):
|
||||
return True
|
||||
|
||||
redirect_uri = None
|
||||
if client_info.redirect_uris:
|
||||
redirect_uri = str(client_info.redirect_uris[0])
|
||||
|
||||
try:
|
||||
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
|
||||
authorize_url = auth_data.get("url")
|
||||
if not authorize_url:
|
||||
return True
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Skipping OAuth preflight for client %s: %s",
|
||||
client_info.client_id,
|
||||
e,
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
authorize_url,
|
||||
allow_redirects=False,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as resp:
|
||||
if resp.status < 400:
|
||||
return True
|
||||
|
||||
body_text = await resp.text()
|
||||
error = None
|
||||
error_description = ""
|
||||
content_type = resp.headers.get("content-type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
payload = json.loads(body_text)
|
||||
error = payload.get("error")
|
||||
error_description = payload.get(
|
||||
"error_description", ""
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
error = None
|
||||
error_description = ""
|
||||
else:
|
||||
error_description = body_text
|
||||
|
||||
combined = f"{error or ''} {error_description}".lower()
|
||||
if "invalid_client" in combined or "invalid client" in combined or "client id" in combined:
|
||||
log.warning(
|
||||
"OAuth client preflight detected invalid registration for %s: %s %s",
|
||||
client_info.client_id,
|
||||
error,
|
||||
error_description,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
log.debug(
|
||||
"Skipping OAuth preflight network check for client %s: %s",
|
||||
client_info.client_id,
|
||||
e,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _re_register_client(self, request, client_id: str) -> bool:
|
||||
idx, connection = self._find_mcp_connection(request, client_id)
|
||||
if idx is None or connection is None:
|
||||
log.warning(
|
||||
"Unable to locate MCP tool server configuration for client %s during re-registration",
|
||||
client_id,
|
||||
)
|
||||
return False
|
||||
|
||||
server_url = connection.get("url")
|
||||
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
|
||||
|
||||
try:
|
||||
oauth_client_info = (
|
||||
await get_oauth_client_info_with_dynamic_client_registration(
|
||||
request,
|
||||
client_id,
|
||||
server_url,
|
||||
oauth_server_key,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Dynamic client re-registration failed for %s: %s",
|
||||
client_id,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
|
||||
|
||||
updated_connections = copy.deepcopy(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||
)
|
||||
if idx >= len(updated_connections):
|
||||
log.error(
|
||||
"MCP tool server index %s out of range during OAuth client re-registration for %s",
|
||||
idx,
|
||||
client_id,
|
||||
)
|
||||
return False
|
||||
|
||||
updated_connection = copy.deepcopy(connection)
|
||||
updated_connection.setdefault("info", {})
|
||||
updated_connection["info"]["oauth_client_info"] = encrypted_info
|
||||
updated_connections[idx] = updated_connection
|
||||
|
||||
try:
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Failed to persist updated OAuth client info for %s: %s",
|
||||
client_id,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
self.remove_client(client_id)
|
||||
self.add_client(client_id, oauth_client_info)
|
||||
OAuthSessions.delete_sessions_by_provider(client_id)
|
||||
|
||||
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
||||
return True
|
||||
|
||||
async def _ensure_valid_client_registration(
|
||||
self, request, client_id: str
|
||||
) -> None:
|
||||
if not client_id.startswith("mcp:"):
|
||||
return
|
||||
|
||||
client = self.get_client(client_id)
|
||||
client_info = self.get_client_info(client_id)
|
||||
if client is None or client_info is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND)
|
||||
|
||||
is_valid = await self._preflight_authorization_url(client, client_info)
|
||||
if is_valid:
|
||||
return
|
||||
|
||||
log.info(
|
||||
"Detected invalid OAuth client %s; attempting re-registration",
|
||||
client_id,
|
||||
)
|
||||
re_registered = await self._re_register_client(request, client_id)
|
||||
if not re_registered:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to re-register OAuth client",
|
||||
)
|
||||
|
||||
client = self.get_client(client_id)
|
||||
client_info = self.get_client_info(client_id)
|
||||
if client is None or client_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth client unavailable after re-registration",
|
||||
)
|
||||
|
||||
if not await self._preflight_authorization_url(client, client_info):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth client registration is still invalid after re-registration",
|
||||
)
|
||||
|
||||
def get_client(self, client_id):
|
||||
client = self.clients.get(client_id)
|
||||
return client["client"] if client else None
|
||||
|
|
@ -602,10 +802,11 @@ class OAuthClientManager:
|
|||
return None
|
||||
|
||||
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
||||
await self._ensure_valid_client_registration(request, client_id)
|
||||
|
||||
client = self.get_client(client_id)
|
||||
if client is None:
|
||||
raise HTTPException(404)
|
||||
|
||||
client_info = self.get_client_info(client_id)
|
||||
if client_info is None:
|
||||
raise HTTPException(404)
|
||||
|
|
@ -613,7 +814,8 @@ class OAuthClientManager:
|
|||
redirect_uri = (
|
||||
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
||||
)
|
||||
return await client.authorize_redirect(request, str(redirect_uri))
|
||||
redirect_uri_str = str(redirect_uri) if redirect_uri else None
|
||||
return await client.authorize_redirect(request, redirect_uri_str)
|
||||
|
||||
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
||||
client = self.get_client(client_id)
|
||||
|
|
|
|||
Loading…
Reference in a new issue