mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
fix: prevent cancellation scope corruption by exitting in LIFO and handling exceptions
This commit is contained in:
parent
23ea754061
commit
8f060ee2fa
2 changed files with 26 additions and 19 deletions
|
|
@ -1556,11 +1556,13 @@ async def chat_completion(
|
||||||
log.info("Chat processing was cancelled")
|
log.info("Chat processing was cancelled")
|
||||||
try:
|
try:
|
||||||
event_emitter = get_event_emitter(metadata)
|
event_emitter = get_event_emitter(metadata)
|
||||||
await event_emitter(
|
await asyncio.shield(event_emitter(
|
||||||
{"type": "chat:tasks:cancel"},
|
{"type": "chat:tasks:cancel"},
|
||||||
)
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
finally:
|
||||||
|
raise # re-raise to ensure proper task cancellation handling
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error processing chat payload: {e}")
|
log.debug(f"Error processing chat payload: {e}")
|
||||||
if metadata.get("chat_id") and metadata.get("message_id"):
|
if metadata.get("chat_id") and metadata.get("message_id"):
|
||||||
|
|
@ -1591,7 +1593,7 @@ async def chat_completion(
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
if mcp_clients := metadata.get("mcp_clients"):
|
if mcp_clients := metadata.get("mcp_clients"):
|
||||||
for client in mcp_clients.values():
|
for client in reversed(mcp_clients.values()):
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug(f"Error cleaning up: {e}")
|
log.debug(f"Error cleaning up: {e}")
|
||||||
|
|
|
||||||
|
|
@ -2,35 +2,40 @@ import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.session: Optional[ClientSession] = None
|
self.session: Optional[ClientSession] = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = None
|
||||||
|
|
||||||
async def connect(self, url: str, headers: Optional[dict] = None):
|
async def connect(self, url: str, headers: Optional[dict] = None):
|
||||||
try:
|
async with AsyncExitStack() as exit_stack:
|
||||||
self._streams_context = streamablehttp_client(url, headers=headers)
|
try:
|
||||||
|
self._streams_context = streamablehttp_client(url, headers=headers)
|
||||||
|
|
||||||
transport = await self.exit_stack.enter_async_context(self._streams_context)
|
transport = await exit_stack.enter_async_context(self._streams_context)
|
||||||
read_stream, write_stream, _ = transport
|
read_stream, write_stream, _ = transport
|
||||||
|
|
||||||
self._session_context = ClientSession(
|
self._session_context = ClientSession(
|
||||||
read_stream, write_stream
|
read_stream, write_stream
|
||||||
) # pylint: disable=W0201
|
) # pylint: disable=W0201
|
||||||
|
|
||||||
self.session = await self.exit_stack.enter_async_context(
|
self.session = await exit_stack.enter_async_context(
|
||||||
self._session_context
|
self._session_context
|
||||||
)
|
)
|
||||||
await self.session.initialize()
|
with anyio.fail_after(10):
|
||||||
except Exception as e:
|
await self.session.initialize()
|
||||||
await self.disconnect()
|
self.exit_stack = exit_stack.pop_all()
|
||||||
raise e
|
except Exception as e:
|
||||||
|
await asyncio.shield(self.disconnect())
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def list_tool_specs(self) -> Optional[dict]:
|
async def list_tool_specs(self) -> Optional[dict]:
|
||||||
if not self.session:
|
if not self.session:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue