feat: experimental mcp support

This commit is contained in:
Timothy Jaeryang Baek 2025-09-23 02:03:26 -04:00
parent aeb5288a3c
commit 777e81f7a8
10 changed files with 417 additions and 105 deletions

View file

@ -1531,6 +1531,14 @@ async def chat_completion(
except: except:
pass pass
finally:
try:
if mcp_clients := metadata.get("mcp_clients"):
for client in mcp_clients:
await client.disconnect()
except Exception as e:
log.debug(f"Error cleaning up: {e}")
pass
if ( if (
metadata.get("session_id") metadata.get("session_id")

View file

@ -1,3 +1,4 @@
from cmath import log
from fastapi import APIRouter, Depends, Request, HTTPException from fastapi import APIRouter, Depends, Request, HTTPException
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -12,7 +13,7 @@ from open_webui.utils.tools import (
get_tool_server_url, get_tool_server_url,
set_tool_servers, set_tool_servers,
) )
from open_webui.utils.mcp.client import MCPClient
router = APIRouter() router = APIRouter()
@ -87,6 +88,7 @@ async def set_connections_config(
class ToolServerConnection(BaseModel): class ToolServerConnection(BaseModel):
url: str url: str
path: str path: str
type: Optional[str] = "openapi" # openapi, mcp
auth_type: Optional[str] auth_type: Optional[str]
key: Optional[str] key: Optional[str]
config: Optional[dict] config: Optional[dict]
@ -129,12 +131,56 @@ async def verify_tool_servers_config(
Verify the connection to the tool server. Verify the connection to the tool server.
""" """
try: try:
if form_data.type == "mcp":
try:
async with MCPClient() as client:
auth = None
headers = None
token = None token = None
if form_data.auth_type == "bearer": if form_data.auth_type == "bearer":
token = form_data.key token = form_data.key
elif form_data.auth_type == "session": elif form_data.auth_type == "session":
token = request.state.token.credentials token = request.state.token.credentials
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
pass
if token:
headers = {"Authorization": f"Bearer {token}"}
await client.connect(form_data.url, auth=auth, headers=headers)
specs = await client.list_tool_specs()
return {
"status": True,
"specs": specs,
}
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to create MCP client: {str(e)}",
)
else: # openapi
token = None
if form_data.auth_type == "bearer":
token = form_data.key
elif form_data.auth_type == "session":
token = request.state.token.credentials
elif form_data.auth_type == "system_oauth":
try:
if request.cookies.get("oauth_session_id", None):
token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
pass
url = get_tool_server_url(form_data.url, form_data.path) url = get_tool_server_url(form_data.url, form_data.path)
return await get_tool_server_data(token, url) return await get_tool_server_data(token, url)

View file

@ -43,6 +43,7 @@ router = APIRouter()
async def get_tools(request: Request, user=Depends(get_verified_user)): async def get_tools(request: Request, user=Depends(get_verified_user)):
tools = Tools.get_tools() tools = Tools.get_tools()
# OpenAPI Tool Servers
for server in await get_tool_servers(request): for server in await get_tool_servers(request):
tools.append( tools.append(
ToolUserResponse( ToolUserResponse(
@ -68,6 +69,29 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
) )
) )
# MCP Tool Servers
for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if server.get("type", "openapi") == "mcp":
tools.append(
ToolUserResponse(
**{
"id": f"server:mcp:{server.get('info', {}).get('id')}",
"user_id": f"server:mcp:{server.get('info', {}).get('id')}",
"name": server.get("info", {}).get("name", "MCP Tool Server"),
"meta": {
"description": server.get("info", {}).get(
"description", ""
),
},
"access_control": server.get("config", {}).get(
"access_control", None
),
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
)
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
# Admin can see all tools # Admin can see all tools
return tools return tools

View file

@ -0,0 +1,83 @@
import asyncio
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class MCPClient:
def __init__(self):
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
async def connect(
self, url: str, headers: Optional[dict] = None, auth: Optional[any] = None
):
self._streams_context = streamablehttp_client(url, headers=headers, auth=auth)
read_stream, write_stream, _ = (
await self._streams_context.__aenter__()
) # pylint: disable=E1101
self._session_context = ClientSession(
read_stream, write_stream
) # pylint: disable=W0201
self.session: ClientSession = (
await self._session_context.__aenter__()
) # pylint: disable=C2801
await self.session.initialize()
async def list_tool_specs(self) -> Optional[dict]:
if not self.session:
raise RuntimeError("MCP client is not connected.")
result = await self.session.list_tools()
tools = result.tools
tool_specs = []
for tool in tools:
name = tool.name
description = tool.description
inputSchema = tool.inputSchema
# TODO: handle outputSchema if needed
outputSchema = getattr(tool, "outputSchema", None)
tool_specs.append(
{"name": name, "description": description, "parameters": inputSchema}
)
return tool_specs
async def call_tool(
self, function_name: str, function_args: dict
) -> Optional[dict]:
if not self.session:
raise RuntimeError("MCP client is not connected.")
result = await self.session.call_tool(function_name, function_args)
return result.model_dump()
async def disconnect(self):
# Clean up and close the session
if self.session:
await self._session_context.__aexit__(
None, None, None
) # pylint: disable=E1101
if self._streams_context:
await self._streams_context.__aexit__(
None, None, None
) # pylint: disable=E1101
self.session = None
async def __aenter__(self):
await self.exit_stack.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.exit_stack.__aexit__(exc_type, exc_value, traceback)
await self.disconnect()

View file

@ -87,6 +87,7 @@ from open_webui.utils.filter import (
) )
from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.code_interpreter import execute_code_jupyter
from open_webui.utils.payload import apply_system_prompt_to_body from open_webui.utils.payload import apply_system_prompt_to_body
from open_webui.utils.mcp.client import MCPClient
from open_webui.config import ( from open_webui.config import (
@ -988,14 +989,94 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# Server side tools # Server side tools
tool_ids = metadata.get("tool_ids", None) tool_ids = metadata.get("tool_ids", None)
# Client side tools # Client side tools
tool_servers = metadata.get("tool_servers", None) direct_tool_servers = metadata.get("tool_servers", None)
log.debug(f"{tool_ids=}") log.debug(f"{tool_ids=}")
log.debug(f"{tool_servers=}") log.debug(f"{direct_tool_servers=}")
tools_dict = {} tools_dict = {}
mcp_clients = []
mcp_tools_dict = {}
if tool_ids: if tool_ids:
for tool_id in tool_ids:
if tool_id.startswith("server:mcp:"):
try:
server_id = tool_id[len("server:mcp:") :]
mcp_server_connection = None
for (
server_connection
) in request.app.state.config.TOOL_SERVER_CONNECTIONS:
if (
server_connection.get("type", "") == "mcp"
and server_connection.get("info", {}).get("id") == server_id
):
mcp_server_connection = server_connection
break
if not mcp_server_connection:
log.error(f"MCP server with id {server_id} not found")
continue
auth_type = mcp_server_connection.get("auth_type", "")
headers = {}
if auth_type == "bearer":
headers["Authorization"] = (
f"Bearer {mcp_server_connection.get('key', '')}"
)
elif auth_type == "none":
# No authentication
pass
elif auth_type == "session":
headers["Authorization"] = (
f"Bearer {request.state.token.credentials}"
)
elif auth_type == "system_oauth":
oauth_token = extra_params.get("__oauth_token__", None)
if oauth_token:
headers["Authorization"] = (
f"Bearer {oauth_token.get('access_token', '')}"
)
mcp_client = MCPClient()
await mcp_client.connect(
url=mcp_server_connection.get("url", ""),
headers=headers if headers else None,
)
tool_specs = await mcp_client.list_tool_specs()
for tool_spec in tool_specs:
def make_tool_function(function_name):
async def tool_function(**kwargs):
print(
f"Calling MCP tool {function_name} with args {kwargs}"
)
return await mcp_client.call_tool(
function_name,
function_args=kwargs,
)
return tool_function
tool_function = make_tool_function(tool_spec["name"])
mcp_tools_dict[tool_spec["name"]] = {
"spec": tool_spec,
"callable": tool_function,
"type": "mcp",
"client": mcp_client,
"direct": False,
}
mcp_clients.append(mcp_client)
except Exception as e:
log.debug(e)
continue
tools_dict = await get_tools( tools_dict = await get_tools(
request, request,
tool_ids, tool_ids,
@ -1007,9 +1088,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []), "__files__": metadata.get("files", []),
}, },
) )
if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict}
if tool_servers: if direct_tool_servers:
for tool_server in tool_servers: for tool_server in direct_tool_servers:
tool_specs = tool_server.pop("specs", []) tool_specs = tool_server.pop("specs", [])
for tool in tool_specs: for tool in tool_specs:
@ -1019,7 +1102,11 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"server": tool_server, "server": tool_server,
} }
if mcp_clients:
metadata["mcp_clients"] = mcp_clients
if tools_dict: if tools_dict:
log.info(f"tools_dict: {tools_dict}")
if metadata.get("params", {}).get("function_calling") == "native": if metadata.get("params", {}).get("function_calling") == "native":
# If the function calling is native, then call the tools function calling handler # If the function calling is native, then call the tools function calling handler
metadata["tools"] = tools_dict metadata["tools"] = tools_dict
@ -1027,6 +1114,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
{"type": "function", "function": tool.get("spec", {})} {"type": "function", "function": tool.get("spec", {})}
for tool in tools_dict.values() for tool in tools_dict.values()
] ]
else: else:
# If the function calling is not native, then call the tools function calling handler # If the function calling is not native, then call the tools function calling handler
try: try:
@ -2330,6 +2418,8 @@ async def process_chat_response(
results = [] results = []
for tool_call in response_tool_calls: for tool_call in response_tool_calls:
print("tool_call", tool_call)
tool_call_id = tool_call.get("id", "") tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get("name", "") tool_name = tool_call.get("function", {}).get("name", "")
tool_args = tool_call.get("function", {}).get("arguments", "{}") tool_args = tool_call.get("function", {}).get("arguments", "{}")
@ -2397,9 +2487,14 @@ async def process_chat_response(
else: else:
tool_function = tool["callable"] tool_function = tool["callable"]
print("tool_name", tool_name)
print("tool_function", tool_function)
print("tool_function_params", tool_function_params)
tool_result = await tool_function( tool_result = await tool_function(
**tool_function_params **tool_function_params
) )
print("tool_result", tool_result)
except Exception as e: except Exception as e:
tool_result = str(e) tool_result = str(e)

View file

@ -96,8 +96,23 @@ async def get_tools(
for tool_id in tool_ids: for tool_id in tool_ids:
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
if tool is None: if tool is None:
if tool_id.startswith("server:"): if tool_id.startswith("server:"):
server_id = tool_id.split(":")[1] splits = tool_id.split(":")
if len(splits) == 2:
type = "openapi"
server_id = splits[1]
elif len(splits) == 3:
type = splits[1]
server_id = splits[2]
server_id_splits = server_id.split("|")
if len(server_id_splits) == 2:
server_id = server_id_splits[0]
function_names = server_id_splits[1].split(",")
if type == "openapi":
tool_server_data = None tool_server_data = None
for server in await get_tool_servers(request): for server in await get_tool_servers(request):
@ -111,7 +126,9 @@ async def get_tools(
tool_server_idx = tool_server_data.get("idx", 0) tool_server_idx = tool_server_data.get("idx", 0)
tool_server_connection = ( tool_server_connection = (
request.app.state.config.TOOL_SERVER_CONNECTIONS[tool_server_idx] request.app.state.config.TOOL_SERVER_CONNECTIONS[
tool_server_idx
]
) )
specs = tool_server_data.get("specs", []) specs = tool_server_data.get("specs", [])
@ -145,7 +162,9 @@ async def get_tools(
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
def make_tool_function(function_name, tool_server_data, headers): def make_tool_function(
function_name, tool_server_data, headers
):
async def tool_function(**kwargs): async def tool_function(**kwargs):
return await execute_tool_server( return await execute_tool_server(
url=tool_server_data["url"], url=tool_server_data["url"],
@ -184,6 +203,11 @@ async def get_tools(
function_name = f"{server_id}_{function_name}" function_name = f"{server_id}_{function_name}"
tools_dict[function_name] = tool_dict tools_dict[function_name] = tool_dict
else:
log.warning(f"Unsupported tool server type: {type}")
continue
else: else:
continue continue
else: else:
@ -579,7 +603,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str,
# Prepare list of enabled servers along with their original index # Prepare list of enabled servers along with their original index
server_entries = [] server_entries = []
for idx, server in enumerate(servers): for idx, server in enumerate(servers):
if server.get("config", {}).get("enable"): if (
server.get("config", {}).get("enable")
and server.get("type", "openapi") == "openapi"
):
# Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL # Path (to OpenAPI spec URL) can be either a full URL or a path to append to the base URL
openapi_path = server.get("path", "openapi.json") openapi_path = server.get("path", "openapi.json")
full_url = get_tool_server_url(server.get("url"), openapi_path) full_url = get_tool_server_url(server.get("url"), openapi_path)

View file

@ -100,6 +100,11 @@
// remove trailing slash from url // remove trailing slash from url
url = url.replace(/\/$/, ''); url = url.replace(/\/$/, '');
if (id.includes(':') || id.includes('|')) {
toast.error($i18n.t('ID cannot contain ":" or "|" characters'));
loading = false;
return;
}
const connection = { const connection = {
url, url,
@ -214,6 +219,7 @@
{$i18n.t('OpenAPI')} {$i18n.t('OpenAPI')}
{:else if type === 'mcp'} {:else if type === 'mcp'}
{$i18n.t('MCP')} {$i18n.t('MCP')}
<span class="text-gray-500">{$i18n.t('Streamable HTTP')}</span>
{/if} {/if}
</button> </button>
</div> </div>
@ -221,6 +227,25 @@
</div> </div>
{/if} {/if}
{#if type === 'mcp'}
<div
class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-2xl text-xs px-4 py-3 mb-2"
>
<span class="font-medium">
{$i18n.t('Warning')}:
</span>
{$i18n.t(
'MCP support is experimental and its specification changes often, which can lead to incompatibilities. OpenAPI specification support is directly maintained by the Open WebUI team, making it the more reliable option for compatibility.'
)}
<a
class="font-medium underline"
href="https://docs.openwebui.com/features/mcp"
target="_blank">{$i18n.t('Read more →')}</a
>
</div>
{/if}
<div class="flex gap-2"> <div class="flex gap-2">
<div class="flex flex-col w-full"> <div class="flex flex-col w-full">
<div class="flex justify-between mb-0.5"> <div class="flex justify-between mb-0.5">
@ -372,9 +397,12 @@
for="enter-id" for="enter-id"
class={`mb-0.5 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`} class={`mb-0.5 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100' : 'text-gray-500'}`}
>{$i18n.t('ID')} >{$i18n.t('ID')}
{#if type !== 'mcp'}
<span class="text-xs text-gray-200 dark:text-gray-800 ml-0.5" <span class="text-xs text-gray-200 dark:text-gray-800 ml-0.5"
>{$i18n.t('Optional')}</span >{$i18n.t('Optional')}</span
> >
{/if}
</label> </label>
<div class="flex-1"> <div class="flex-1">
@ -385,6 +413,7 @@
bind:value={id} bind:value={id}
placeholder={$i18n.t('Enter ID')} placeholder={$i18n.t('Enter ID')}
autocomplete="off" autocomplete="off"
required={type === 'mcp'}
/> />
</div> </div>
</div> </div>

View file

@ -14,7 +14,7 @@
import Plus from '$lib/components/icons/Plus.svelte'; import Plus from '$lib/components/icons/Plus.svelte';
import Connection from '$lib/components/chat/Settings/Tools/Connection.svelte'; import Connection from '$lib/components/chat/Settings/Tools/Connection.svelte';
import AddServerModal from '$lib/components/AddServerModal.svelte'; import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
import { getToolServerConnections, setToolServerConnections } from '$lib/apis/configs'; import { getToolServerConnections, setToolServerConnections } from '$lib/apis/configs';
export let saveSettings: Function; export let saveSettings: Function;
@ -47,7 +47,7 @@
}); });
</script> </script>
<AddServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} /> <AddToolServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} />
<form <form
class="flex flex-col h-full justify-between text-sm" class="flex flex-col h-full justify-between text-sm"

View file

@ -14,7 +14,7 @@
import Plus from '$lib/components/icons/Plus.svelte'; import Plus from '$lib/components/icons/Plus.svelte';
import Connection from './Tools/Connection.svelte'; import Connection from './Tools/Connection.svelte';
import AddServerModal from '$lib/components/AddServerModal.svelte'; import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
export let saveSettings: Function; export let saveSettings: Function;
@ -52,7 +52,7 @@
}); });
</script> </script>
<AddServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} direct /> <AddToolServerModal bind:show={showConnectionModal} onSubmit={addConnectionHandler} direct />
<form <form
id="tab-tools" id="tab-tools"

View file

@ -6,7 +6,7 @@
import SensitiveInput from '$lib/components/common/SensitiveInput.svelte'; import SensitiveInput from '$lib/components/common/SensitiveInput.svelte';
import Cog6 from '$lib/components/icons/Cog6.svelte'; import Cog6 from '$lib/components/icons/Cog6.svelte';
import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
import AddServerModal from '$lib/components/AddServerModal.svelte'; import AddToolServerModal from '$lib/components/AddToolServerModal.svelte';
export let onDelete = () => {}; export let onDelete = () => {};
export let onSubmit = () => {}; export let onSubmit = () => {};
@ -18,7 +18,7 @@
let showDeleteConfirmDialog = false; let showDeleteConfirmDialog = false;
</script> </script>
<AddServerModal <AddToolServerModal
edit edit
{direct} {direct}
bind:show={showConfigModal} bind:show={showConfigModal}