From 8f060ee2fa24173f98297632c1d8dc3059628cc1 Mon Sep 17 00:00:00 2001 From: Omar Aburub Date: Thu, 23 Oct 2025 15:34:47 +0300 Subject: [PATCH] fix: prevent cancellation scope corruption by exitting in LIFO and handling exceptions --- backend/open_webui/main.py | 8 +++--- backend/open_webui/utils/mcp/client.py | 37 +++++++++++++++----------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 9998af0e73..76cb9d7e07 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1556,11 +1556,13 @@ async def chat_completion( log.info("Chat processing was cancelled") try: event_emitter = get_event_emitter(metadata) - await event_emitter( + await asyncio.shield(event_emitter( {"type": "chat:tasks:cancel"}, - ) + )) except Exception as e: pass + finally: + raise # re-raise to ensure proper task cancellation handling except Exception as e: log.debug(f"Error processing chat payload: {e}") if metadata.get("chat_id") and metadata.get("message_id"): @@ -1591,7 +1593,7 @@ async def chat_completion( finally: try: if mcp_clients := metadata.get("mcp_clients"): - for client in mcp_clients.values(): + for client in reversed(mcp_clients.values()): await client.disconnect() except Exception as e: log.debug(f"Error cleaning up: {e}") diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index 01df38886c..67903b94d8 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -2,35 +2,40 @@ import asyncio from typing import Optional from contextlib import AsyncExitStack +import anyio + from mcp import ClientSession from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken - class MCPClient: def __init__(self): self.session: Optional[ClientSession] = None - self.exit_stack = AsyncExitStack() + self.exit_stack = None async def connect(self, url: str, headers: Optional[dict] = None): - try: - self._streams_context = streamablehttp_client(url, headers=headers) + async with AsyncExitStack() as exit_stack: + try: + self._streams_context = streamablehttp_client(url, headers=headers) - transport = await self.exit_stack.enter_async_context(self._streams_context) - read_stream, write_stream, _ = transport + transport = await exit_stack.enter_async_context(self._streams_context) + read_stream, write_stream, _ = transport - self._session_context = ClientSession( - read_stream, write_stream - ) # pylint: disable=W0201 + self._session_context = ClientSession( + read_stream, write_stream + ) # pylint: disable=W0201 - self.session = await self.exit_stack.enter_async_context( - self._session_context - ) - await self.session.initialize() - except Exception as e: - await self.disconnect() - raise e + self.session = await exit_stack.enter_async_context( + self._session_context + ) + with anyio.fail_after(10): + await self.session.initialize() + self.exit_stack = exit_stack.pop_all() + except Exception as e: + await asyncio.shield(self.disconnect()) + raise e + async def list_tool_specs(self) -> Optional[dict]: if not self.session: