This commit is contained in:
Timothy Jaeryang Baek 2025-10-27 16:46:04 -07:00
parent c8b2313362
commit cbcab062eb
2 changed files with 116 additions and 156 deletions

View file

@ -482,9 +482,11 @@ from open_webui.utils.auth import (
)
from open_webui.utils.plugin import install_tool_and_function_dependencies
from open_webui.utils.oauth import (
get_oauth_client_info_with_dynamic_client_registration,
encrypt_data,
decrypt_data,
OAuthManager,
OAuthClientManager,
decrypt_data,
OAuthClientInformationFull,
)
from open_webui.utils.security_headers import SecurityHeadersMiddleware
@ -1987,6 +1989,64 @@ except Exception as e:
)
async def register_client(self, request, client_id: str) -> bool:
server_type, server_id = client_id.split(":", 1)
connection = None
connection_idx = None
for idx, conn in enumerate(request.app.state.config.TOOL_SERVER_CONNECTIONS or []):
if conn.get("type", "openapi") == server_type:
info = conn.get("info", {})
if info.get("id") == server_id:
connection = conn
connection_idx = idx
break
if connection is None or connection_idx is None:
log.warning(
f"Unable to locate MCP tool server configuration for client {client_id} during re-registration"
)
return False
server_url = connection.get("url")
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
try:
oauth_client_info = (
await get_oauth_client_info_with_dynamic_client_registration(
request,
client_id,
server_url,
oauth_server_key,
)
)
except Exception as e:
log.error(f"Dynamic client re-registration failed for {client_id}: {e}")
return False
try:
request.app.state.config.TOOL_SERVER_CONNECTIONS[connection_idx] = {
**connection,
"info": {
**connection.get("info", {}),
"oauth_client_info": encrypt_data(
oauth_client_info.model_dump(mode="json")
),
},
}
except Exception as e:
log.error(
f"Failed to persist updated OAuth client info for tool server {client_id}: {e}"
)
return False
oauth_client_manager.remove_client(client_id)
oauth_client_manager.add_client(client_id, oauth_client_info)
log.info(f"Re-registered OAuth client {client_id} for tool server")
return True
@app.get("/oauth/clients/{client_id}/authorize")
async def oauth_client_authorize(
client_id: str,
@ -1994,6 +2054,41 @@ async def oauth_client_authorize(
response: Response,
user=Depends(get_verified_user),
):
# ensure_valid_client_registration
client = oauth_client_manager.get_client(client_id)
client_info = oauth_client_manager.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
if not await oauth_client_manager._preflight_authorization_url(client, client_info):
log.info(
"Detected invalid OAuth client %s; attempting re-registration",
client_id,
)
re_registered = await register_client(request, client_id)
if not re_registered:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to re-register OAuth client",
)
client = oauth_client_manager.get_client(client_id)
client_info = oauth_client_manager.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client unavailable after re-registration",
)
if not await oauth_client_manager._preflight_authorization_url(
client, client_info
):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client registration is still invalid after re-registration",
)
return await oauth_client_manager.handle_authorize(request, client_id=client_id)

View file

@ -416,34 +416,10 @@ class OAuthClientManager:
return True
def _find_mcp_connection(self, request, client_id: str):
try:
connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
except Exception:
connections = []
normalized_client_id = client_id.split(":")[-1]
for idx, connection in enumerate(connections):
if not isinstance(connection, dict):
continue
if connection.get("type") != "mcp":
continue
info = connection.get("info") or {}
server_id = info.get("id")
if not server_id:
continue
normalized_server_id = server_id.split(":")[-1]
if normalized_server_id == normalized_client_id:
return idx, connection
return None, None
async def _preflight_authorization_url(
self, client, client_info: OAuthClientInformationFull
) -> bool:
# TODO: Replace this logic with a more robust OAuth client registration validation
# Only perform preflight checks for Starlette OAuth clients
if not hasattr(client, "create_authorization_url"):
return True
@ -454,168 +430,59 @@ class OAuthClientManager:
try:
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
authorize_url = auth_data.get("url")
if not authorize_url:
authorization_url = auth_data.get("url")
if not authorization_url:
return True
except Exception as e:
log.debug(
"Skipping OAuth preflight for client %s: %s",
client_info.client_id,
e,
f"Skipping OAuth preflight for client {client_info.client_id}: {e}",
)
return True
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
authorize_url,
authorization_url,
allow_redirects=False,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as resp:
if resp.status < 400:
return True
response_text = await resp.text()
body_text = await resp.text()
error = None
error_description = ""
content_type = resp.headers.get("content-type", "")
content_type = resp.headers.get("content-type", "")
if "application/json" in content_type:
try:
payload = json.loads(body_text)
payload = json.loads(response_text)
error = payload.get("error")
error_description = payload.get("error_description", "")
except json.JSONDecodeError:
error = None
error_description = ""
except:
pass
else:
error_description = body_text
error_description = response_text
combined = f"{error or ''} {error_description}".lower()
if (
"invalid_client" in combined
or "invalid client" in combined
or "client id" in combined
error_message = f"{error or ''} {error_description or ''}".lower()
if any(
keyword in error_message
for keyword in ("invalid_client", "invalid client", "client id")
):
log.warning(
"OAuth client preflight detected invalid registration for %s: %s %s",
client_info.client_id,
error,
error_description,
f"OAuth client preflight detected invalid registration for {client_info.client_id}: {error} {error_description}"
)
return False
except Exception as e:
log.debug(
"Skipping OAuth preflight network check for client %s: %s",
client_info.client_id,
e,
f"Skipping OAuth preflight network check for client {client_info.client_id}: {e}"
)
return True
async def _re_register_client(self, request, client_id: str) -> bool:
idx, connection = self._find_mcp_connection(request, client_id)
if idx is None or connection is None:
log.warning(
"Unable to locate MCP tool server configuration for client %s during re-registration",
client_id,
)
return False
server_url = connection.get("url")
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
try:
oauth_client_info = (
await get_oauth_client_info_with_dynamic_client_registration(
request,
client_id,
server_url,
oauth_server_key,
)
)
except Exception as e:
log.error(
"Dynamic client re-registration failed for %s: %s",
client_id,
e,
)
return False
encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
updated_connections = copy.deepcopy(
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
)
if idx >= len(updated_connections):
log.error(
"MCP tool server index %s out of range during OAuth client re-registration for %s",
idx,
client_id,
)
return False
updated_connection = copy.deepcopy(connection)
updated_connection.setdefault("info", {})
updated_connection["info"]["oauth_client_info"] = encrypted_info
updated_connections[idx] = updated_connection
try:
request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
except Exception as e:
log.error(
"Failed to persist updated OAuth client info for %s: %s",
client_id,
e,
)
return False
self.remove_client(client_id)
self.add_client(client_id, oauth_client_info)
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
return True
async def _ensure_valid_client_registration(self, request, client_id: str) -> None:
if not client_id.startswith("mcp:"):
return
client = self.get_client(client_id)
client_info = self.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(status.HTTP_404_NOT_FOUND)
is_valid = await self._preflight_authorization_url(client, client_info)
if is_valid:
return
log.info(
"Detected invalid OAuth client %s; attempting re-registration",
client_id,
)
re_registered = await self._re_register_client(request, client_id)
if not re_registered:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to re-register OAuth client",
)
client = self.get_client(client_id)
client_info = self.get_client_info(client_id)
if client is None or client_info is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client unavailable after re-registration",
)
if not await self._preflight_authorization_url(client, client_info):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth client registration is still invalid after re-registration",
)
def get_client(self, client_id):
client = self.clients.get(client_id)
return client["client"] if client else None
@ -801,8 +668,6 @@ class OAuthClientManager:
return None
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
# await self._ensure_valid_client_registration(request, client_id)
client = self.get_client(client_id)
if client is None:
raise HTTPException(404)