mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 12:55:19 +00:00
refac: tool server redis cache
This commit is contained in:
parent
8a157578f4
commit
f592748011
4 changed files with 34 additions and 20 deletions
|
|
@ -9,8 +9,8 @@ from open_webui.config import BannerModel
|
|||
|
||||
from open_webui.utils.tools import (
|
||||
get_tool_server_data,
|
||||
get_tool_servers_data,
|
||||
get_tool_server_url,
|
||||
set_tool_servers,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -114,10 +114,7 @@ async def set_tool_servers_config(
|
|||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||
]
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
await set_tool_servers(request)
|
||||
|
||||
return {
|
||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
|||
from open_webui.utils.tools import get_tool_specs
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
from open_webui.utils.tools import get_tool_servers_data
|
||||
from open_webui.utils.tools import get_tool_servers
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||
|
|
@ -32,6 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
############################
|
||||
# GetTools
|
||||
############################
|
||||
|
|
@ -39,18 +40,9 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=list[ToolUserResponse])
|
||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
if not request.app.state.TOOL_SERVERS:
|
||||
# If the tool servers are not set, we need to set them
|
||||
# This is done only once when the server starts
|
||||
# This is done to avoid loading the tool servers every time
|
||||
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
tools = Tools.get_tools()
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
|
||||
for server in await get_tool_servers(request):
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
|
|
|
|||
|
|
@ -910,7 +910,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
tools_dict = get_tools(
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def get_async_tool_function_and_apply_extra_params(
|
|||
return new_function
|
||||
|
||||
|
||||
def get_tools(
|
||||
async def get_tools(
|
||||
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
tools_dict = {}
|
||||
|
|
@ -80,7 +80,7 @@ def get_tools(
|
|||
server_id = tool_id.split(":")[1]
|
||||
|
||||
tool_server_data = None
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
for server in await get_tool_servers(request):
|
||||
if server["id"] == server_id:
|
||||
tool_server_data = server
|
||||
break
|
||||
|
|
@ -447,6 +447,31 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
|||
return tool_payload
|
||||
|
||||
|
||||
async def set_tool_servers(request: Request):
|
||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
if request.app.state.redis is not None:
|
||||
await request.app.state.redis.hmset(
|
||||
"tool_servers", request.app.state.TOOL_SERVERS
|
||||
)
|
||||
|
||||
return request.app.state.TOOL_SERVERS
|
||||
|
||||
|
||||
async def get_tool_servers(request: Request):
|
||||
tool_servers = []
|
||||
if request.app.state.redis is not None:
|
||||
tool_servers = await request.app.state.redis.hgetall("tool_servers")
|
||||
|
||||
if not tool_servers:
|
||||
await set_tool_servers(request)
|
||||
|
||||
request.app.state.TOOL_SERVERS = tool_servers
|
||||
return request.app.state.TOOL_SERVERS
|
||||
|
||||
|
||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
|
|
|
|||
Loading…
Reference in a new issue