diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 39e9c66051..7ebb76cb58 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1531,6 +1531,14 @@ async def chat_completion( except: pass + finally: + try: + if mcp_clients := metadata.get("mcp_clients"): + for client in mcp_clients: + await client.disconnect() + except Exception as e: + log.debug(f"Error cleaning up: {e}") + pass if ( metadata.get("session_id") diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 8ce4e0d247..a505c9d797 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,3 +1,4 @@ +from cmath import log from fastapi import APIRouter, Depends, Request, HTTPException from pydantic import BaseModel, ConfigDict @@ -12,7 +13,7 @@ from open_webui.utils.tools import ( get_tool_server_url, set_tool_servers, ) - +from open_webui.utils.mcp.client import MCPClient router = APIRouter() @@ -87,6 +88,7 @@ async def set_connections_config( class ToolServerConnection(BaseModel): url: str path: str + type: Optional[str] = "openapi" # openapi, mcp auth_type: Optional[str] key: Optional[str] config: Optional[dict] @@ -129,15 +131,59 @@ async def verify_tool_servers_config( Verify the connection to the tool server. """ try: + if form_data.type == "mcp": + try: + async with MCPClient() as client: + auth = None + headers = None - token = None - if form_data.auth_type == "bearer": - token = form_data.key - elif form_data.auth_type == "session": - token = request.state.token.credentials + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + pass - url = get_tool_server_url(form_data.url, form_data.path) - return await get_tool_server_data(token, url) + if token: + headers = {"Authorization": f"Bearer {token}"} + + await client.connect(form_data.url, auth=auth, headers=headers) + specs = await client.list_tool_specs() + return { + "status": True, + "specs": specs, + } + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to create MCP client: {str(e)}", + ) + else: # openapi + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + pass + + url = get_tool_server_url(form_data.url, form_data.path) + return await get_tool_server_data(token, url) except Exception as e: raise HTTPException( status_code=400, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 5f82e7f1bd..71c7069fd3 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -43,6 +43,7 @@ router = APIRouter() async def get_tools(request: Request, user=Depends(get_verified_user)): tools = Tools.get_tools() + # OpenAPI Tool Servers for server in await get_tool_servers(request): tools.append( ToolUserResponse( @@ -68,6 +69,29 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): ) ) + # MCP Tool Servers + for server in request.app.state.config.TOOL_SERVER_CONNECTIONS: + if server.get("type", "openapi") == "mcp": + tools.append( + ToolUserResponse( + **{ + "id": f"server:mcp:{server.get('info', {}).get('id')}", + "user_id": f"server:mcp:{server.get('info', {}).get('id')}", + "name": server.get("info", {}).get("name", "MCP Tool Server"), + "meta": { + "description": server.get("info", {}).get( + "description", "" + ), + }, + "access_control": server.get("config", {}).get( + "access_control", None + ), + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + ) + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: # Admin can see all tools return tools diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py new file mode 100644 index 0000000000..9b9d92f10d --- /dev/null +++ b/backend/open_webui/utils/mcp/client.py @@ -0,0 +1,83 @@ +import asyncio +from typing import Optional +from contextlib import AsyncExitStack + +from mcp import ClientSession +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +class MCPClient: + def __init__(self): + self.session: Optional[ClientSession] = None + self.exit_stack = AsyncExitStack() + + async def connect( + self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None + ): + self._streams_context = streamablehttp_client(url, headers=headers, auth=auth) + read_stream, write_stream, _ = ( + await self._streams_context.__aenter__() + ) # pylint: disable=E1101 + + self._session_context = ClientSession( + read_stream, write_stream + ) # pylint: disable=W0201 + self.session: ClientSession = ( + await self._session_context.__aenter__() + ) # pylint: disable=C2801 + + await self.session.initialize() + + async def list_tool_specs(self) -> Optional[dict]: + if not self.session: + raise RuntimeError("MCP client is not connected.") + + result = await self.session.list_tools() + tools = result.tools + + tool_specs = [] + for tool in tools: + name = tool.name + description = tool.description + + inputSchema = tool.inputSchema + + # TODO: handle outputSchema if needed + outputSchema = getattr(tool, "outputSchema", None) + + tool_specs.append( + {"name": name, "description": description, "parameters": inputSchema} + ) + + return tool_specs + + async def call_tool( + self, function_name: str, function_args: dict + ) -> Optional[dict]: + if not self.session: + raise RuntimeError("MCP client is not connected.") + + result = await self.session.call_tool(function_name, function_args) + return result.model_dump() + + async def disconnect(self): + # Clean up and close the session + if self.session: + await self._session_context.__aexit__( + None, None, None + ) # pylint: disable=E1101 + if self._streams_context: + await self._streams_context.__aexit__( + None, None, None + ) # pylint: disable=E1101 + self.session = None + + async def __aenter__(self): + await self.exit_stack.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.exit_stack.__aexit__(exc_type, exc_value, traceback) + await self.disconnect() diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 3cd7d3a6e8..286e40ebad 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -87,6 +87,7 @@ from open_webui.utils.filter import ( ) from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.payload import apply_system_prompt_to_body +from open_webui.utils.mcp.client import MCPClient from open_webui.config import ( @@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model): # Server side tools tool_ids = metadata.get("tool_ids", None) # Client side tools - tool_servers = metadata.get("tool_servers", None) + direct_tool_servers = metadata.get("tool_servers", None) log.debug(f"{tool_ids=}") - log.debug(f"{tool_servers=}") + log.debug(f"{direct_tool_servers=}") tools_dict = {} + mcp_clients = [] + mcp_tools_dict = {} + if tool_ids: + for tool_id in tool_ids: + if tool_id.startswith("server:mcp:"): + try: + server_id = tool_id[len("server:mcp:") :] + + mcp_server_connection = None + for ( + server_connection + ) in request.app.state.config.TOOL_SERVER_CONNECTIONS: + if ( + server_connection.get("type", "") == "mcp" + and server_connection.get("info", {}).get("id") == server_id + ): + mcp_server_connection = server_connection + break + + if not mcp_server_connection: + log.error(f"MCP server with id {server_id} not found") + continue + + auth_type = mcp_server_connection.get("auth_type", "") + + headers = {} + if auth_type == "bearer": + headers["Authorization"] = ( + f"Bearer {mcp_server_connection.get('key', '')}" + ) + elif auth_type == "none": + # No authentication + pass + elif auth_type == "session": + headers["Authorization"] = ( + f"Bearer {request.state.token.credentials}" + ) + elif auth_type == "system_oauth": + oauth_token = extra_params.get("__oauth_token__", None) + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + + mcp_client = MCPClient() + await mcp_client.connect( + url=mcp_server_connection.get("url", ""), + headers=headers if headers else None, + ) + + tool_specs = await mcp_client.list_tool_specs() + for tool_spec in tool_specs: + + def make_tool_function(function_name): + async def tool_function(**kwargs): + print( + f"Calling MCP tool {function_name} with args {kwargs}" + ) + return await mcp_client.call_tool( + function_name, + function_args=kwargs, + ) + + return tool_function + + tool_function = make_tool_function(tool_spec["name"]) + + mcp_tools_dict[tool_spec["name"]] = { + "spec": tool_spec, + "callable": tool_function, + "type": "mcp", + "client": mcp_client, + "direct": False, + } + + mcp_clients.append(mcp_client) + except Exception as e: + log.debug(e) + continue + tools_dict = await get_tools( request, tool_ids, @@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__files__": metadata.get("files", []), }, ) + if mcp_tools_dict: + tools_dict = {**tools_dict, **mcp_tools_dict} - if tool_servers: - for tool_server in tool_servers: + if direct_tool_servers: + for tool_server in direct_tool_servers: tool_specs = tool_server.pop("specs", []) for tool in tool_specs: @@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): "server": tool_server, } + if mcp_clients: + metadata["mcp_clients"] = mcp_clients + if tools_dict: + log.info(f"tools_dict: {tools_dict}") if metadata.get("params", {}).get("function_calling") == "native": # If the function calling is native, then call the tools function calling handler metadata["tools"] = tools_dict @@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): {"type": "function", "function": tool.get("spec", {})} for tool in tools_dict.values() ] + else: # If the function calling is not native, then call the tools function calling handler try: @@ -2330,6 +2418,8 @@ async def process_chat_response( results = [] for tool_call in response_tool_calls: + + print("tool_call", tool_call) tool_call_id = tool_call.get("id", "") tool_name = tool_call.get("function", {}).get("name", "") tool_args = tool_call.get("function", {}).get("arguments", "{}") @@ -2397,9 +2487,14 @@ async def process_chat_response( else: tool_function = tool["callable"] + + print("tool_name", tool_name) + print("tool_function", tool_function) + print("tool_function_params", tool_function_params) tool_result = await tool_function( **tool_function_params ) + print("tool_result", tool_result) except Exception as e: tool_result = str(e) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 63df47700f..cb3626146a 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -96,94 +96,118 @@ async def get_tools( for tool_id in tool_ids: tool = Tools.get_tool_by_id(tool_id) if tool is None: + if tool_id.startswith("server:"): - server_id = tool_id.split(":")[1] + splits = tool_id.split(":") - tool_server_data = None - for server in await get_tool_servers(request): - if server["id"] == server_id: - tool_server_data = server - break + if len(splits) == 2: + type = "openapi" + server_id = splits[1] + elif len(splits) == 3: + type = splits[1] + server_id = splits[2] - if tool_server_data is None: - log.warning(f"Tool server data not found for {server_id}") + server_id_splits = server_id.split("|") + if len(server_id_splits) == 2: + server_id = server_id_splits[0] + function_names = server_id_splits[1].split(",") + + if type == "openapi": + + tool_server_data = None + for server in await get_tool_servers(request): + if server["id"] == server_id: + tool_server_data = server + break + + if tool_server_data is None: + log.warning(f"Tool server data not found for {server_id}") + continue + + tool_server_idx = tool_server_data.get("idx", 0) + tool_server_connection = ( + request.app.state.config.TOOL_SERVER_CONNECTIONS[ + tool_server_idx + ] + ) + + specs = tool_server_data.get("specs", []) + for spec in specs: + function_name = spec["name"] + + auth_type = tool_server_connection.get("auth_type", "bearer") + + cookies = {} + headers = {} + + if auth_type == "bearer": + headers["Authorization"] = ( + f"Bearer {tool_server_connection.get('key', '')}" + ) + elif auth_type == "none": + # No authentication + pass + elif auth_type == "session": + cookies = request.cookies + headers["Authorization"] = ( + f"Bearer {request.state.token.credentials}" + ) + elif auth_type == "system_oauth": + cookies = request.cookies + oauth_token = extra_params.get("__oauth_token__", None) + if oauth_token: + headers["Authorization"] = ( + f"Bearer {oauth_token.get('access_token', '')}" + ) + + headers["Content-Type"] = "application/json" + + def make_tool_function( + function_name, tool_server_data, headers + ): + async def tool_function(**kwargs): + return await execute_tool_server( + url=tool_server_data["url"], + headers=headers, + cookies=cookies, + name=function_name, + params=kwargs, + server_data=tool_server_data, + ) + + return tool_function + + tool_function = make_tool_function( + function_name, tool_server_data, headers + ) + + callable = get_async_tool_function_and_apply_extra_params( + tool_function, + {}, + ) + + tool_dict = { + "tool_id": tool_id, + "callable": callable, + "spec": spec, + # Misc info + "type": "external", + } + + # Handle function name collisions + while function_name in tools_dict: + log.warning( + f"Tool {function_name} already exists in another tools!" + ) + # Prepend server ID to function name + function_name = f"{server_id}_{function_name}" + + tools_dict[function_name] = tool_dict + + else: + log.warning(f"Unsupported tool server type: {type}") continue - tool_server_idx = tool_server_data.get("idx", 0) - tool_server_connection = ( - request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx] - ) - - specs = tool_server_data.get("specs", []) - for spec in specs: - function_name = spec["name"] - - auth_type = tool_server_connection.get("auth_type", "bearer") - - cookies = {} - headers = {} - - if auth_type == "bearer": - headers["Authorization"] = ( - f"Bearer {tool_server_connection.get('key', '')}" - ) - elif auth_type == "none": - # No authentication - pass - elif auth_type == "session": - cookies = request.cookies - headers["Authorization"] = ( - f"Bearer {request.state.token.credentials}" - ) - elif auth_type == "system_oauth": - cookies = request.cookies - oauth_token = extra_params.get("__oauth_token__", None) - if oauth_token: - headers["Authorization"] = ( - f"Bearer {oauth_token.get('access_token', '')}" - ) - - headers["Content-Type"] = "application/json" - - def make_tool_function(function_name, tool_server_data, headers): - async def tool_function(**kwargs): - return await execute_tool_server( - url=tool_server_data["url"], - headers=headers, - cookies=cookies, - name=function_name, - params=kwargs, - server_data=tool_server_data, - ) - - return tool_function - - tool_function = make_tool_function( - function_name, tool_server_data, headers - ) - - callable = get_async_tool_function_and_apply_extra_params( - tool_function, - {}, - ) - - tool_dict = { - "tool_id": tool_id, - "callable": callable, - "spec": spec, - # Misc info - "type": "external", - } - - # Handle function name collisions - while function_name in tools_dict: - log.warning( - f"Tool {function_name} already exists in another tools!" - ) - # Prepend server ID to function name - function_name = f"{server_id}_{function_name}" - - tools_dict[function_name] = tool_dict else: continue else: @@ -579,7 +603,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, # Prepare list of enabled servers along with their original index server_entries = [] for idx, server in enumerate(servers): - if server.get("config", {}).get("enable"): + if ( + server.get("config", {}).get("enable") + and server.get("type", "openapi") == "openapi" + ): # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL openapi_path = server.get("path", "openapi.json") full_url = get_tool_server_url(server.get("url"), openapi_path) diff --git a/src/lib/components/AddServerModal.svelte b/src/lib/components/AddToolServerModal.svelte similarity index 92% rename from src/lib/components/AddServerModal.svelte rename to src/lib/components/AddToolServerModal.svelte index fe87cde45f..01c87010ef 100644 --- a/src/lib/components/AddServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -100,6 +100,11 @@ // remove trailing slash from url url = url.replace(/\/$/, ''); + if (id.includes(':') || id.includes('|')) { + toast.error($i18n.t('ID cannot contain ":" or "|" characters')); + loading = false; + return; + } const connection = { url, @@ -214,6 +219,7 @@ {$i18n.t('OpenAPI')} {:else if type === 'mcp'} {$i18n.t('MCP')} + {$i18n.t('Streamable HTTP')} {/if} @@ -221,6 +227,25 @@ {/if} + {#if type === 'mcp'} +
+ + {$i18n.t('Warning')}: + + {$i18n.t( + 'MCP support is experimental and its specification changes often, which can lead to incompatibilities. OpenAPI specification support is directly maintained by the Open WebUI team, making it the more reliable option for compatibility.' + )} + + {$i18n.t('Read more →')} +
+ {/if} +
@@ -372,9 +397,12 @@ for="enter-id" class={`mb-0.5 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`} >{$i18n.t('ID')} - {$i18n.t('Optional')} + + {#if type !== 'mcp'} + {$i18n.t('Optional')} + {/if}
@@ -385,6 +413,7 @@ bind:value={id} placeholder={$i18n.t('Enter ID')} autocomplete="off" + required={type === 'mcp'} />
diff --git a/src/lib/components/admin/Settings/Tools.svelte b/src/lib/components/admin/Settings/Tools.svelte index d59621be55..28a30b23b5 100644 --- a/src/lib/components/admin/Settings/Tools.svelte +++ b/src/lib/components/admin/Settings/Tools.svelte @@ -14,7 +14,7 @@ import Plus from '$lib/components/icons/Plus.svelte'; import Connection from '$lib/components/chat/Settings/Tools/Connection.svelte'; - import AddServerModal from '$lib/components/AddServerModal.svelte'; + import AddToolServerModal from '$lib/components/AddToolServerModal.svelte'; import { getToolServerConnections, setToolServerConnections } from '$lib/apis/configs'; export let saveSettings: Function; @@ -47,7 +47,7 @@ }); - +
- + {}; export let onSubmit = () => {}; @@ -18,7 +18,7 @@ let showDeleteConfirmDialog = false; -