diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 62d1d95699..5070850896 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -144,6 +144,7 @@ class ToolServerConnection(BaseModel): path: str type: Optional[str] = "openapi" # openapi, mcp auth_type: Optional[str] + headers: Optional[dict] key: Optional[str] config: Optional[dict] @@ -282,10 +283,14 @@ async def verify_tool_servers_config( token = oauth_token.get("access_token", "") except Exception as e: pass - if token: headers = {"Authorization": f"Bearer {token}"} + if form_data.headers: + if headers is None: + headers = {} + headers.update(form_data.headers) + await client.connect(form_data.url, headers=headers) specs = await client.list_tool_specs() return { @@ -303,6 +308,7 @@ async def verify_tool_servers_config( await client.disconnect() else: # openapi token = None + headers = None if form_data.auth_type == "bearer": token = form_data.key elif form_data.auth_type == "session": @@ -323,8 +329,16 @@ async def verify_tool_servers_config( except Exception as e: pass + if token: + headers = {"Authorization": f"Bearer {token}"} + + if form_data.headers: + if headers is None: + headers = {} + headers.update(form_data.headers) + url = get_tool_server_url(form_data.url, form_data.path) - return await get_tool_server_data(token, url) + return await get_tool_server_data(url, headers=headers) except HTTPException as e: raise e except Exception as e: diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7fd556dc57..3091757a04 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -312,7 +312,11 @@ async def chat_completion_tools_handler( for message in recent_messages ) - prompt = f"History:\n{chat_history}\nQuery: {user_message}" if chat_history else f"Query: {user_message}" + prompt = ( + f"History:\n{chat_history}\nQuery: {user_message}" + if chat_history + else f"Query: {user_message}" + ) return { "model": task_model_id, @@ -1327,7 +1331,6 @@ async def process_chat_payload(request, form_data, user, metadata, model): continue auth_type = mcp_server_connection.get("auth_type", "") - headers = {} if auth_type == "bearer": headers["Authorization"] = ( @@ -1363,6 +1366,11 @@ async def process_chat_payload(request, form_data, user, metadata, model): log.error(f"Error getting OAuth token: {e}") oauth_token = None + connection_headers = mcp_server_connection.get("headers", None) + if connection_headers: + for key, value in connection_headers.items(): + headers[key] = value + mcp_clients[server_id] = MCPClient() await mcp_clients[server_id].connect( url=mcp_server_connection.get("url", ""), diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 1d1254f184..f14b3ce8ff 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -155,7 +155,9 @@ async def get_tools( auth_type = tool_server_connection.get("auth_type", "bearer") cookies = {} - headers = {} + headers = { + "Content-Type": "application/json", + } if auth_type == "bearer": headers["Authorization"] = ( @@ -177,7 +179,10 @@ async def get_tools( f"Bearer {oauth_token.get('access_token', '')}" ) - headers["Content-Type"] = "application/json" + connection_headers = tool_server_connection.get("headers", None) + if connection_headers: + for key, value in connection_headers.items(): + headers[key] = value def make_tool_function( function_name, tool_server_data, headers @@ -561,20 +566,21 @@ async def get_tool_servers(request: Request): return tool_servers -async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: - headers = { +async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]: + _headers = { "Accept": "application/json", "Content-Type": "application/json", } - if token: - headers["Authorization"] = f"Bearer {token}" + + if headers: + _headers.update(headers) error = None try: timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL + url, headers=_headers, ssl=AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL ) as response: if response.status != 200: error_body = await response.json() @@ -644,7 +650,10 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, openapi_path = server.get("path", "openapi.json") spec_url = get_tool_server_url(server_url, openapi_path) # Fetch from URL - task = get_tool_server_data(token, spec_url) + task = get_tool_server_data( + spec_url, + {"Authorization": f"Bearer {token}"} if token else None, + ) elif spec_type == "json" and server.get("spec", ""): # Use provided JSON spec spec_json = None diff --git a/src/lib/components/AddConnectionModal.svelte b/src/lib/components/AddConnectionModal.svelte index 90bb60b406..5a75774fa0 100644 --- a/src/lib/components/AddConnectionModal.svelte +++ b/src/lib/components/AddConnectionModal.svelte @@ -426,7 +426,7 @@