mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-14 13:25:20 +00:00
refac
This commit is contained in:
parent
c6cbb05b84
commit
db658a730c
2 changed files with 116 additions and 156 deletions
|
|
@ -482,9 +482,11 @@ from open_webui.utils.auth import (
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
from open_webui.utils.plugin import install_tool_and_function_dependencies
|
||||||
from open_webui.utils.oauth import (
|
from open_webui.utils.oauth import (
|
||||||
|
get_oauth_client_info_with_dynamic_client_registration,
|
||||||
|
encrypt_data,
|
||||||
|
decrypt_data,
|
||||||
OAuthManager,
|
OAuthManager,
|
||||||
OAuthClientManager,
|
OAuthClientManager,
|
||||||
decrypt_data,
|
|
||||||
OAuthClientInformationFull,
|
OAuthClientInformationFull,
|
||||||
)
|
)
|
||||||
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
from open_webui.utils.security_headers import SecurityHeadersMiddleware
|
||||||
|
|
@ -1987,6 +1989,64 @@ except Exception as e:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def register_client(self, request, client_id: str) -> bool:
|
||||||
|
server_type, server_id = client_id.split(":", 1)
|
||||||
|
|
||||||
|
connection = None
|
||||||
|
connection_idx = None
|
||||||
|
|
||||||
|
for idx, conn in enumerate(request.app.state.config.TOOL_SERVER_CONNECTIONS or []):
|
||||||
|
if conn.get("type", "openapi") == server_type:
|
||||||
|
info = conn.get("info", {})
|
||||||
|
if info.get("id") == server_id:
|
||||||
|
connection = conn
|
||||||
|
connection_idx = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if connection is None or connection_idx is None:
|
||||||
|
log.warning(
|
||||||
|
f"Unable to locate MCP tool server configuration for client {client_id} during re-registration"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
server_url = connection.get("url")
|
||||||
|
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
|
||||||
|
|
||||||
|
try:
|
||||||
|
oauth_client_info = (
|
||||||
|
await get_oauth_client_info_with_dynamic_client_registration(
|
||||||
|
request,
|
||||||
|
client_id,
|
||||||
|
server_url,
|
||||||
|
oauth_server_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Dynamic client re-registration failed for {client_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS[connection_idx] = {
|
||||||
|
**connection,
|
||||||
|
"info": {
|
||||||
|
**connection.get("info", {}),
|
||||||
|
"oauth_client_info": encrypt_data(
|
||||||
|
oauth_client_info.model_dump(mode="json")
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to persist updated OAuth client info for tool server {client_id}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
oauth_client_manager.remove_client(client_id)
|
||||||
|
oauth_client_manager.add_client(client_id, oauth_client_info)
|
||||||
|
log.info(f"Re-registered OAuth client {client_id} for tool server")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@app.get("/oauth/clients/{client_id}/authorize")
|
@app.get("/oauth/clients/{client_id}/authorize")
|
||||||
async def oauth_client_authorize(
|
async def oauth_client_authorize(
|
||||||
client_id: str,
|
client_id: str,
|
||||||
|
|
@ -1994,6 +2054,41 @@ async def oauth_client_authorize(
|
||||||
response: Response,
|
response: Response,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
# ensure_valid_client_registration
|
||||||
|
client = oauth_client_manager.get_client(client_id)
|
||||||
|
client_info = oauth_client_manager.get_client_info(client_id)
|
||||||
|
if client is None or client_info is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
if not await oauth_client_manager._preflight_authorization_url(client, client_info):
|
||||||
|
log.info(
|
||||||
|
"Detected invalid OAuth client %s; attempting re-registration",
|
||||||
|
client_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
re_registered = await register_client(request, client_id)
|
||||||
|
if not re_registered:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to re-register OAuth client",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = oauth_client_manager.get_client(client_id)
|
||||||
|
client_info = oauth_client_manager.get_client_info(client_id)
|
||||||
|
if client is None or client_info is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth client unavailable after re-registration",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not await oauth_client_manager._preflight_authorization_url(
|
||||||
|
client, client_info
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth client registration is still invalid after re-registration",
|
||||||
|
)
|
||||||
|
|
||||||
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
|
return await oauth_client_manager.handle_authorize(request, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -416,34 +416,10 @@ class OAuthClientManager:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _find_mcp_connection(self, request, client_id: str):
|
|
||||||
try:
|
|
||||||
connections = request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
|
||||||
except Exception:
|
|
||||||
connections = []
|
|
||||||
|
|
||||||
normalized_client_id = client_id.split(":")[-1]
|
|
||||||
|
|
||||||
for idx, connection in enumerate(connections):
|
|
||||||
if not isinstance(connection, dict):
|
|
||||||
continue
|
|
||||||
if connection.get("type") != "mcp":
|
|
||||||
continue
|
|
||||||
|
|
||||||
info = connection.get("info") or {}
|
|
||||||
server_id = info.get("id")
|
|
||||||
if not server_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
normalized_server_id = server_id.split(":")[-1]
|
|
||||||
if normalized_server_id == normalized_client_id:
|
|
||||||
return idx, connection
|
|
||||||
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
async def _preflight_authorization_url(
|
async def _preflight_authorization_url(
|
||||||
self, client, client_info: OAuthClientInformationFull
|
self, client, client_info: OAuthClientInformationFull
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
# TODO: Replace this logic with a more robust OAuth client registration validation
|
||||||
# Only perform preflight checks for Starlette OAuth clients
|
# Only perform preflight checks for Starlette OAuth clients
|
||||||
if not hasattr(client, "create_authorization_url"):
|
if not hasattr(client, "create_authorization_url"):
|
||||||
return True
|
return True
|
||||||
|
|
@ -454,168 +430,59 @@ class OAuthClientManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
|
auth_data = await client.create_authorization_url(redirect_uri=redirect_uri)
|
||||||
authorize_url = auth_data.get("url")
|
authorization_url = auth_data.get("url")
|
||||||
if not authorize_url:
|
|
||||||
|
if not authorization_url:
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(
|
log.debug(
|
||||||
"Skipping OAuth preflight for client %s: %s",
|
f"Skipping OAuth preflight for client {client_info.client_id}: {e}",
|
||||||
client_info.client_id,
|
|
||||||
e,
|
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||||
async with session.get(
|
async with session.get(
|
||||||
authorize_url,
|
authorization_url,
|
||||||
allow_redirects=False,
|
allow_redirects=False,
|
||||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status < 400:
|
if resp.status < 400:
|
||||||
return True
|
return True
|
||||||
|
response_text = await resp.text()
|
||||||
|
|
||||||
body_text = await resp.text()
|
|
||||||
error = None
|
error = None
|
||||||
error_description = ""
|
error_description = ""
|
||||||
content_type = resp.headers.get("content-type", "")
|
|
||||||
|
|
||||||
|
content_type = resp.headers.get("content-type", "")
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
try:
|
try:
|
||||||
payload = json.loads(body_text)
|
payload = json.loads(response_text)
|
||||||
error = payload.get("error")
|
error = payload.get("error")
|
||||||
error_description = payload.get("error_description", "")
|
error_description = payload.get("error_description", "")
|
||||||
except json.JSONDecodeError:
|
except:
|
||||||
error = None
|
pass
|
||||||
error_description = ""
|
|
||||||
else:
|
else:
|
||||||
error_description = body_text
|
error_description = response_text
|
||||||
|
|
||||||
combined = f"{error or ''} {error_description}".lower()
|
error_message = f"{error or ''} {error_description or ''}".lower()
|
||||||
if (
|
|
||||||
"invalid_client" in combined
|
if any(
|
||||||
or "invalid client" in combined
|
keyword in error_message
|
||||||
or "client id" in combined
|
for keyword in ("invalid_client", "invalid client", "client id")
|
||||||
):
|
):
|
||||||
log.warning(
|
log.warning(
|
||||||
"OAuth client preflight detected invalid registration for %s: %s %s",
|
f"OAuth client preflight detected invalid registration for {client_info.client_id}: {error} {error_description}"
|
||||||
client_info.client_id,
|
|
||||||
error,
|
|
||||||
error_description,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(
|
log.debug(
|
||||||
"Skipping OAuth preflight network check for client %s: %s",
|
f"Skipping OAuth preflight network check for client {client_info.client_id}: {e}"
|
||||||
client_info.client_id,
|
|
||||||
e,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def _re_register_client(self, request, client_id: str) -> bool:
|
|
||||||
idx, connection = self._find_mcp_connection(request, client_id)
|
|
||||||
if idx is None or connection is None:
|
|
||||||
log.warning(
|
|
||||||
"Unable to locate MCP tool server configuration for client %s during re-registration",
|
|
||||||
client_id,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
server_url = connection.get("url")
|
|
||||||
oauth_server_key = (connection.get("config") or {}).get("oauth_server_key")
|
|
||||||
|
|
||||||
try:
|
|
||||||
oauth_client_info = (
|
|
||||||
await get_oauth_client_info_with_dynamic_client_registration(
|
|
||||||
request,
|
|
||||||
client_id,
|
|
||||||
server_url,
|
|
||||||
oauth_server_key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(
|
|
||||||
"Dynamic client re-registration failed for %s: %s",
|
|
||||||
client_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
encrypted_info = encrypt_data(oauth_client_info.model_dump(mode="json"))
|
|
||||||
|
|
||||||
updated_connections = copy.deepcopy(
|
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS or []
|
|
||||||
)
|
|
||||||
if idx >= len(updated_connections):
|
|
||||||
log.error(
|
|
||||||
"MCP tool server index %s out of range during OAuth client re-registration for %s",
|
|
||||||
idx,
|
|
||||||
client_id,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
updated_connection = copy.deepcopy(connection)
|
|
||||||
updated_connection.setdefault("info", {})
|
|
||||||
updated_connection["info"]["oauth_client_info"] = encrypted_info
|
|
||||||
updated_connections[idx] = updated_connection
|
|
||||||
|
|
||||||
try:
|
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = updated_connections
|
|
||||||
except Exception as e:
|
|
||||||
log.error(
|
|
||||||
"Failed to persist updated OAuth client info for %s: %s",
|
|
||||||
client_id,
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.remove_client(client_id)
|
|
||||||
self.add_client(client_id, oauth_client_info)
|
|
||||||
|
|
||||||
log.info("Re-registered OAuth client %s for MCP tool server", client_id)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _ensure_valid_client_registration(self, request, client_id: str) -> None:
|
|
||||||
if not client_id.startswith("mcp:"):
|
|
||||||
return
|
|
||||||
|
|
||||||
client = self.get_client(client_id)
|
|
||||||
client_info = self.get_client_info(client_id)
|
|
||||||
|
|
||||||
if client is None or client_info is None:
|
|
||||||
raise HTTPException(status.HTTP_404_NOT_FOUND)
|
|
||||||
|
|
||||||
is_valid = await self._preflight_authorization_url(client, client_info)
|
|
||||||
if is_valid:
|
|
||||||
return
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"Detected invalid OAuth client %s; attempting re-registration",
|
|
||||||
client_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
re_registered = await self._re_register_client(request, client_id)
|
|
||||||
if not re_registered:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to re-register OAuth client",
|
|
||||||
)
|
|
||||||
|
|
||||||
client = self.get_client(client_id)
|
|
||||||
client_info = self.get_client_info(client_id)
|
|
||||||
if client is None or client_info is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="OAuth client unavailable after re-registration",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not await self._preflight_authorization_url(client, client_info):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="OAuth client registration is still invalid after re-registration",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_client(self, client_id):
|
def get_client(self, client_id):
|
||||||
client = self.clients.get(client_id)
|
client = self.clients.get(client_id)
|
||||||
return client["client"] if client else None
|
return client["client"] if client else None
|
||||||
|
|
@ -801,8 +668,6 @@ class OAuthClientManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
async def handle_authorize(self, request, client_id: str) -> RedirectResponse:
|
||||||
# await self._ensure_valid_client_registration(request, client_id)
|
|
||||||
|
|
||||||
client = self.get_client(client_id)
|
client = self.get_client(client_id)
|
||||||
if client is None:
|
if client is None:
|
||||||
raise HTTPException(404)
|
raise HTTPException(404)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue