refac: tool server redis cache

This commit is contained in:
Timothy Jaeryang Baek 2025-08-18 20:53:46 +04:00
parent 8a157578f4
commit f592748011
4 changed files with 34 additions and 20 deletions

View file

@ -9,8 +9,8 @@ from open_webui.config import BannerModel
from open_webui.utils.tools import (
get_tool_server_data,
get_tool_servers_data,
get_tool_server_url,
set_tool_servers,
)
@ -114,10 +114,7 @@ async def set_tool_servers_config(
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
]
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
await set_tool_servers(request)
return {
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,

View file

@ -19,7 +19,7 @@ from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.utils.tools import get_tool_servers_data
from open_webui.utils.tools import get_tool_servers
from open_webui.env import SRC_LOG_LEVELS
from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
@ -32,6 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter()
############################
# GetTools
############################
@ -39,18 +40,9 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)):
if not request.app.state.TOOL_SERVERS:
# If the tool servers are not set, we need to set them
# This is done only once when the server starts
# This is done to avoid loading the tool servers every time
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
tools = Tools.get_tools()
for server in request.app.state.TOOL_SERVERS:
for server in await get_tool_servers(request):
tools.append(
ToolUserResponse(
**{

View file

@ -910,7 +910,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tools_dict = {}
if tool_ids:
tools_dict = get_tools(
tools_dict = await get_tools(
request,
tool_ids,
user,

View file

@ -68,7 +68,7 @@ def get_async_tool_function_and_apply_extra_params(
return new_function
def get_tools(
async def get_tools(
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools_dict = {}
@ -80,7 +80,7 @@ def get_tools(
server_id = tool_id.split(":")[1]
tool_server_data = None
for server in request.app.state.TOOL_SERVERS:
for server in await get_tool_servers(request):
if server["id"] == server_id:
tool_server_data = server
break
@ -447,6 +447,31 @@ def convert_openapi_to_tool_payload(openapi_spec):
return tool_payload
async def set_tool_servers(request: Request):
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
if request.app.state.redis is not None:
await request.app.state.redis.hmset(
"tool_servers", request.app.state.TOOL_SERVERS
)
return request.app.state.TOOL_SERVERS
async def get_tool_servers(request: Request):
tool_servers = []
if request.app.state.redis is not None:
tool_servers = await request.app.state.redis.hgetall("tool_servers")
if not tool_servers:
await set_tool_servers(request)
request.app.state.TOOL_SERVERS = tool_servers
return request.app.state.TOOL_SERVERS
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
headers = {
"Accept": "application/json",