feat: experimental mcp support

This commit is contained in:
Timothy Jaeryang Baek 2025-09-23 02:03:26 -04:00
parent aeb5288a3c
commit 777e81f7a8
10 changed files with 417 additions and 105 deletions

View file

@ -1531,6 +1531,14 @@ async def chat_completion(
except:
pass
finally:
try:
if mcp_clients := metadata.get("mcp_clients"):
for client in mcp_clients:
await client.disconnect()
except Exception as e:
log.debug(f"Error cleaning up: {e}")
pass
if (
metadata.get("session_id")

View file

@ -1,3 +1,4 @@
from cmath import log
from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict
@ -12,7 +13,7 @@ from open_webui.utils.tools import (
get_tool_server_url,
set_tool_servers,
)
from open_webui.utils.mcp.client import MCPClient
router = APIRouter()
@ -87,6 +88,7 @@ async def set_connections_config(
class ToolServerConnection(BaseModel):
url: str
path: str
type: Optional[str] = "openapi" # openapi, mcp
auth_type: Optional[str]
key: Optional[str]
config: Optional[dict]
@ -129,15 +131,59 @@ async def verify_tool_servers_config(
Verify the connection to the tool server.
"""
try:
if form_data.type == "mcp":
try:
async with MCPClient() as client:
auth = None
headers = None
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
pass
url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(token, url)
if token:
headers = {"Authorization": f"Bearer {token}"}
await client.connect(form_data.url, auth=auth, headers=headers)
specs = await client.list_tool_specs()
return {
"status": True,
"specs": specs,
}
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to create MCP client: {str(e)}",
)
else: # openapi
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
pass
url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(token, url)
except Exception as e:
raise HTTPException(
status_code=400,

View file

@ -43,6 +43,7 @@ router = APIRouter()
async def get_tools(request: Request, user=Depends(get_verified_user)):
tools = Tools.get_tools()
# OpenAPI Tool Servers
for server in await get_tool_servers(request):
tools.append(
ToolUserResponse(
@ -68,6 +69,29 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
)
)
# MCP Tool Servers
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if server.get("type", "openapi") == "mcp":
tools.append(
ToolUserResponse(
**{
"id": f"server:mcp:{server.get('info', {}).get('id')}",
"user_id": f"server:mcp:{server.get('info', {}).get('id')}",
"name": server.get("info", {}).get("name", "MCP Tool Server"),
"meta": {
"description": server.get("info", {}).get(
"description", ""
),
},
"access_control": server.get("config", {}).get(
"access_control", None
),
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
# Admin can see all tools
return tools

View file

@ -0,0 +1,83 @@
import asyncio
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class MCPClient:
def __init__(self):
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
async def connect(
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
):
self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
read_stream, write_stream, _ = (
await self._streams_context.__aenter__()
) # pylint: disable=E1101
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self.session: ClientSession = (
await self._session_context.__aenter__()
) # pylint: disable=C2801
await self.session.initialize()
async def list_tool_specs(self) -> Optional[dict]:
if not self.session:
raise RuntimeError("MCP client is not connected.")
result = await self.session.list_tools()
tools = result.tools
tool_specs = []
for tool in tools:
name = tool.name
description = tool.description
inputSchema = tool.inputSchema
# TODO: handle outputSchema if needed
outputSchema = getattr(tool, "outputSchema", None)
tool_specs.append(
{"name": name, "description": description, "parameters": inputSchema}
)
return tool_specs
async def call_tool(
self, function_name: str, function_args: dict
) -> Optional[dict]:
if not self.session:
raise RuntimeError("MCP client is not connected.")
result = await self.session.call_tool(function_name, function_args)
return result.model_dump()
async def disconnect(self):
# Clean up and close the session
if self.session:
await self._session_context.__aexit__(
None, None, None
) # pylint: disable=E1101
if self._streams_context:
await self._streams_context.__aexit__(
None, None, None
) # pylint: disable=E1101
self.session = None
async def __aenter__(self):
await self.exit_stack.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.exit_stack.__aexit__(exc_type, exc_value, traceback)
await self.disconnect()

View file

@ -87,6 +87,7 @@ from open_webui.utils.filter import (
)
from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.utils.payload import apply_system_prompt_to_body
from open_webui.utils.mcp.client import MCPClient
from open_webui.config import (
@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# Server side tools
tool_ids = metadata.get("tool_ids", None)
# Client side tools
tool_servers = metadata.get("tool_servers", None)
direct_tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}")
log.debug(f"{tool_servers=}")
log.debug(f"{direct_tool_servers=}")
tools_dict = {}
mcp_clients = []
mcp_tools_dict = {}
if tool_ids:
for tool_id in tool_ids:
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
mcp_server_connection = None
for (
server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if (
server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id
):
mcp_server_connection = server_connection
break
if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found")
continue
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
mcp_client = MCPClient()
await mcp_client.connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
tool_specs = await mcp_client.list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(function_name):
async def tool_function(**kwargs):
print(
f"Calling MCP tool {function_name} with args {kwargs}"
)
return await mcp_client.call_tool(
function_name,
function_args=kwargs,
)
return tool_function
tool_function = make_tool_function(tool_spec["name"])
mcp_tools_dict[tool_spec["name"]] = {
"spec": tool_spec,
"callable": tool_function,
"type": "mcp",
"client": mcp_client,
"direct": False,
}
mcp_clients.append(mcp_client)
except Exception as e:
log.debug(e)
continue
tools_dict = await get_tools(
request,
tool_ids,
@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []),
},
)
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
if tool_servers:
for tool_server in tool_servers:
if direct_tool_servers:
for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", [])
for tool in tool_specs:
@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server,
}
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
if tools_dict:
log.info(f"tools_dict: {tools_dict}")
if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools_dict
@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
{"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values()
]
else:
# If the function calling is not native, then call the tools function calling handler
try:
@ -2330,6 +2418,8 @@ async def process_chat_response(
results = []
for tool_call in response_tool_calls:
print("tool_call", tool_call)
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get("name", "")
tool_args = tool_call.get("function", {}).get("arguments", "{}")
@ -2397,9 +2487,14 @@ async def process_chat_response(
else:
tool_function = tool["callable"]
print("tool_name", tool_name)
print("tool_function", tool_function)
print("tool_function_params", tool_function_params)
tool_result = await tool_function(
**tool_function_params
)
print("tool_result", tool_result)
except Exception as e:
tool_result = str(e)

View file

@ -96,94 +96,118 @@ async def get_tools(
for tool_id in tool_ids:
tool = Tools.get_tool_by_id(tool_id)
if tool is None:
if tool_id.startswith("server:"):
server_id = tool_id.split(":")[1]
splits = tool_id.split(":")
tool_server_data = None
for server in await get_tool_servers(request):
if server["id"] == server_id:
tool_server_data = server
break
if len(splits) == 2:
type = "openapi"
server_id = splits[1]
elif len(splits) == 3:
type = splits[1]
server_id = splits[2]
if tool_server_data is None:
log.warning(f"Tool server data not found for {server_id}")
server_id_splits = server_id.split("|")
if len(server_id_splits) == 2:
server_id = server_id_splits[0]
function_names = server_id_splits[1].split(",")
if type == "openapi":
tool_server_data = None
for server in await get_tool_servers(request):
if server["id"] == server_id:
tool_server_data = server
break
if tool_server_data is None:
log.warning(f"Tool server data not found for {server_id}")
continue
tool_server_idx = tool_server_data.get("idx", 0)
tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[
tool_server_idx
]
)
specs = tool_server_data.get("specs", [])
for spec in specs:
function_name = spec["name"]
auth_type = tool_server_connection.get("auth_type", "bearer")
cookies = {}
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {tool_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
cookies = request.cookies
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
cookies = request.cookies
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
headers["Content-Type"] = "application/json"
def make_tool_function(
function_name, tool_server_data, headers
):
async def tool_function(**kwargs):
return await execute_tool_server(
url=tool_server_data["url"],
headers=headers,
cookies=cookies,
name=function_name,
params=kwargs,
server_data=tool_server_data,
)
return tool_function
tool_function = make_tool_function(
function_name, tool_server_data, headers
)
callable = get_async_tool_function_and_apply_extra_params(
tool_function,
{},
)
tool_dict = {
"tool_id": tool_id,
"callable": callable,
"spec": spec,
# Misc info
"type": "external",
}
# Handle function name collisions
while function_name in tools_dict:
log.warning(
f"Tool {function_name} already exists in another tools!"
)
# Prepend server ID to function name
function_name = f"{server_id}_{function_name}"
tools_dict[function_name] = tool_dict
else:
log.warning(f"Unsupported tool server type: {type}")
continue
tool_server_idx = tool_server_data.get("idx", 0)
tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx]
)
specs = tool_server_data.get("specs", [])
for spec in specs:
function_name = spec["name"]
auth_type = tool_server_connection.get("auth_type", "bearer")
cookies = {}
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {tool_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
cookies = request.cookies
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
cookies = request.cookies
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
headers["Content-Type"] = "application/json"
def make_tool_function(function_name, tool_server_data, headers):
async def tool_function(**kwargs):
return await execute_tool_server(
url=tool_server_data["url"],
headers=headers,
cookies=cookies,
name=function_name,
params=kwargs,
server_data=tool_server_data,
)
return tool_function
tool_function = make_tool_function(
function_name, tool_server_data, headers
)
callable = get_async_tool_function_and_apply_extra_params(
tool_function,
{},
)
tool_dict = {
"tool_id": tool_id,
"callable": callable,
"spec": spec,
# Misc info
"type": "external",
}
# Handle function name collisions
while function_name in tools_dict:
log.warning(
f"Tool {function_name} already exists in another tools!"
)
# Prepend server ID to function name
function_name = f"{server_id}_{function_name}"
tools_dict[function_name] = tool_dict
else:
continue
else:
@ -579,7 +603,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
# Prepare list of enabled servers along with their original index
server_entries = []
for idx, server in enumerate(servers):
if server.get("config", {}).get("enable"):
if (
server.get("config", {}).get("enable")
and server.get("type", "openapi") == "openapi"
):
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
openapi_path = server.get("path", "openapi.json")
full_url = get_tool_server_url(server.get("url"), openapi_path)

View file

@ -100,6 +100,11 @@
// remove trailing slash from url
url = url.replace(/\/$/, '');
if (id.includes(':') || id.includes('|')) {
toast.error($i18n.t('ID cannot contain ":" or "|" characters'));
loading = false;
return;
}
const connection = {
url,
@ -214,6 +219,7 @@
{$i18n.t('OpenAPI')}
{:else if type === 'mcp'}
{$i18n.t('MCP')}
<span class="text-gray-500">{$i18n.t('Streamable HTTP')}</span>
{/if}
</button>
</div>
@ -221,6 +227,25 @@
</div>
{/if}
{#if type === 'mcp'}
<div
class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-2xl text-xs px-4 py-3 mb-2"
>
<span class="font-medium">
{$i18n.t('Warning')}:
</span>
{$i18n.t(
'MCP support is experimental and its specification changes often, which can lead to incompatibilities. OpenAPI specification support is directly maintained by the Open WebUI team, making it the more reliable option for compatibility.'
)}
<a
class="font-medium underline"
href="https://docs.openwebui.com/features/mcp"
target="_blank">{$i18n.t('Read more →')}</a
>
</div>
{/if}
<div class="flex gap-2">
<div class="flex flex-col w-full">
<div class="flex justify-between mb-0.5">
@ -372,9 +397,12 @@
for="enter-id"
class={`mb-0.5 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>{$i18n.t('ID')}
<span class="text-xs text-gray-200 dark:text-gray-800 ml-0.5"
>{$i18n.t('Optional')}</span
>
{#if type !== 'mcp'}
<span class="text-xs text-gray-200 dark:text-gray-800 ml-0.5"
>{$i18n.t('Optional')}</span
>
{/if}
</label>
<div class="flex-1">
@ -385,6 +413,7 @@
bind:value={id}
placeholder={$i18n.t('Enter ID')}
autocomplete="off"
required={type === 'mcp'}
/>
</div>
</div>

View file

@ -14,7 +14,7 @@
import Plus from '$lib/components/icons/Plus.svelte';
import Connection from '$lib/components/chat/Settings/Tools/Connection.svelte';
import AddServerModal from '$lib/components/AddServerModal.svelte';
import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
import { getToolServerConnections, setToolServerConnections } from '$lib/apis/configs';
export let saveSettings: Function;
@ -47,7 +47,7 @@
});
</script>
<AddServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} />
<AddToolServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} />
<form
class="flex flex-col h-full justify-between text-sm"

View file

@ -14,7 +14,7 @@
import Plus from '$lib/components/icons/Plus.svelte';
import Connection from './Tools/Connection.svelte';
import AddServerModal from '$lib/components/AddServerModal.svelte';
import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
export let saveSettings: Function;
@ -52,7 +52,7 @@
});
</script>
<AddServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} direct />
<AddToolServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} direct />
<form
id="tab-tools"

View file

@ -6,7 +6,7 @@
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import Cog6 from '$lib/components/icons/Cog6.svelte';
import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
import AddServerModal from '$lib/components/AddServerModal.svelte';
import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
export let onDelete = () => {};
export let onSubmit = () => {};
@ -18,7 +18,7 @@
let showDeleteConfirmDialog = false;
</script>
<AddServerModal
<AddToolServerModal
edit
{direct}
bind:show={showConfigModal}