fix: prevent cancellation scope corruption by exitting in LIFO and handling exceptions

This commit is contained in:
Omar Aburub 2025-10-23 15:34:47 +03:00
parent 23ea754061
commit 8f060ee2fa
2 changed files with 26 additions and 19 deletions

View file

@ -1556,11 +1556,13 @@ async def chat_completion(
log.info("Chat processing was cancelled")
try:
event_emitter = get_event_emitter(metadata)
await event_emitter(
await asyncio.shield(event_emitter(
{"type": "chat:tasks:cancel"},
)
))
except Exception as e:
pass
finally:
raise # re-raise to ensure proper task cancellation handling
except Exception as e:
log.debug(f"Error processing chat payload: {e}")
if metadata.get("chat_id") and metadata.get("message_id"):
@ -1591,7 +1593,7 @@ async def chat_completion(
finally:
try:
if mcp_clients := metadata.get("mcp_clients"):
for client in mcp_clients.values():
for client in reversed(mcp_clients.values()):
await client.disconnect()
except Exception as e:
log.debug(f"Error cleaning up: {e}")

View file

@ -2,35 +2,40 @@ import asyncio
from typing import Optional
from contextlib import AsyncExitStack
import anyio
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class MCPClient:
def __init__(self):
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self.exit_stack = None
async def connect(self, url: str, headers: Optional[dict] = None):
try:
self._streams_context = streamablehttp_client(url, headers=headers)
async with AsyncExitStack() as exit_stack:
try:
self._streams_context = streamablehttp_client(url, headers=headers)
transport = await self.exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport
transport = await exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self.session = await self.exit_stack.enter_async_context(
self._session_context
)
await self.session.initialize()
except Exception as e:
await self.disconnect()
raise e
self.session = await exit_stack.enter_async_context(
self._session_context
)
with anyio.fail_after(10):
await self.session.initialize()
self.exit_stack = exit_stack.pop_all()
except Exception as e:
await asyncio.shield(self.disconnect())
raise e
async def list_tool_specs(self) -> Optional[dict]:
if not self.session: