mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
feat/enh: tool server function name filter list
This commit is contained in:
parent
488631db98
commit
743199f2d0
6 changed files with 93 additions and 36 deletions
|
|
@ -5,7 +5,8 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname
|
from open_webui.retrieval.web.utils import resolve_hostname
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_results(results, filter_list):
|
def get_filtered_results(results, filter_list):
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ from open_webui.config import (
|
||||||
WEB_FETCH_FILTER_LIST,
|
WEB_FETCH_FILTER_LIST,
|
||||||
)
|
)
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
@ -59,39 +59,6 @@ def resolve_hostname(hostname):
|
||||||
return ipv4_addresses, ipv6_addresses
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
def get_allow_block_lists(filter_list):
|
|
||||||
allow_list = []
|
|
||||||
block_list = []
|
|
||||||
|
|
||||||
if filter_list:
|
|
||||||
for d in filter_list:
|
|
||||||
if d.startswith("!"):
|
|
||||||
# Domains starting with "!" → blocked
|
|
||||||
block_list.append(d[1:])
|
|
||||||
else:
|
|
||||||
# Domains starting without "!" → allowed
|
|
||||||
allow_list.append(d)
|
|
||||||
|
|
||||||
return allow_list, block_list
|
|
||||||
|
|
||||||
|
|
||||||
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
|
|
||||||
if not filter_list:
|
|
||||||
return True
|
|
||||||
|
|
||||||
allow_list, block_list = get_allow_block_lists(filter_list)
|
|
||||||
# If allow list is non-empty, require domain to match one of them
|
|
||||||
if allow_list:
|
|
||||||
if not any(string.endswith(allowed) for allowed in allow_list):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Block list always removes matches
|
|
||||||
if any(string.endswith(blocked) for blocked in block_list):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def validate_url(url: Union[str, Sequence[str]]):
|
def validate_url(url: Union[str, Sequence[str]]):
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
if isinstance(validators.url(url), validators.ValidationError):
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from fastapi.responses import HTMLResponse
|
||||||
from starlette.responses import Response, StreamingResponse, JSONResponse
|
from starlette.responses import Response, StreamingResponse, JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
from open_webui.models.oauth_sessions import OAuthSessions
|
from open_webui.models.oauth_sessions import OAuthSessions
|
||||||
from open_webui.models.chats import Chats
|
from open_webui.models.chats import Chats
|
||||||
from open_webui.models.folders import Folders
|
from open_webui.models.folders import Folders
|
||||||
|
|
@ -1408,6 +1409,9 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
headers=headers if headers else None,
|
headers=headers if headers else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
function_name_filter_list = mcp_server_connection.get(
|
||||||
|
"function_name_filter_list", None
|
||||||
|
)
|
||||||
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
tool_specs = await mcp_clients[server_id].list_tool_specs()
|
||||||
for tool_spec in tool_specs:
|
for tool_spec in tool_specs:
|
||||||
|
|
||||||
|
|
@ -1420,6 +1424,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
return tool_function
|
return tool_function
|
||||||
|
|
||||||
|
if function_name_filter_list and isinstance(
|
||||||
|
function_name_filter_list, list
|
||||||
|
):
|
||||||
|
if not is_string_allowed(
|
||||||
|
tool_spec["name"], function_name_filter_list
|
||||||
|
):
|
||||||
|
# Skip this function
|
||||||
|
continue
|
||||||
|
|
||||||
tool_function = make_tool_function(
|
tool_function = make_tool_function(
|
||||||
mcp_clients[server_id], tool_spec["name"]
|
mcp_clients[server_id], tool_spec["name"]
|
||||||
)
|
)
|
||||||
|
|
@ -1460,6 +1473,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
"__files__": metadata.get("files", []),
|
"__files__": metadata.get("files", []),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if mcp_tools_dict:
|
if mcp_tools_dict:
|
||||||
tools_dict = {**tools_dict, **mcp_tools_dict}
|
tools_dict = {**tools_dict, **mcp_tools_dict}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,45 @@ def deep_update(d, u):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def get_allow_block_lists(filter_list):
|
||||||
|
allow_list = []
|
||||||
|
block_list = []
|
||||||
|
|
||||||
|
if filter_list:
|
||||||
|
for d in filter_list:
|
||||||
|
if d.startswith("!"):
|
||||||
|
# Domains starting with "!" → blocked
|
||||||
|
block_list.append(d[1:])
|
||||||
|
else:
|
||||||
|
# Domains starting without "!" → allowed
|
||||||
|
allow_list.append(d)
|
||||||
|
|
||||||
|
return allow_list, block_list
|
||||||
|
|
||||||
|
|
||||||
|
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a string is allowed based on the provided filter list.
|
||||||
|
:param string: The string to check (e.g., domain or hostname).
|
||||||
|
:param filter_list: List of allowed/blocked strings. Strings starting with "!" are blocked.
|
||||||
|
:return: True if the string is allowed, False otherwise.
|
||||||
|
"""
|
||||||
|
if not filter_list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
allow_list, block_list = get_allow_block_lists(filter_list)
|
||||||
|
# If allow list is non-empty, require domain to match one of them
|
||||||
|
if allow_list:
|
||||||
|
if not any(string.endswith(allowed) for allowed in allow_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Block list always removes matches
|
||||||
|
if any(string.endswith(blocked) for blocked in block_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_message_list(messages_map, message_id):
|
def get_message_list(messages_map, message_id):
|
||||||
"""
|
"""
|
||||||
Reconstructs a list of messages in order up to the specified message_id.
|
Reconstructs a list of messages in order up to the specified message_id.
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ from langchain_core.utils.function_calling import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.misc import is_string_allowed
|
||||||
from open_webui.models.tools import Tools
|
from open_webui.models.tools import Tools
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.utils.plugin import load_tool_module_by_id
|
from open_webui.utils.plugin import load_tool_module_by_id
|
||||||
|
|
@ -149,8 +150,20 @@ async def get_tools(
|
||||||
)
|
)
|
||||||
|
|
||||||
specs = tool_server_data.get("specs", [])
|
specs = tool_server_data.get("specs", [])
|
||||||
|
function_name_filter_list = tool_server_connection.get(
|
||||||
|
"function_name_filter_list", None
|
||||||
|
)
|
||||||
|
|
||||||
for spec in specs:
|
for spec in specs:
|
||||||
function_name = spec["name"]
|
function_name = spec["name"]
|
||||||
|
if function_name_filter_list and isinstance(
|
||||||
|
function_name_filter_list, list
|
||||||
|
):
|
||||||
|
if not is_string_allowed(
|
||||||
|
function_name, function_name_filter_list
|
||||||
|
):
|
||||||
|
# Skip this function
|
||||||
|
continue
|
||||||
|
|
||||||
auth_type = tool_server_connection.get("auth_type", "bearer")
|
auth_type = tool_server_connection.get("auth_type", "bearer")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@
|
||||||
let key = '';
|
let key = '';
|
||||||
let headers = '';
|
let headers = '';
|
||||||
|
|
||||||
|
let functionNameFilterList = [];
|
||||||
let accessControl = {};
|
let accessControl = {};
|
||||||
|
|
||||||
let id = '';
|
let id = '';
|
||||||
|
|
@ -303,7 +304,7 @@
|
||||||
key,
|
key,
|
||||||
config: {
|
config: {
|
||||||
enable: enable,
|
enable: enable,
|
||||||
|
function_name_filter_list: functionNameFilterList,
|
||||||
access_control: accessControl
|
access_control: accessControl
|
||||||
},
|
},
|
||||||
info: {
|
info: {
|
||||||
|
|
@ -333,9 +334,11 @@
|
||||||
id = '';
|
id = '';
|
||||||
name = '';
|
name = '';
|
||||||
description = '';
|
description = '';
|
||||||
|
|
||||||
oauthClientInfo = null;
|
oauthClientInfo = null;
|
||||||
|
|
||||||
enable = true;
|
enable = true;
|
||||||
|
functionNameFilterList = [];
|
||||||
accessControl = null;
|
accessControl = null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -359,6 +362,7 @@
|
||||||
oauthClientInfo = connection.info?.oauth_client_info ?? null;
|
oauthClientInfo = connection.info?.oauth_client_info ?? null;
|
||||||
|
|
||||||
enable = connection.config?.enable ?? true;
|
enable = connection.config?.enable ?? true;
|
||||||
|
functionNameFilterList = connection.config?.function_name_filter_list ?? [];
|
||||||
accessControl = connection.config?.access_control ?? null;
|
accessControl = connection.config?.access_control ?? null;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -793,6 +797,25 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="flex flex-col w-full mt-2">
|
||||||
|
<label
|
||||||
|
for="function-name-filter-list"
|
||||||
|
class={`mb-1 text-xs ${($settings?.highContrastMode ?? false) ? 'text-gray-800 dark:text-gray-100 placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700 text-gray-500'}`}
|
||||||
|
>{$i18n.t('Function Name Filter List')}</label
|
||||||
|
>
|
||||||
|
|
||||||
|
<div class="flex-1">
|
||||||
|
<input
|
||||||
|
id="function-name-filter-list"
|
||||||
|
class={`w-full text-sm bg-transparent ${($settings?.highContrastMode ?? false) ? 'placeholder:text-gray-700 dark:placeholder:text-gray-100' : 'outline-hidden placeholder:text-gray-300 dark:placeholder:text-gray-700'}`}
|
||||||
|
type="text"
|
||||||
|
bind:value={functionNameFilterList}
|
||||||
|
placeholder={$i18n.t('Enter function name filter list (e.g. func1, !func2)')}
|
||||||
|
autocomplete="off"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
|
<hr class=" border-gray-100 dark:border-gray-700/10 my-2.5 w-full" />
|
||||||
|
|
||||||
<div class="my-2 -mx-2">
|
<div class="my-2 -mx-2">
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue