diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index a505c9d797..4601e332f5 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -133,39 +133,44 @@ async def verify_tool_servers_config( try: if form_data.type == "mcp": try: - async with MCPClient() as client: - auth = None - headers = None + client = MCPClient() + auth = None + headers = None - token = None - if form_data.auth_type == "bearer": - token = form_data.key - elif form_data.auth_type == "session": - token = request.state.token.credentials - elif form_data.auth_type == "system_oauth": - try: - if request.cookies.get("oauth_session_id", None): - token = await request.app.state.oauth_manager.get_oauth_token( + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = ( + await request.app.state.oauth_manager.get_oauth_token( user.id, request.cookies.get("oauth_session_id", None), ) - except Exception as e: - pass + ) + except Exception as e: + pass - if token: - headers = {"Authorization": f"Bearer {token}"} + if token: + headers = {"Authorization": f"Bearer {token}"} - await client.connect(form_data.url, auth=auth, headers=headers) - specs = await client.list_tool_specs() - return { - "status": True, - "specs": specs, - } + await client.connect(form_data.url, auth=auth, headers=headers) + specs = await client.list_tool_specs() + return { + "status": True, + "specs": specs, + } except Exception as e: raise HTTPException( status_code=400, detail=f"Failed to create MCP client: {str(e)}", ) + finally: + if client: + await client.disconnect() else: # openapi token = None if form_data.auth_type == "bearer": diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index fabdc5541c..45f3fdac22 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -16,19 +16,25 @@ class MCPClient: async def connect( self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None ): - self._streams_context = streamablehttp_client(url, headers=headers, auth=auth) - read_stream, write_stream, _ = ( - await self._streams_context.__aenter__() - ) # pylint: disable=E1101 + try: + self._streams_context = streamablehttp_client( + url, headers=headers, auth=auth + ) - self._session_context = ClientSession( - read_stream, write_stream - ) # pylint: disable=W0201 - self.session: ClientSession = ( - await self._session_context.__aenter__() - ) # pylint: disable=C2801 + transport = await self.exit_stack.enter_async_context(self._streams_context) + read_stream, write_stream, _ = transport - await self.session.initialize() + self._session_context = ClientSession( + read_stream, write_stream + ) # pylint: disable=W0201 + + self.session = await self.exit_stack.enter_async_context( + self._session_context + ) + await self.session.initialize() + except Exception as e: + await self.disconnect() + raise e async def list_tool_specs(self) -> Optional[dict]: if not self.session: @@ -97,15 +103,7 @@ class MCPClient: async def disconnect(self): # Clean up and close the session - if self.session: - await self._session_context.__aexit__( - None, None, None - ) # pylint: disable=E1101 - if self._streams_context: - await self._streams_context.__aexit__( - None, None, None - ) # pylint: disable=E1101 - self.session = None + await self.exit_stack.aclose() async def __aenter__(self): await self.exit_stack.__aenter__()