diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 0eb88e767e..b48db49222 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -147,6 +147,7 @@ class ToolServerConnection(BaseModel): headers: Optional[dict | str] = None key: Optional[str] config: Optional[dict] + placeholders: Optional[list[str]] = None model_config = ConfigDict(extra="allow") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index d397471dd9..d0b97156b6 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -91,7 +91,11 @@ from open_webui.utils.misc import ( convert_logit_bias_input_to_json, get_content_from_message, ) -from open_webui.utils.tools import get_tools, get_updated_tool_function +from open_webui.utils.tools import ( + get_tools, + get_updated_tool_function, + replace_placeholders_in_headers +) from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.filter import ( get_sorted_filter_ids, @@ -1139,10 +1143,16 @@ async def process_chat_payload(request, form_data, user, metadata, model): except Exception as e: log.error(f"Error getting OAuth token: {e}") + __user__ = user.model_dump() if isinstance(user, UserModel) else {} + if isinstance(user, UserModel) and user.settings: + user_settings = user.settings.model_dump() if user.settings else {} + if "tool_server_placeholders" in user_settings: + __user__["tool_server_placeholders"] = user_settings["tool_server_placeholders"] + extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_caller, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, + "__user__": __user__, "__metadata__": metadata, "__oauth_token__": oauth_token, "__request__": request, @@ -1401,6 +1411,16 @@ async def process_chat_payload(request, form_data, user, metadata, model): connection_headers = mcp_server_connection.get("headers", None) if connection_headers and isinstance(connection_headers, dict): + user_placeholders = ( + extra_params.get("__user__", {}) + .get("tool_server_placeholders", {}) + .get(server_id, {}) + ) + + connection_headers = replace_placeholders_in_headers( + connection_headers, user_placeholders + ) + for key, value in connection_headers.items(): headers[key] = value @@ -1993,10 +2013,16 @@ async def process_chat_response( except Exception as e: log.error(f"Error getting OAuth token: {e}") + __user__ = user.model_dump() if isinstance(user, UserModel) else {} + if isinstance(user, UserModel) and user.settings: + user_settings = user.settings.model_dump() if user.settings else {} + if "tool_server_placeholders" in user_settings: + __user__["tool_server_placeholders"] = user_settings["tool_server_placeholders"] + extra_params = { "__event_emitter__": event_emitter, "__event_call__": event_caller, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, + "__user__": __user__, "__metadata__": metadata, "__oauth_token__": oauth_token, "__request__": request, diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 2baff503ee..ed4dbe8c9c 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -51,6 +51,28 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) +def replace_placeholders_in_headers( + headers: dict, placeholders: dict[str, str] +) -> dict: + if not headers or not placeholders: + return headers + + replaced_headers = {} + for key, value in headers.items(): + if isinstance(value, str): + replaced_value = value + for placeholder_name, placeholder_value in placeholders.items(): + placeholder_pattern = f"{{{{{placeholder_name}}}}}" + replaced_value = replaced_value.replace( + placeholder_pattern, placeholder_value + ) + replaced_headers[key] = replaced_value + else: + replaced_headers[key] = value + + return replaced_headers + + def get_async_tool_function_and_apply_extra_params( function: Callable, extra_params: dict ) -> Callable[..., Awaitable]: @@ -195,6 +217,17 @@ async def get_tools( connection_headers = tool_server_connection.get("headers", None) if connection_headers and isinstance(connection_headers, dict): + user_placeholders = ( + extra_params.get("__user__", {}) + .get("tool_server_placeholders", {}) + .get(server_id, {}) + ) + + # Replace placeholders in headers + connection_headers = replace_placeholders_in_headers( + connection_headers, user_placeholders + ) + for key, value in connection_headers.items(): headers[key] = value diff --git a/src/lib/components/AddToolServerModal.svelte b/src/lib/components/AddToolServerModal.svelte index 764e2259bd..fd426ac850 100644 --- a/src/lib/components/AddToolServerModal.svelte +++ b/src/lib/components/AddToolServerModal.svelte @@ -46,6 +46,7 @@ let auth_type = 'bearer'; let key = ''; let headers = ''; + let placeholders: string[] = []; let functionNameFilterList = ''; let accessControl = {}; @@ -197,6 +198,7 @@ if (data.auth_type) auth_type = data.auth_type; if (data.headers) headers = JSON.stringify(data.headers, null, 2); if (data.key) key = data.key; + if (data.placeholders) placeholders = data.placeholders; if (data.info) { id = data.info.id ?? ''; @@ -231,6 +233,7 @@ auth_type, headers: headers ? JSON.parse(headers) : undefined, key, + placeholders: placeholders.length > 0 ? placeholders : undefined, info: { id: id, @@ -302,6 +305,7 @@ headers: headers ? JSON.parse(headers) : undefined, key, + placeholders: placeholders.length > 0 ? placeholders : undefined, config: { enable: enable, function_name_filter_list: functionNameFilterList, @@ -330,6 +334,7 @@ key = ''; auth_type = 'bearer'; + placeholders = []; id = ''; name = ''; @@ -355,6 +360,7 @@ headers = connection?.headers ? JSON.stringify(connection.headers, null, 2) : ''; key = connection?.key ?? ''; + placeholders = connection?.placeholders ?? []; id = connection.info?.id ?? ''; name = connection.info?.name ?? ''; @@ -726,6 +732,66 @@ +