This commit is contained in:
Timothy Jaeryang Baek 2025-09-23 03:32:25 -04:00
parent c55afc4255
commit 61f20acf61
2 changed files with 45 additions and 42 deletions

View file

@ -133,7 +133,7 @@ async def verify_tool_servers_config(
try:
if form_data.type == "mcp":
try:
async with MCPClient() as client:
client = MCPClient()
auth = None
headers = None
@ -145,10 +145,12 @@ async def verify_tool_servers_config(
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token(
token = (
await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
)
except Exception as e:
pass
@ -166,6 +168,9 @@ async def verify_tool_servers_config(
status_code=400,
detail=f"Failed to create MCP client: {str(e)}",
)
finally:
if client:
await client.disconnect()
else: # openapi
token = None
if form_data.auth_type == "bearer":

View file

@ -16,19 +16,25 @@ class MCPClient:
async def connect(
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
):
self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
read_stream, write_stream, _ = (
await self._streams_context.__aenter__()
) # pylint: disable=E1101
try:
self._streams_context = streamablehttp_client(
url, headers=headers, auth=auth
)
transport = await self.exit_stack.enter_async_context(self._streams_context)
read_stream, write_stream, _ = transport
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self.session: ClientSession = (
await self._session_context.__aenter__()
) # pylint: disable=C2801
self.session = await self.exit_stack.enter_async_context(
self._session_context
)
await self.session.initialize()
except Exception as e:
await self.disconnect()
raise e
async def list_tool_specs(self) -> Optional[dict]:
if not self.session:
@ -97,15 +103,7 @@ class MCPClient:
async def disconnect(self):
# Clean up and close the session
if self.session:
await self._session_context.__aexit__(
None, None, None
) # pylint: disable=E1101
if self._streams_context:
await self._streams_context.__aexit__(
None, None, None
) # pylint: disable=E1101
self.session = None
await self.exit_stack.aclose()
async def __aenter__(self):
await self.exit_stack.__aenter__()