feat/enh: tool server function name filter list

This commit is contained in:
Timothy Jaeryang Baek 2025-11-25 02:31:34 -05:00
parent 488631db98
commit 743199f2d0
6 changed files with 93 additions and 36 deletions

View file

@ -5,7 +5,8 @@ from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname from open_webui.retrieval.web.utils import resolve_hostname
from open_webui.utils.misc import is_string_allowed
def get_filtered_results(results, filter_list): def get_filtered_results(results, filter_list):

View file

@ -42,7 +42,7 @@ from open_webui.config import (
WEB_FETCH_FILTER_LIST, WEB_FETCH_FILTER_LIST,
) )
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.misc import is_string_allowed
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -59,39 +59,6 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses return ipv4_addresses, ipv6_addresses
def get_allow_block_lists(filter_list):
allow_list = []
block_list = []
if filter_list:
for d in filter_list:
if d.startswith("!"):
# Domains starting with "!" → blocked
block_list.append(d[1:])
else:
# Domains starting without "!" → allowed
allow_list.append(d)
return allow_list, block_list
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
if not filter_list:
return True
allow_list, block_list = get_allow_block_lists(filter_list)
# If allow list is non-empty, require domain to match one of them
if allow_list:
if not any(string.endswith(allowed) for allowed in allow_list):
return False
# Block list always removes matches
if any(string.endswith(blocked) for blocked in block_list):
return False
return True
def validate_url(url: Union[str, Sequence[str]]): def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str): if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):

View file

@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
from starlette.responses import Response, StreamingResponse, JSONResponse from starlette.responses import Response, StreamingResponse, JSONResponse
from open_webui.utils.misc import is_string_allowed
from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.oauth_sessions import OAuthSessions
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.folders import Folders from open_webui.models.folders import Folders
@ -1408,6 +1409,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
headers=headers if headers else None, headers=headers if headers else None,
) )
function_name_filter_list = mcp_server_connection.get(
"function_name_filter_list", None
)
tool_specs = await mcp_clients[server_id].list_tool_specs() tool_specs = await mcp_clients[server_id].list_tool_specs()
for tool_spec in tool_specs: for tool_spec in tool_specs:
@ -1420,6 +1424,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
return tool_function return tool_function
if function_name_filter_list and isinstance(
function_name_filter_list, list
):
if not is_string_allowed(
tool_spec["name"], function_name_filter_list
):
# Skip this function
continue
tool_function = make_tool_function( tool_function = make_tool_function(
mcp_clients[server_id], tool_spec["name"] mcp_clients[server_id], tool_spec["name"]
) )
@ -1460,6 +1473,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
"__files__": metadata.get("files", []), "__files__": metadata.get("files", []),
}, },
) )
if mcp_tools_dict: if mcp_tools_dict:
tools_dict = {**tools_dict, **mcp_tools_dict} tools_dict = {**tools_dict, **mcp_tools_dict}

View file

@ -27,6 +27,45 @@ def deep_update(d, u):
return d return d
def get_allow_block_lists(filter_list):
allow_list = []
block_list = []
if filter_list:
for d in filter_list:
if d.startswith("!"):
# Domains starting with "!" → blocked
block_list.append(d[1:])
else:
# Domains starting without "!" → allowed
allow_list.append(d)
return allow_list, block_list
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
"""
Checks if a string is allowed based on the provided filter list.
:param string: The string to check (e.g., domain or hostname).
:param filter_list: List of allowed/blocked strings. Strings starting with "!" are blocked.
:return: True if the string is allowed, False otherwise.
"""
if not filter_list:
return True
allow_list, block_list = get_allow_block_lists(filter_list)
# If allow list is non-empty, require domain to match one of them
if allow_list:
if not any(string.endswith(allowed) for allowed in allow_list):
return False
# Block list always removes matches
if any(string.endswith(blocked) for blocked in block_list):
return False
return True
def get_message_list(messages_map, message_id): def get_message_list(messages_map, message_id):
""" """
Reconstructs a list of messages in order up to the specified message_id. Reconstructs a list of messages in order up to the specified message_id.

View file

@ -34,6 +34,7 @@ from langchain_core.utils.function_calling import (
) )
from open_webui.utils.misc import is_string_allowed
from open_webui.models.tools import Tools from open_webui.models.tools import Tools
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.utils.plugin import load_tool_module_by_id from open_webui.utils.plugin import load_tool_module_by_id
@ -149,8 +150,20 @@ async def get_tools(
) )
specs = tool_server_data.get("specs", []) specs = tool_server_data.get("specs", [])
function_name_filter_list = tool_server_connection.get(
"function_name_filter_list", None
)
for spec in specs: for spec in specs:
function_name = spec["name"] function_name = spec["name"]
if function_name_filter_list and isinstance(
function_name_filter_list, list
):
if not is_string_allowed(
function_name, function_name_filter_list
):
# Skip this function
continue
auth_type = tool_server_connection.get("auth_type", "bearer") auth_type = tool_server_connection.get("auth_type", "bearer")

View file

@ -47,6 +47,7 @@
let key = ''; let key = '';
let headers = ''; let headers = '';
let functionNameFilterList = [];
let accessControl = {}; let accessControl = {};
let id = ''; let id = '';
@ -303,7 +304,7 @@
key, key,
config: { config: {
enable: enable, enable: enable,
function_name_filter_list: functionNameFilterList,
access_control: accessControl access_control: accessControl
}, },
info: { info: {
@ -333,9 +334,11 @@
id = ''; id = '';
name = ''; name = '';
description = ''; description = '';
oauthClientInfo = null; oauthClientInfo = null;
enable = true; enable = true;
functionNameFilterList = [];
accessControl = null; accessControl = null;
}; };
@ -359,6 +362,7 @@
oauthClientInfo = connection.info?.oauth_client_info ?? null; oauthClientInfo = connection.info?.oauth_client_info ?? null;
enable = connection.config?.enable ?? true; enable = connection.config?.enable ?? true;
functionNameFilterList = connection.config?.function_name_filter_list ?? [];
accessControl = connection.config?.access_control ?? null; accessControl = connection.config?.access_control ?? null;
} }
}; };
@ -793,6 +797,25 @@
</div> </div>
</div> </div>
<div class="flex flex-col w-full mt-2">
<label
for="function-name-filter-list"
class={`mb-1 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100 placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700 text-gray-500'}`}
>{$i18n.t('Function Name Filter List')}</label
>
<div class="flex-1">
<input
id="function-name-filter-list"
class={`w-full text-sm bg-transparent ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
type="text"
bind:value={functionNameFilterList}
placeholder={$i18n.t('Enter function name filter list (e.g. func1, !func2)')}
autocomplete="off"
/>
</div>
</div>
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" /> <hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
<div class="my-2 -mx-2"> <div class="my-2 -mx-2">