mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 12:55:19 +00:00
refac: tool server redis cache
This commit is contained in:
parent
8a157578f4
commit
f592748011
4 changed files with 34 additions and 20 deletions
|
|
@ -9,8 +9,8 @@ from open_webui.config import BannerModel
|
||||||
|
|
||||||
from open_webui.utils.tools import (
|
from open_webui.utils.tools import (
|
||||||
get_tool_server_data,
|
get_tool_server_data,
|
||||||
get_tool_servers_data,
|
|
||||||
get_tool_server_url,
|
get_tool_server_url,
|
||||||
|
set_tool_servers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -114,10 +114,7 @@ async def set_tool_servers_config(
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
request.app.state.config.TOOL_SERVER_CONNECTIONS = [
|
||||||
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS
|
||||||
]
|
]
|
||||||
|
await set_tool_servers(request)
|
||||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
"TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS,
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
||||||
from open_webui.utils.tools import get_tool_specs
|
from open_webui.utils.tools import get_tool_specs
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
from open_webui.utils.tools import get_tool_servers_data
|
from open_webui.utils.tools import get_tool_servers
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
from open_webui.config import CACHE_DIR, ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||||
|
|
@ -32,6 +32,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetTools
|
# GetTools
|
||||||
############################
|
############################
|
||||||
|
|
@ -39,18 +40,9 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[ToolUserResponse])
|
@router.get("/", response_model=list[ToolUserResponse])
|
||||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
|
|
||||||
if not request.app.state.TOOL_SERVERS:
|
|
||||||
# If the tool servers are not set, we need to set them
|
|
||||||
# This is done only once when the server starts
|
|
||||||
# This is done to avoid loading the tool servers every time
|
|
||||||
|
|
||||||
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
|
||||||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = Tools.get_tools()
|
tools = Tools.get_tools()
|
||||||
for server in request.app.state.TOOL_SERVERS:
|
|
||||||
|
for server in await get_tool_servers(request):
|
||||||
tools.append(
|
tools.append(
|
||||||
ToolUserResponse(
|
ToolUserResponse(
|
||||||
**{
|
**{
|
||||||
|
|
|
||||||
|
|
@ -910,7 +910,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
tools_dict = {}
|
tools_dict = {}
|
||||||
|
|
||||||
if tool_ids:
|
if tool_ids:
|
||||||
tools_dict = get_tools(
|
tools_dict = await get_tools(
|
||||||
request,
|
request,
|
||||||
tool_ids,
|
tool_ids,
|
||||||
user,
|
user,
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ def get_async_tool_function_and_apply_extra_params(
|
||||||
return new_function
|
return new_function
|
||||||
|
|
||||||
|
|
||||||
def get_tools(
|
async def get_tools(
|
||||||
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict]:
|
||||||
tools_dict = {}
|
tools_dict = {}
|
||||||
|
|
@ -80,7 +80,7 @@ def get_tools(
|
||||||
server_id = tool_id.split(":")[1]
|
server_id = tool_id.split(":")[1]
|
||||||
|
|
||||||
tool_server_data = None
|
tool_server_data = None
|
||||||
for server in request.app.state.TOOL_SERVERS:
|
for server in await get_tool_servers(request):
|
||||||
if server["id"] == server_id:
|
if server["id"] == server_id:
|
||||||
tool_server_data = server
|
tool_server_data = server
|
||||||
break
|
break
|
||||||
|
|
@ -447,6 +447,31 @@ def convert_openapi_to_tool_payload(openapi_spec):
|
||||||
return tool_payload
|
return tool_payload
|
||||||
|
|
||||||
|
|
||||||
|
async def set_tool_servers(request: Request):
|
||||||
|
request.app.state.TOOL_SERVERS = await get_tool_servers_data(
|
||||||
|
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.app.state.redis is not None:
|
||||||
|
await request.app.state.redis.hmset(
|
||||||
|
"tool_servers", request.app.state.TOOL_SERVERS
|
||||||
|
)
|
||||||
|
|
||||||
|
return request.app.state.TOOL_SERVERS
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tool_servers(request: Request):
|
||||||
|
tool_servers = []
|
||||||
|
if request.app.state.redis is not None:
|
||||||
|
tool_servers = await request.app.state.redis.hgetall("tool_servers")
|
||||||
|
|
||||||
|
if not tool_servers:
|
||||||
|
await set_tool_servers(request)
|
||||||
|
|
||||||
|
request.app.state.TOOL_SERVERS = tool_servers
|
||||||
|
return request.app.state.TOOL_SERVERS
|
||||||
|
|
||||||
|
|
||||||
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
async def get_tool_server_data(token: str, url: str) -> Dict[str, Any]:
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue