Merge pull request #18537 from OAburub/patch

fix: prevent cancellation scope corruption by exitting in LIFO and ha…
This commit is contained in:
Tim Baek 2025-10-25 22:47:40 -07:00 committed by GitHub
commit a4d0bd1073
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 19 deletions

View file

@ -1556,11 +1556,13 @@ async def chat_completion(
log.info("Chat processing was cancelled") log.info("Chat processing was cancelled")
try: try:
event_emitter = get_event_emitter(metadata) event_emitter = get_event_emitter(metadata)
await event_emitter( await asyncio.shield(event_emitter(
{"type": "chat:tasks:cancel"}, {"type": "chat:tasks:cancel"},
) ))
except Exception as e: except Exception as e:
pass pass
finally:
raise # re-raise to ensure proper task cancellation handling
except Exception as e: except Exception as e:
log.debug(f"Error processing chat payload: {e}") log.debug(f"Error processing chat payload: {e}")
if metadata.get("chat_id") and metadata.get("message_id"): if metadata.get("chat_id") and metadata.get("message_id"):
@ -1591,7 +1593,7 @@ async def chat_completion(
finally: finally:
try: try:
if mcp_clients := metadata.get("mcp_clients"): if mcp_clients := metadata.get("mcp_clients"):
for client in mcp_clients.values(): for client in reversed(mcp_clients.values()):
await client.disconnect() await client.disconnect()
except Exception as e: except Exception as e:
log.debug(f"Error cleaning up: {e}") log.debug(f"Error cleaning up: {e}")

View file

@ -2,35 +2,40 @@ import asyncio
from typing import Optional from typing import Optional
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
import anyio
from mcp import ClientSession from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamablehttp_client from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class MCPClient: class MCPClient:
def __init__(self): def __init__(self):
self.session: Optional[ClientSession] = None self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack() self.exit_stack = None
async def connect(self, url: str, headers: Optional[dict] = None): async def connect(self, url: str, headers: Optional[dict] = None):
try: async with AsyncExitStack() as exit_stack:
self._streams_context = streamablehttp_client(url, headers=headers) try:
self._streams_context = streamablehttp_client(url, headers=headers)
transport = await self.exit_stack.enter_async_context(self._streams_context) transport = await exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport read_stream, write_stream, _ = transport
self._session_context = ClientSession( self._session_context = ClientSession(
read_stream, write_stream read_stream, write_stream
) # pylint: disable=W0201 ) # pylint: disable=W0201
self.session = await self.exit_stack.enter_async_context( self.session = await exit_stack.enter_async_context(
self._session_context self._session_context
) )
await self.session.initialize() with anyio.fail_after(10):
except Exception as e: await self.session.initialize()
await self.disconnect() self.exit_stack = exit_stack.pop_all()
raise e except Exception as e:
await asyncio.shield(self.disconnect())
raise e
async def list_tool_specs(self) -> Optional[dict]: async def list_tool_specs(self) -> Optional[dict]:
if not self.session: if not self.session: