Merge pull request #18415 from taylorwilsdon/oauth_error_handling_enh

enh: More detailed OAuth2.1 tool callback error handling + fix for editing existing tools
This commit is contained in:
Tim Baek 2025-10-27 15:13:33 -07:00 committed by GitHub
commit bfadbc9934
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 327 additions and 7 deletions

View file

@ -262,5 +262,16 @@ class OAuthSessionTable:
log.error(f"Error deleting OAuth sessions by user ID: {e}") log.error(f"Error deleting OAuth sessions by user ID: {e}")
return False return False
def delete_sessions_by_provider(self, provider: str) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db() as db:
db.query(OAuthSession).filter_by(provider=provider).delete()
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
return False
OAuthSessions = OAuthSessionTable() OAuthSessions = OAuthSessionTable()

View file

@ -1,4 +1,5 @@
import logging import logging
import copy
from fastapi import APIRouter, Depends, Request, HTTPException from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import aiohttp import aiohttp
@ -15,6 +16,7 @@ from open_webui.utils.tools import (
set_tool_servers, set_tool_servers,
) )
from open_webui.utils.mcp.client import MCPClient from open_webui.utils.mcp.client import MCPClient
from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -165,12 +167,59 @@ async def set_tool_servers_config(
form_data: ToolServersConfigForm, form_data: ToolServersConfigForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
request.app.state.config.TOOL_SERVER_CONNECTIONS = [ old_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
new_connections = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
] ]
old_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in old_connections
if conn.get("type") == "mcp"
}
new_mcp_connections = {
conn.get("info", {}).get("id"): conn
for conn in new_connections
if conn.get("type") == "mcp"
}
purge_oauth_clients = set()
for server_id, old_conn in old_mcp_connections.items():
if not server_id:
continue
old_auth_type = old_conn.get("auth_type", "none")
new_conn = new_mcp_connections.get(server_id)
if new_conn is None:
if old_auth_type == "oauth_2.1":
purge_oauth_clients.add(server_id)
continue
new_auth_type = new_conn.get("auth_type", "none")
if old_auth_type == "oauth_2.1":
if (
new_auth_type != "oauth_2.1"
or old_conn.get("url") != new_conn.get("url")
or old_conn.get("info", {}).get("oauth_client_info")
!= new_conn.get("info", {}).get("oauth_client_info")
):
purge_oauth_clients.add(server_id)
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
await set_tool_servers(request) await set_tool_servers(request)
for server_id in purge_oauth_clients:
client_key = f"mcp:{server_id}"
request.app.state.oauth_client_manager.remove_client(client_key)
OAuthSessions.delete_sessions_by_provider(client_key)
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
server_type = connection.get("type", "openapi") server_type = connection.get("type", "openapi")
if server_type == "mcp": if server_type == "mcp":

View file

@ -1,4 +1,5 @@
import base64 import base64
import copy
import hashlib import hashlib
import logging import logging
import mimetypes import mimetypes
@ -74,6 +75,8 @@ from mcp.shared.auth import (
OAuthMetadata, OAuthMetadata,
) )
from authlib.oauth2.rfc6749.errors import OAuth2Error
class OAuthClientInformationFull(OAuthClientMetadata): class OAuthClientInformationFull(OAuthClientMetadata):
issuer: Optional[str] = None # URL of the OAuth server that issued this client issuer: Optional[str] = None # URL of the OAuth server that issued this client
@ -150,6 +153,37 @@ def decrypt_data(data: str):
raise raise
def _build_oauth_callback_error_message(exc: Exception) -> str:
"""
Produce a user-facing callback error string with actionable context.
Keeps the message short and strips newlines for safe redirect usage.
"""
if isinstance(exc, OAuth2Error):
parts = [p for p in [exc.error, exc.description] if p]
detail = " - ".join(parts)
elif isinstance(exc, HTTPException):
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
elif isinstance(exc, aiohttp.ClientResponseError):
detail = f"Upstream provider returned {exc.status}: {exc.message}"
elif isinstance(exc, aiohttp.ClientError):
detail = str(exc)
elif isinstance(exc, KeyError):
missing = str(exc).strip("'")
if missing.lower() == "state":
detail = "Missing state parameter in callback (session may have expired)"
else:
detail = f"Missing expected key '{missing}' in OAuth response"
else:
detail = str(exc)
detail = detail.replace("\n", " ").strip()
if not detail:
detail = exc.__class__.__name__
message = f"OAuth callback failed: {detail}"
return message[:197] + "..." if len(message) > 200 else message
def is_in_blocked_groups(group_name: str, groups: list) -> bool: def is_in_blocked_groups(group_name: str, groups: list) -> bool:
""" """
Check if a group name matches any blocked pattern. Check if a group name matches any blocked pattern.
@ -368,11 +402,221 @@ class OAuthClientManager:
return self.clients[client_id] return self.clients[client_id]
def remove_client(self, client_id): def remove_client(self, client_id):
removed = False
if client_id in self.clients: if client_id in self.clients:
del self.clients[client_id] del self.clients[client_id]
removed = True
if hasattr(self.oauth, "_clients"):
if client_id in self.oauth._clients:
self.oauth._clients.pop(client_id, None)
removed = True
if hasattr(self.oauth, "_registry"):
if client_id in self.oauth._registry:
self.oauth._registry.pop(client_id, None)
removed = True
if removed:
log.info(f"Removed OAuth client {client_id}") log.info(f"Removed OAuth client {client_id}")
return True return True
def _find_mcp_connection(self, request, client_id: str):
try:
connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
except Exception:
connections = []
normalized_client_id = client_id.split(":")[-1]
for idx, connection in enumerate(connections):
if not isinstance(connection, dict):
continue
if connection.get("type") != "mcp":
continue
info = connection.get("info") or {}
server_id = info.get("id")
if not server_id:
continue
normalized_server_id = server_id.split(":")[-1]
if normalized_server_id == normalized_client_id:
return idx, connection
return None, None
async def _preflight_authorization_url(
self, client, client_info: OAuthClientInformationFull
) -> bool:
# Only perform preflight checks for Starlette OAuth clients
if not hasattr(client, "create_authorization_url"):
return True
redirect_uri = None
if client_info.redirect_uris:
redirect_uri = str(client_info.redirect_uris[0])
try:
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
authorize_url = auth_data.get("url")
if not authorize_url:
return True
except Exception as e:
log.debug(
"Skipping OAuth preflight for client %s: %s",
client_info.client_id,
e,
)
return True
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
authorize_url,
allow_redirects=False,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as resp:
if resp.status < 400:
return True
body_text = await resp.text()
error = None
error_description = ""
content_type = resp.headers.get("content-type", "")
if "application/json" in content_type:
try:
payload = json.loads(body_text)
error = payload.get("error")
error_description = payload.get("error_description", "")
except json.JSONDecodeError:
error = None
error_description = ""
else:
error_description = body_text
combined = f"{error or ''} {error_description}".lower()
if (
"invalid_client" in combined
or "invalid client" in combined
or "client id" in combined
):
log.warning(
"OAuth client preflight detected invalid registration for %s: %s %s",
client_info.client_id,
error,
error_description,
)
return False
except Exception as e:
log.debug(
"Skipping OAuth preflight network check for client %s: %s",
client_info.client_id,
e,
)
return True
async def _re_register_client(self, request, client_id: str) -> bool:
idx, connection = self._find_mcp_connection(request, client_id)
if idx is None or connection is None:
log.warning(
"Unable to locate MCP tool server configuration for client %s during re-registration",
client_id,
)
return False
server_url = connection.get("url")
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
try:
oauth_client_info = (
await get_oauth_client_info_with_dynamic_client_registration(
request,
client_id,
server_url,
oauth_server_key,
)
)
except Exception as e:
log.error(
"Dynamic client re-registration failed for %s: %s",
client_id,
e,
)
return False
encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
updated_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
if idx >= len(updated_connections):
log.error(
"MCP tool server index %s out of range during OAuth client re-registration for %s",
idx,
client_id,
)
return False
updated_connection = copy.deepcopy(connection)
updated_connection.setdefault("info", {})
updated_connection["info"]["oauth_client_info"] = encrypted_info
updated_connections[idx] = updated_connection
try:
request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
except Exception as e:
log.error(
"Failed to persist updated OAuth client info for %s: %s",
client_id,
e,
)
return False
self.remove_client(client_id)
self.add_client(client_id, oauth_client_info)
OAuthSessions.delete_sessions_by_provider(client_id)
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
return True
async def _ensure_valid_client_registration(self, request, client_id: str) -> None:
if not client_id.startswith("mcp:"):
return
client = self.get_client(client_id)
client_info = self.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
is_valid = await self._preflight_authorization_url(client, client_info)
if is_valid:
return
log.info(
"Detected invalid OAuth client %s; attempting re-registration",
client_id,
)
re_registered = await self._re_register_client(request, client_id)
if not re_registered:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to re-register OAuth client",
)
client = self.get_client(client_id)
client_info = self.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client unavailable after re-registration",
)
if not await self._preflight_authorization_url(client, client_info):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client registration is still invalid after re-registration",
)
def get_client(self, client_id): def get_client(self, client_id):
client = self.clients.get(client_id) client = self.clients.get(client_id)
return client["client"] if client else None return client["client"] if client else None
@ -558,10 +802,11 @@ class OAuthClientManager:
return None return None
async def handle_authorize(self, request, client_id: str) -> RedirectResponse: async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
await self._ensure_valid_client_registration(request, client_id)
client = self.get_client(client_id) client = self.get_client(client_id)
if client is None: if client is None:
raise HTTPException(404) raise HTTPException(404)
client_info = self.get_client_info(client_id) client_info = self.get_client_info(client_id)
if client_info is None: if client_info is None:
raise HTTPException(404) raise HTTPException(404)
@ -569,7 +814,8 @@ class OAuthClientManager:
redirect_uri = ( redirect_uri = (
client_info.redirect_uris[0] if client_info.redirect_uris else None client_info.redirect_uris[0] if client_info.redirect_uris else None
) )
return await client.authorize_redirect(request, str(redirect_uri)) redirect_uri_str = str(redirect_uri) if redirect_uri else None
return await client.authorize_redirect(request, redirect_uri_str)
async def handle_callback(self, request, client_id: str, user_id: str, response): async def handle_callback(self, request, client_id: str, user_id: str, response):
client = self.get_client(client_id) client = self.get_client(client_id)
@ -621,8 +867,14 @@ class OAuthClientManager:
error_message = "Failed to obtain OAuth token" error_message = "Failed to obtain OAuth token"
log.warning(error_message) log.warning(error_message)
except Exception as e: except Exception as e:
error_message = "OAuth callback error" error_message = _build_oauth_callback_error_message(e)
log.warning(f"OAuth callback error: {e}") log.warning(
"OAuth callback error for user_id=%s client_id=%s: %s",
user_id,
client_id,
error_message,
exc_info=True,
)
redirect_url = ( redirect_url = (
str(request.app.state.config.WEBUI_URL or request.base_url) str(request.app.state.config.WEBUI_URL or request.base_url)
@ -630,7 +882,9 @@ class OAuthClientManager:
if error_message: if error_message:
log.debug(error_message) log.debug(error_message)
redirect_url = f"{redirect_url}/?error={error_message}" redirect_url = (
f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}"
)
return RedirectResponse(url=redirect_url, headers=response.headers) return RedirectResponse(url=redirect_url, headers=response.headers)
response = RedirectResponse(url=redirect_url, headers=response.headers) response = RedirectResponse(url=redirect_url, headers=response.headers)
@ -1104,7 +1358,13 @@ class OAuthManager:
try: try:
token = await client.authorize_access_token(request) token = await client.authorize_access_token(request)
except Exception as e: except Exception as e:
log.warning(f"OAuth callback error: {e}") detailed_error = _build_oauth_callback_error_message(e)
log.warning(
"OAuth callback error during authorize_access_token for provider %s: %s",
provider,
detailed_error,
exc_info=True,
)
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Try to get userinfo from the token first, some providers include it there # Try to get userinfo from the token first, some providers include it there