diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 39e9c66051..7ebb76cb58 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1531,6 +1531,14 @@ async def chat_completion( except: pass + finally: + try: + if mcp_clients := metadata.get("mcp_clients"): + for client in mcp_clients: + await client.disconnect() + except Exception as e: + log.debug(f"Error cleaning up: {e}") + pass if ( metadata.get("session_id") diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 8ce4e0d247..a505c9d797 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,3 +1,4 @@ +from cmath import log from fastapi import APIRouter, Depends, Request, HTTPException from pydantic import BaseModel, ConfigDict @@ -12,7 +13,7 @@ from open_webui.utils.tools import ( get_tool_server_url, set_tool_servers, ) - +from open_webui.utils.mcp.client import MCPClient router = APIRouter() @@ -87,6 +88,7 @@ async def set_connections_config( class ToolServerConnection(BaseModel): url: str path: str + type: Optional[str] = "openapi" # openapi, mcp auth_type: Optional[str] key: Optional[str] config: Optional[dict] @@ -129,15 +131,59 @@ async def verify_tool_servers_config( Verify the connection to the tool server. """ try: + if form_data.type == "mcp": + try: + async with MCPClient() as client: + auth = None + headers = None - token = None - if form_data.auth_type == "bearer": - token = form_data.key - elif form_data.auth_type == "session": - token = request.state.token.credentials + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + pass - url = get_tool_server_url(form_data.url, form_data.path) - return await get_tool_server_data(token, url) + if token: + headers = {"Authorization": f"Bearer {token}"} + + await client.connect(form_data.url, auth=auth, headers=headers) + specs = await client.list_tool_specs() + return { + "status": True, + "specs": specs, + } + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to create MCP client: {str(e)}", + ) + else: # openapi + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + pass + + url = get_tool_server_url(form_data.url, form_data.path) + return await get_tool_server_data(token, url) except Exception as e: raise HTTPException( status_code=400, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 5f82e7f1bd..71c7069fd3 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -43,6 +43,7 @@ router = APIRouter() async def get_tools(request: Request, user=Depends(get_verified_user)): tools = Tools.get_tools() + # OpenAPI Tool Servers for server in await get_tool_servers(request): tools.append( ToolUserResponse( @@ -68,6 +69,29 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): ) ) + # MCP Tool Servers + for server in request.app.state.config.TOOL_SERVER_CONNECTIONS: + if server.get("type", "openapi") == "mcp": + tools.append( + ToolUserResponse( + **{ + "id": f"server:mcp:{server.get('info', {}).get('id')}", + "user_id": f"server:mcp:{server.get('info', {}).get('id')}", + "name": server.get("info", {}).get("name", "MCP Tool Server"), + "meta": { + "description": server.get("info", {}).get( + "description", "" + ), + }, + "access_control": server.get("config", {}).get( + "access_control", None + ), + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + ) + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: # Admin can see all tools return tools diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py new file mode 100644 index 0000000000..9b9d92f10d --- /dev/null +++ b/backend/open_webui/utils/mcp/client.py @@ -0,0 +1,83 @@ +import asyncio +from typing import Optional +from contextlib import AsyncExitStack + +from mcp import ClientSession +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +class MCPClient: + def __init__(self): + self.session: Optional[ClientSession] = None + self.exit_stack = AsyncExitStack() + + async def connect( + self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None + ): + self._streams_context = streamablehttp_client(url, headers=headers, auth=auth) + read_stream, write_stream, _ = ( + await self._streams_context.__aenter__() + ) # pylint: disable=E1101 + + self._session_context = ClientSession( + read_stream, write_stream + ) # pylint: disable=W0201 + self.session: ClientSession = ( + await self._session_context.__aenter__() + ) # pylint: disable=C2801 + + await self.session.initialize() + + async def list_tool_specs(self) -> Optional[dict]: + if not self.session: + raise RuntimeError("MCP client is not connected.") + + result = await self.session.list_tools() + tools = result.tools + + tool_specs = [] + for tool in tools: + name = tool.name + description = tool.description + + inputSchema = tool.inputSchema + + # TODO: handle outputSchema if needed + outputSchema = getattr(tool, "outputSchema", None) + + tool_specs.append( + {"name": name, "description": description, "parameters": inputSchema} + ) + + return tool_specs + + async def call_tool( + self, function_name: str, function_args: dict + ) -> Optional[dict]: + if not self.session: + raise RuntimeError("MCP client is not connected.") + + result = await self.session.call_tool(function_name, function_args) + return result.model_dump() + + async def disconnect(self): + # Clean up and close the session + if self.session: + await self._session_context.__aexit__( + None, None, None + ) # pylint: disable=E1101 + if self._streams_context: + await self._streams_context.__aexit__( + None, None, None + ) # pylint: disable=E1101 + self.session = None + + async def __aenter__(self): + await self.exit_stack.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.exit_stack.__aexit__(exc_type, exc_value, traceback) + await self.disconnect() diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 3cd7d3a6e8..286e40ebad 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -87,6 +87,7 @@ from open_webui.utils.filter import ( ) from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.payload import apply_system_prompt_to_body +from open_webui.utils.mcp.client import MCPClient from open_webui.config import ( @@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Server side tools tool_ids = metadata.get("tool_ids", None) # Client side tools - tool_servers = metadata.get("tool_servers", None) + direct_tool_servers = metadata.get("tool_servers", None) log.debug(f"{tool_ids=}") - log.debug(f"{tool_servers=}") + log.debug(f"{direct_tool_servers=}") tools_dict = {} + mcp_clients = [] + mcp_tools_dict = {} + if tool_ids: + for tool_id in tool_ids: + if tool_id.startswith("server:mcp:"): + try: + server_id = tool_id[len("server:mcp:") :] + + mcp_server_connection = None + for ( + server_connection + ) in request.app.state.config.TOOL_SERVER_CONNECTIONS: + if ( + server_connection.get("type", "") == "mcp" + and server_connection.get("info", {}).get("id") == server_id + ): + mcp_server_connection = server_connection + break + + if not mcp_server_connection: + log.error(f"MCP server with id {server_id} not found") + continue + + auth_type = mcp_server_connection.get("auth_type", "") + + headers = {} + if auth_type == "bearer": + headers["Authorization"] = ( + f"Bearer {mcp_server_connection.get('key', '')}" + ) + elif auth_type == "none": + # No authentication + pass + elif auth_type == "session": + headers["Authorization"] = ( + f"Bearer {request.state.token.credentials}" + ) + elif auth_type == "system_oauth": + oauth_token = extra_params.get("__oauth_token__", None) + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + + mcp_client = MCPClient() + await mcp_client.connect( + url=mcp_server_connection.get("url", ""), + headers=headers if headers else None, + ) + + tool_specs = await mcp_client.list_tool_specs() + for tool_spec in tool_specs: + + def make_tool_function(function_name): + async def tool_function(**kwargs): + print( + f"Calling MCP tool {function_name} with args {kwargs}" + ) + return await mcp_client.call_tool( + function_name, + function_args=kwargs, + ) + + return tool_function + + tool_function = make_tool_function(tool_spec["name"]) + + mcp_tools_dict[tool_spec["name"]] = { + "spec": tool_spec, + "callable": tool_function, + "type": "mcp", + "client": mcp_client, + "direct": False, + } + + mcp_clients.append(mcp_client) + except Exception as e: + log.debug(e) + continue + tools_dict = await get_tools( request, tool_ids, @@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__files__": metadata.get("files", []), }, ) + if mcp_tools_dict: + tools_dict = {**tools_dict, **mcp_tools_dict} - if tool_servers: - for tool_server in tool_servers: + if direct_tool_servers: + for tool_server in direct_tool_servers: tool_specs = tool_server.pop("specs", []) for tool in tool_specs: @@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): "server": tool_server, } + if mcp_clients: + metadata["mcp_clients"] = mcp_clients + if tools_dict: + log.info(f"tools_dict: {tools_dict}") if metadata.get("params", {}).get("function_calling") == "native": # If the function calling is native, then call the tools function calling handler metadata["tools"] = tools_dict @@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): {"type": "function", "function": tool.get("spec", {})} for tool in tools_dict.values() ] + else: # If the function calling is not native, then call the tools function calling handler try: @@ -2330,6 +2418,8 @@ async def process_chat_response( results = [] for tool_call in response_tool_calls: + + print("tool_call", tool_call) tool_call_id = tool_call.get("id", "") tool_name = tool_call.get("function", {}).get("name", "") tool_args = tool_call.get("function", {}).get("arguments", "{}") @@ -2397,9 +2487,14 @@ async def process_chat_response( else: tool_function = tool["callable"] + + print("tool_name", tool_name) + print("tool_function", tool_function) + print("tool_function_params", tool_function_params) tool_result = await tool_function( **tool_function_params ) + print("tool_result", tool_result) except Exception as e: tool_result = str(e) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 63df47700f..cb3626146a 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -96,94 +96,118 @@ async def get_tools( for tool_id in tool_ids: tool = Tools.get_tool_by_id(tool_id) if tool is None: + if tool_id.startswith("server:"): - server_id = tool_id.split(":")[1] + splits = tool_id.split(":") - tool_server_data = None - for server in await get_tool_servers(request): - if server["id"] == server_id: - tool_server_data = server - break + if len(splits) == 2: + type = "openapi" + server_id = splits[1] + elif len(splits) == 3: + type = splits[1] + server_id = splits[2] - if tool_server_data is None: - log.warning(f"Tool server data not found for {server_id}") + server_id_splits = server_id.split("|") + if len(server_id_splits) == 2: + server_id = server_id_splits[0] + function_names = server_id_splits[1].split(",") + + if type == "openapi": + + tool_server_data = None + for server in await get_tool_servers(request): + if server["id"] == server_id: + tool_server_data = server + break + + if tool_server_data is None: + log.warning(f"Tool server data not found for {server_id}") + continue + + tool_server_idx = tool_server_data.get("idx", 0) + tool_server_connection = ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[ + tool_server_idx + ] + ) + + specs = tool_server_data.get("specs", []) + for spec in specs: + function_name = spec["name"] + + auth_type = tool_server_connection.get("auth_type", "bearer") + + cookies = {} + headers = {} + + if auth_type == "bearer": + headers["Authorization"] = ( + f"Bearer {tool_server_connection.get('key', '')}" + ) + elif auth_type == "none": + # No authentication + pass + elif auth_type == "session": + cookies = request.cookies + headers["Authorization"] = ( + f"Bearer {request.state.token.credentials}" + ) + elif auth_type == "system_oauth": + cookies = request.cookies + oauth_token = extra_params.get("__oauth_token__", None) + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + + headers["Content-Type"] = "application/json" + + def make_tool_function( + function_name, tool_server_data, headers + ): + async def tool_function(**kwargs): + return await execute_tool_server( + url=tool_server_data["url"], + headers=headers, + cookies=cookies, + name=function_name, + params=kwargs, + server_data=tool_server_data, + ) + + return tool_function + + tool_function = make_tool_function( + function_name, tool_server_data, headers + ) + + callable = get_async_tool_function_and_apply_extra_params( + tool_function, + {}, + ) + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + # Misc info + "type": "external", + } + + # Handle function name collisions + while function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + # Prepend server ID to function name + function_name = f"{server_id}_{function_name}" + + tools_dict[function_name] = tool_dict + + else: + log.warning(f"Unsupported tool server type: {type}") continue - tool_server_idx = tool_server_data.get("idx", 0) - tool_server_connection = ( - request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx] - ) - - specs = tool_server_data.get("specs", []) - for spec in specs: - function_name = spec["name"] - - auth_type = tool_server_connection.get("auth_type", "bearer") - - cookies = {} - headers = {} - - if auth_type == "bearer": - headers["Authorization"] = ( - f"Bearer {tool_server_connection.get('key', '')}" - ) - elif auth_type == "none": - # No authentication - pass - elif auth_type == "session": - cookies = request.cookies - headers["Authorization"] = ( - f"Bearer {request.state.token.credentials}" - ) - elif auth_type == "system_oauth": - cookies = request.cookies - oauth_token = extra_params.get("__oauth_token__", None) - if oauth_token: - headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" - ) - - headers["Content-Type"] = "application/json" - - def make_tool_function(function_name, tool_server_data, headers): - async def tool_function(**kwargs): - return await execute_tool_server( - url=tool_server_data["url"], - headers=headers, - cookies=cookies, - name=function_name, - params=kwargs, - server_data=tool_server_data, - ) - - return tool_function - - tool_function = make_tool_function( - function_name, tool_server_data, headers - ) - - callable = get_async_tool_function_and_apply_extra_params( - tool_function, - {}, - ) - - tool_dict = { - "tool_id": tool_id, - "callable": callable, - "spec": spec, - # Misc info - "type": "external", - } - - # Handle function name collisions - while function_name in tools_dict: - log.warning( - f"Tool {function_name} already exists in another tools!" - ) - # Prepend server ID to function name - function_name = f"{server_id}_{function_name}" - - tools_dict[function_name] = tool_dict else: continue else: @@ -579,7 +603,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, # Prepare list of enabled servers along with their original index server_entries = [] for idx, server in enumerate(servers): - if server.get("config", {}).get("enable"): + if ( + server.get("config", {}).get("enable") + and server.get("type", "openapi") == "openapi" + ): # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL openapi_path = server.get("path", "openapi.json") full_url = get_tool_server_url(server.get("url"), openapi_path) diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddToolServerModal.svelte similarity index 92% rename from src/lib/components/AddServerModal.svelte rename to src/lib/components/AddToolServerModal.svelte index fe87cde45f..01c87010ef 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -100,6 +100,11 @@ // remove trailing slash from url url = url.replace(/\/$/, ''); + if (id.includes(':') || id.includes('|')) { + toast.error($i18n.t('ID cannot contain ":" or "|" characters')); + loading = false; + return; + } const connection = { url, @@ -214,6 +219,7 @@ {$i18n.t('OpenAPI')} {:else if type === 'mcp'} {$i18n.t('MCP')} + {$i18n.t('Streamable HTTP')} {/if} @@ -221,6 +227,25 @@ {/if} + {#if type === 'mcp'} +