mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-15 13:55:19 +00:00
Merge pull request #18415 from taylorwilsdon/oauth_error_handling_enh
enh: More detailed OAuth2.1 tool callback error handling + fix for editing existing tools
This commit is contained in:
commit
bfadbc9934
3 changed files with 327 additions and 7 deletions
|
|
@ -262,5 +262,16 @@ class OAuthSessionTable:
|
||||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def delete_sessions_by_provider(self, provider: str) -> bool:
|
||||||
|
"""Delete all OAuth sessions for a provider"""
|
||||||
|
try:
|
||||||
|
with get_db() as db:
|
||||||
|
db.query(OAuthSession).filter_by(provider=provider).delete()
|
||||||
|
db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
OAuthSessions = OAuthSessionTable()
|
OAuthSessions = OAuthSessionTable()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import copy
|
||||||
from fastapi import APIRouter, Depends, Request, HTTPException
|
from fastapi import APIRouter, Depends, Request, HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
@ -15,6 +16,7 @@ from open_webui.utils.tools import (
|
||||||
set_tool_servers,
|
set_tool_servers,
|
||||||
)
|
)
|
||||||
from open_webui.utils.mcp.client import MCPClient
|
from open_webui.utils.mcp.client import MCPClient
|
||||||
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
@ -165,12 +167,59 @@ async def set_tool_servers_config(
|
||||||
form_data: ToolServersConfigForm,
|
form_data: ToolServersConfigForm,
|
||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
old_connections = copy.deepcopy(
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||||
|
)
|
||||||
|
|
||||||
|
new_connections = [
|
||||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||||
]
|
]
|
||||||
|
|
||||||
|
old_mcp_connections = {
|
||||||
|
conn.get("info", {}).get("id"): conn
|
||||||
|
for conn in old_connections
|
||||||
|
if conn.get("type") == "mcp"
|
||||||
|
}
|
||||||
|
new_mcp_connections = {
|
||||||
|
conn.get("info", {}).get("id"): conn
|
||||||
|
for conn in new_connections
|
||||||
|
if conn.get("type") == "mcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
purge_oauth_clients = set()
|
||||||
|
|
||||||
|
for server_id, old_conn in old_mcp_connections.items():
|
||||||
|
if not server_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
old_auth_type = old_conn.get("auth_type", "none")
|
||||||
|
new_conn = new_mcp_connections.get(server_id)
|
||||||
|
|
||||||
|
if new_conn is None:
|
||||||
|
if old_auth_type == "oauth_2.1":
|
||||||
|
purge_oauth_clients.add(server_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_auth_type = new_conn.get("auth_type", "none")
|
||||||
|
|
||||||
|
if old_auth_type == "oauth_2.1":
|
||||||
|
if (
|
||||||
|
new_auth_type != "oauth_2.1"
|
||||||
|
or old_conn.get("url") != new_conn.get("url")
|
||||||
|
or old_conn.get("info", {}).get("oauth_client_info")
|
||||||
|
!= new_conn.get("info", {}).get("oauth_client_info")
|
||||||
|
):
|
||||||
|
purge_oauth_clients.add(server_id)
|
||||||
|
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS = new_connections
|
||||||
|
|
||||||
await set_tool_servers(request)
|
await set_tool_servers(request)
|
||||||
|
|
||||||
|
for server_id in purge_oauth_clients:
|
||||||
|
client_key = f"mcp:{server_id}"
|
||||||
|
request.app.state.oauth_client_manager.remove_client(client_key)
|
||||||
|
OAuthSessions.delete_sessions_by_provider(client_key)
|
||||||
|
|
||||||
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS:
|
||||||
server_type = connection.get("type", "openapi")
|
server_type = connection.get("type", "openapi")
|
||||||
if server_type == "mcp":
|
if server_type == "mcp":
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import base64
|
import base64
|
||||||
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
|
@ -74,6 +75,8 @@ from mcp.shared.auth import (
|
||||||
OAuthMetadata,
|
OAuthMetadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from authlib.oauth2.rfc6749.errors import OAuth2Error
|
||||||
|
|
||||||
|
|
||||||
class OAuthClientInformationFull(OAuthClientMetadata):
|
class OAuthClientInformationFull(OAuthClientMetadata):
|
||||||
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
issuer: Optional[str] = None # URL of the OAuth server that issued this client
|
||||||
|
|
@ -150,6 +153,37 @@ def decrypt_data(data: str):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _build_oauth_callback_error_message(exc: Exception) -> str:
|
||||||
|
"""
|
||||||
|
Produce a user-facing callback error string with actionable context.
|
||||||
|
Keeps the message short and strips newlines for safe redirect usage.
|
||||||
|
"""
|
||||||
|
if isinstance(exc, OAuth2Error):
|
||||||
|
parts = [p for p in [exc.error, exc.description] if p]
|
||||||
|
detail = " - ".join(parts)
|
||||||
|
elif isinstance(exc, HTTPException):
|
||||||
|
detail = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
|
||||||
|
elif isinstance(exc, aiohttp.ClientResponseError):
|
||||||
|
detail = f"Upstream provider returned {exc.status}: {exc.message}"
|
||||||
|
elif isinstance(exc, aiohttp.ClientError):
|
||||||
|
detail = str(exc)
|
||||||
|
elif isinstance(exc, KeyError):
|
||||||
|
missing = str(exc).strip("'")
|
||||||
|
if missing.lower() == "state":
|
||||||
|
detail = "Missing state parameter in callback (session may have expired)"
|
||||||
|
else:
|
||||||
|
detail = f"Missing expected key '{missing}' in OAuth response"
|
||||||
|
else:
|
||||||
|
detail = str(exc)
|
||||||
|
|
||||||
|
detail = detail.replace("\n", " ").strip()
|
||||||
|
if not detail:
|
||||||
|
detail = exc.__class__.__name__
|
||||||
|
|
||||||
|
message = f"OAuth callback failed: {detail}"
|
||||||
|
return message[:197] + "..." if len(message) > 200 else message
|
||||||
|
|
||||||
|
|
||||||
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
def is_in_blocked_groups(group_name: str, groups: list) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a group name matches any blocked pattern.
|
Check if a group name matches any blocked pattern.
|
||||||
|
|
@ -368,11 +402,221 @@ class OAuthClientManager:
|
||||||
return self.clients[client_id]
|
return self.clients[client_id]
|
||||||
|
|
||||||
def remove_client(self, client_id):
|
def remove_client(self, client_id):
|
||||||
|
removed = False
|
||||||
if client_id in self.clients:
|
if client_id in self.clients:
|
||||||
del self.clients[client_id]
|
del self.clients[client_id]
|
||||||
|
removed = True
|
||||||
|
if hasattr(self.oauth, "_clients"):
|
||||||
|
if client_id in self.oauth._clients:
|
||||||
|
self.oauth._clients.pop(client_id, None)
|
||||||
|
removed = True
|
||||||
|
if hasattr(self.oauth, "_registry"):
|
||||||
|
if client_id in self.oauth._registry:
|
||||||
|
self.oauth._registry.pop(client_id, None)
|
||||||
|
removed = True
|
||||||
|
if removed:
|
||||||
log.info(f"Removed OAuth client {client_id}")
|
log.info(f"Removed OAuth client {client_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _find_mcp_connection(self, request, client_id: str):
|
||||||
|
try:
|
||||||
|
connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||||
|
except Exception:
|
||||||
|
connections = []
|
||||||
|
|
||||||
|
normalized_client_id = client_id.split(":")[-1]
|
||||||
|
|
||||||
|
for idx, connection in enumerate(connections):
|
||||||
|
if not isinstance(connection, dict):
|
||||||
|
continue
|
||||||
|
if connection.get("type") != "mcp":
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = connection.get("info") or {}
|
||||||
|
server_id = info.get("id")
|
||||||
|
if not server_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_server_id = server_id.split(":")[-1]
|
||||||
|
if normalized_server_id == normalized_client_id:
|
||||||
|
return idx, connection
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def _preflight_authorization_url(
|
||||||
|
self, client, client_info: OAuthClientInformationFull
|
||||||
|
) -> bool:
|
||||||
|
# Only perform preflight checks for Starlette OAuth clients
|
||||||
|
if not hasattr(client, "create_authorization_url"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
redirect_uri = None
|
||||||
|
if client_info.redirect_uris:
|
||||||
|
redirect_uri = str(client_info.redirect_uris[0])
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
|
||||||
|
authorize_url = auth_data.get("url")
|
||||||
|
if not authorize_url:
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(
|
||||||
|
"Skipping OAuth preflight for client %s: %s",
|
||||||
|
client_info.client_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
|
async with session.get(
|
||||||
|
authorize_url,
|
||||||
|
allow_redirects=False,
|
||||||
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
|
) as resp:
|
||||||
|
if resp.status < 400:
|
||||||
|
return True
|
||||||
|
|
||||||
|
body_text = await resp.text()
|
||||||
|
error = None
|
||||||
|
error_description = ""
|
||||||
|
content_type = resp.headers.get("content-type", "")
|
||||||
|
|
||||||
|
if "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
payload = json.loads(body_text)
|
||||||
|
error = payload.get("error")
|
||||||
|
error_description = payload.get("error_description", "")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
error = None
|
||||||
|
error_description = ""
|
||||||
|
else:
|
||||||
|
error_description = body_text
|
||||||
|
|
||||||
|
combined = f"{error or ''} {error_description}".lower()
|
||||||
|
if (
|
||||||
|
"invalid_client" in combined
|
||||||
|
or "invalid client" in combined
|
||||||
|
or "client id" in combined
|
||||||
|
):
|
||||||
|
log.warning(
|
||||||
|
"OAuth client preflight detected invalid registration for %s: %s %s",
|
||||||
|
client_info.client_id,
|
||||||
|
error,
|
||||||
|
error_description,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(
|
||||||
|
"Skipping OAuth preflight network check for client %s: %s",
|
||||||
|
client_info.client_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _re_register_client(self, request, client_id: str) -> bool:
|
||||||
|
idx, connection = self._find_mcp_connection(request, client_id)
|
||||||
|
if idx is None or connection is None:
|
||||||
|
log.warning(
|
||||||
|
"Unable to locate MCP tool server configuration for client %s during re-registration",
|
||||||
|
client_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
server_url = connection.get("url")
|
||||||
|
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
|
||||||
|
|
||||||
|
try:
|
||||||
|
oauth_client_info = (
|
||||||
|
await get_oauth_client_info_with_dynamic_client_registration(
|
||||||
|
request,
|
||||||
|
client_id,
|
||||||
|
server_url,
|
||||||
|
oauth_server_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
"Dynamic client re-registration failed for %s: %s",
|
||||||
|
client_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
|
||||||
|
|
||||||
|
updated_connections = copy.deepcopy(
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
||||||
|
)
|
||||||
|
if idx >= len(updated_connections):
|
||||||
|
log.error(
|
||||||
|
"MCP tool server index %s out of range during OAuth client re-registration for %s",
|
||||||
|
idx,
|
||||||
|
client_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
updated_connection = copy.deepcopy(connection)
|
||||||
|
updated_connection.setdefault("info", {})
|
||||||
|
updated_connection["info"]["oauth_client_info"] = encrypted_info
|
||||||
|
updated_connections[idx] = updated_connection
|
||||||
|
|
||||||
|
try:
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
"Failed to persist updated OAuth client info for %s: %s",
|
||||||
|
client_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.remove_client(client_id)
|
||||||
|
self.add_client(client_id, oauth_client_info)
|
||||||
|
OAuthSessions.delete_sessions_by_provider(client_id)
|
||||||
|
|
||||||
|
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _ensure_valid_client_registration(self, request, client_id: str) -> None:
|
||||||
|
if not client_id.startswith("mcp:"):
|
||||||
|
return
|
||||||
|
|
||||||
|
client = self.get_client(client_id)
|
||||||
|
client_info = self.get_client_info(client_id)
|
||||||
|
if client is None or client_info is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
is_valid = await self._preflight_authorization_url(client, client_info)
|
||||||
|
if is_valid:
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Detected invalid OAuth client %s; attempting re-registration",
|
||||||
|
client_id,
|
||||||
|
)
|
||||||
|
re_registered = await self._re_register_client(request, client_id)
|
||||||
|
if not re_registered:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to re-register OAuth client",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = self.get_client(client_id)
|
||||||
|
client_info = self.get_client_info(client_id)
|
||||||
|
if client is None or client_info is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth client unavailable after re-registration",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await self._preflight_authorization_url(client, client_info):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth client registration is still invalid after re-registration",
|
||||||
|
)
|
||||||
|
|
||||||
def get_client(self, client_id):
|
def get_client(self, client_id):
|
||||||
client = self.clients.get(client_id)
|
client = self.clients.get(client_id)
|
||||||
return client["client"] if client else None
|
return client["client"] if client else None
|
||||||
|
|
@ -558,10 +802,11 @@ class OAuthClientManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
||||||
|
await self._ensure_valid_client_registration(request, client_id)
|
||||||
|
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id)
|
||||||
if client is None:
|
if client is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
||||||
client_info = self.get_client_info(client_id)
|
client_info = self.get_client_info(client_id)
|
||||||
if client_info is None:
|
if client_info is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
@ -569,7 +814,8 @@ class OAuthClientManager:
|
||||||
redirect_uri = (
|
redirect_uri = (
|
||||||
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
client_info.redirect_uris[0] if client_info.redirect_uris else None
|
||||||
)
|
)
|
||||||
return await client.authorize_redirect(request, str(redirect_uri))
|
redirect_uri_str = str(redirect_uri) if redirect_uri else None
|
||||||
|
return await client.authorize_redirect(request, redirect_uri_str)
|
||||||
|
|
||||||
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
async def handle_callback(self, request, client_id: str, user_id: str, response):
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id)
|
||||||
|
|
@ -621,8 +867,14 @@ class OAuthClientManager:
|
||||||
error_message = "Failed to obtain OAuth token"
|
error_message = "Failed to obtain OAuth token"
|
||||||
log.warning(error_message)
|
log.warning(error_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = "OAuth callback error"
|
error_message = _build_oauth_callback_error_message(e)
|
||||||
log.warning(f"OAuth callback error: {e}")
|
log.warning(
|
||||||
|
"OAuth callback error for user_id=%s client_id=%s: %s",
|
||||||
|
user_id,
|
||||||
|
client_id,
|
||||||
|
error_message,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
redirect_url = (
|
redirect_url = (
|
||||||
str(request.app.state.config.WEBUI_URL or request.base_url)
|
str(request.app.state.config.WEBUI_URL or request.base_url)
|
||||||
|
|
@ -630,7 +882,9 @@ class OAuthClientManager:
|
||||||
|
|
||||||
if error_message:
|
if error_message:
|
||||||
log.debug(error_message)
|
log.debug(error_message)
|
||||||
redirect_url = f"{redirect_url}/?error={error_message}"
|
redirect_url = (
|
||||||
|
f"{redirect_url}/?error={urllib.parse.quote_plus(error_message)}"
|
||||||
|
)
|
||||||
return RedirectResponse(url=redirect_url, headers=response.headers)
|
return RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
||||||
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
response = RedirectResponse(url=redirect_url, headers=response.headers)
|
||||||
|
|
@ -1104,7 +1358,13 @@ class OAuthManager:
|
||||||
try:
|
try:
|
||||||
token = await client.authorize_access_token(request)
|
token = await client.authorize_access_token(request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"OAuth callback error: {e}")
|
detailed_error = _build_oauth_callback_error_message(e)
|
||||||
|
log.warning(
|
||||||
|
"OAuth callback error during authorize_access_token for provider %s: %s",
|
||||||
|
provider,
|
||||||
|
detailed_error,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
# Try to get userinfo from the token first, some providers include it there
|
# Try to get userinfo from the token first, some providers include it there
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue