From 743199f2d097ae1458381bce450d9025a0ab3f3d Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 25 Nov 2025 02:31:34 -0500 Subject: [PATCH] feat/enh: tool server function name filter list --- backend/open_webui/retrieval/web/main.py | 3 +- backend/open_webui/retrieval/web/utils.py | 35 +----------------- backend/open_webui/utils/middleware.py | 14 +++++++ backend/open_webui/utils/misc.py | 39 ++++++++++++++++++++ backend/open_webui/utils/tools.py | 13 +++++++ src/lib/components/AddToolServerModal.svelte | 25 ++++++++++++- 6 files changed, 93 insertions(+), 36 deletions(-) diff --git a/backend/open_webui/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py index d8cfb11ba0..6d2fd1bc5a 100644 --- a/backend/open_webui/retrieval/web/main.py +++ b/backend/open_webui/retrieval/web/main.py @@ -5,7 +5,8 @@ from urllib.parse import urlparse from pydantic import BaseModel -from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname +from open_webui.retrieval.web.utils import resolve_hostname +from open_webui.utils.misc import is_string_allowed def get_filtered_results(results, filter_list): diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 127c703442..bdbde0b3a9 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -42,7 +42,7 @@ from open_webui.config import ( WEB_FETCH_FILTER_LIST, ) from open_webui.env import SRC_LOG_LEVELS - +from open_webui.utils.misc import is_string_allowed log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -59,39 +59,6 @@ def resolve_hostname(hostname): return ipv4_addresses, ipv6_addresses -def get_allow_block_lists(filter_list): - allow_list = [] - block_list = [] - - if filter_list: - for d in filter_list: - if d.startswith("!"): - # Domains starting with "!" → blocked - block_list.append(d[1:]) - else: - # Domains starting without "!" → allowed - allow_list.append(d) - - return allow_list, block_list - - -def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool: - if not filter_list: - return True - - allow_list, block_list = get_allow_block_lists(filter_list) - # If allow list is non-empty, require domain to match one of them - if allow_list: - if not any(string.endswith(allowed) for allowed in allow_list): - return False - - # Block list always removes matches - if any(string.endswith(blocked) for blocked in block_list): - return False - - return True - - def validate_url(url: Union[str, Sequence[str]]): if isinstance(url, str): if isinstance(validators.url(url), validators.ValidationError): diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 4a4e0ea6be..323f93f450 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse from starlette.responses import Response, StreamingResponse, JSONResponse +from open_webui.utils.misc import is_string_allowed from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.chats import Chats from open_webui.models.folders import Folders @@ -1408,6 +1409,9 @@ async def process_chat_payload(request, form_data, user, metadata, model): headers=headers if headers else None, ) + function_name_filter_list = mcp_server_connection.get( + "function_name_filter_list", None + ) tool_specs = await mcp_clients[server_id].list_tool_specs() for tool_spec in tool_specs: @@ -1420,6 +1424,15 @@ async def process_chat_payload(request, form_data, user, metadata, model): return tool_function + if function_name_filter_list and isinstance( + function_name_filter_list, list + ): + if not is_string_allowed( + tool_spec["name"], function_name_filter_list + ): + # Skip this function + continue + tool_function = make_tool_function( mcp_clients[server_id], tool_spec["name"] ) @@ -1460,6 +1473,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): "__files__": metadata.get("files", []), }, ) + if mcp_tools_dict: tools_dict = {**tools_dict, **mcp_tools_dict} diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index ce16691365..466e235598 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -27,6 +27,45 @@ def deep_update(d, u): return d +def get_allow_block_lists(filter_list): + allow_list = [] + block_list = [] + + if filter_list: + for d in filter_list: + if d.startswith("!"): + # Domains starting with "!" → blocked + block_list.append(d[1:]) + else: + # Domains starting without "!" → allowed + allow_list.append(d) + + return allow_list, block_list + + +def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool: + """ + Checks if a string is allowed based on the provided filter list. + :param string: The string to check (e.g., domain or hostname). + :param filter_list: List of allowed/blocked strings. Strings starting with "!" are blocked. + :return: True if the string is allowed, False otherwise. + """ + if not filter_list: + return True + + allow_list, block_list = get_allow_block_lists(filter_list) + # If allow list is non-empty, require domain to match one of them + if allow_list: + if not any(string.endswith(allowed) for allowed in allow_list): + return False + + # Block list always removes matches + if any(string.endswith(blocked) for blocked in block_list): + return False + + return True + + def get_message_list(messages_map, message_id): """ Reconstructs a list of messages in order up to the specified message_id. diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index fb623ed332..ecdf7187e4 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -34,6 +34,7 @@ from langchain_core.utils.function_calling import ( ) +from open_webui.utils.misc import is_string_allowed from open_webui.models.tools import Tools from open_webui.models.users import UserModel from open_webui.utils.plugin import load_tool_module_by_id @@ -149,8 +150,20 @@ async def get_tools( ) specs = tool_server_data.get("specs", []) + function_name_filter_list = tool_server_connection.get( + "function_name_filter_list", None + ) + for spec in specs: function_name = spec["name"] + if function_name_filter_list and isinstance( + function_name_filter_list, list + ): + if not is_string_allowed( + function_name, function_name_filter_list + ): + # Skip this function + continue auth_type = tool_server_connection.get("auth_type", "bearer") diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte index a2098de912..2b639b3e64 100644 --- a/src/lib/components/AddToolServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -47,6 +47,7 @@ let key = ''; let headers = ''; + let functionNameFilterList = []; let accessControl = {}; let id = ''; @@ -303,7 +304,7 @@ key, config: { enable: enable, - + function_name_filter_list: functionNameFilterList, access_control: accessControl }, info: { @@ -333,9 +334,11 @@ id = ''; name = ''; description = ''; + oauthClientInfo = null; enable = true; + functionNameFilterList = []; accessControl = null; }; @@ -359,6 +362,7 @@ oauthClientInfo = connection.info?.oauth_client_info ?? null; enable = connection.config?.enable ?? true; + functionNameFilterList = connection.config?.function_name_filter_list ?? []; accessControl = connection.config?.access_control ?? null; } }; @@ -793,6 +797,25 @@ +
+ + +
+ +
+
+