diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index c8badfa112..8ce4e0d247 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -9,8 +9,8 @@ from open_webui.config import BannerModel from open_webui.utils.tools import ( get_tool_server_data, - get_tool_servers_data, get_tool_server_url, + set_tool_servers, ) @@ -114,10 +114,7 @@ async def set_tool_servers_config( request.app.state.config.TOOL_SERVER_CONNECTIONS = [ connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS ] - - request.app.state.TOOL_SERVERS = await get_tool_servers_data( - request.app.state.config.TOOL_SERVER_CONNECTIONS - ) + await set_tool_servers(request) return { "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 184e55952e..183bd28397 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -19,7 +19,7 @@ from open_webui.utils.plugin import load_tool_module_by_id, replace_imports from open_webui.utils.tools import get_tool_specs from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission -from open_webui.utils.tools import get_tool_servers_data +from open_webui.utils.tools import get_tool_servers from open_webui.env import SRC_LOG_LEVELS from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS @@ -32,6 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() + ############################ # GetTools ############################ @@ -39,18 +40,9 @@ router = APIRouter() @router.get("/", response_model=list[ToolUserResponse]) async def get_tools(request: Request, user=Depends(get_verified_user)): - - if not request.app.state.TOOL_SERVERS: - # If the tool servers are not set, we need to set them - # This is done only once when the server starts - # This is done to avoid loading the tool servers every time - - request.app.state.TOOL_SERVERS = await get_tool_servers_data( - request.app.state.config.TOOL_SERVER_CONNECTIONS - ) - tools = Tools.get_tools() - for server in request.app.state.TOOL_SERVERS: + + for server in await get_tool_servers(request): tools.append( ToolUserResponse( **{ diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 29fc617115..0eb3aa8853 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -910,7 +910,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): tools_dict = {} if tool_ids: - tools_dict = get_tools( + tools_dict = await get_tools( request, tool_ids, user, diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 5c6317375c..99f967026f 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -68,7 +68,7 @@ def get_async_tool_function_and_apply_extra_params( return new_function -def get_tools( +async def get_tools( request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools_dict = {} @@ -80,7 +80,7 @@ def get_tools( server_id = tool_id.split(":")[1] tool_server_data = None - for server in request.app.state.TOOL_SERVERS: + for server in await get_tool_servers(request): if server["id"] == server_id: tool_server_data = server break @@ -447,6 +447,31 @@ def convert_openapi_to_tool_payload(openapi_spec): return tool_payload +async def set_tool_servers(request: Request): + request.app.state.TOOL_SERVERS = await get_tool_servers_data( + request.app.state.config.TOOL_SERVER_CONNECTIONS + ) + + if request.app.state.redis is not None: + await request.app.state.redis.hmset( + "tool_servers", request.app.state.TOOL_SERVERS + ) + + return request.app.state.TOOL_SERVERS + + +async def get_tool_servers(request: Request): + tool_servers = [] + if request.app.state.redis is not None: + tool_servers = await request.app.state.redis.hgetall("tool_servers") + + if not tool_servers: + await set_tool_servers(request) + + request.app.state.TOOL_SERVERS = tool_servers + return request.app.state.TOOL_SERVERS + + async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]: headers = { "Accept": "application/json",