From 54beeeaf72873c6e80ca9b24b3ff28787dee8d6a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 26 Sep 2025 19:01:22 -0500 Subject: [PATCH] refac: tools --- backend/open_webui/routers/tools.py | 34 ++++++++++++++++------- backend/open_webui/utils/plugin.py | 42 +++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 9 deletions(-) diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index eb66a86825..2fa3f6abf6 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -17,7 +17,11 @@ from open_webui.models.tools import ( ToolUserResponse, Tools, ) -from open_webui.utils.plugin import load_tool_module_by_id, replace_imports +from open_webui.utils.plugin import ( + load_tool_module_by_id, + replace_imports, + get_tool_module_from_cache, +) from open_webui.utils.tools import get_tool_specs from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission @@ -35,6 +39,14 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) router = APIRouter() +def get_tool_module(request, tool_id, load_from_db=True): + """ + Get the tool module by its ID. + """ + tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db) + return tool_module + + ############################ # GetTools ############################ @@ -42,15 +54,19 @@ router = APIRouter() @router.get("/", response_model=list[ToolUserResponse]) async def get_tools(request: Request, user=Depends(get_verified_user)): - tools = [ - ToolUserResponse( - **{ - **tool.model_dump(), - "has_user_valves": "class UserValves(BaseModel):" in tool.content, - } + tools = [] + + # Local Tools + for tool in Tools.get_tools(): + tool_module = get_tool_module(request, tool.id) + tools.append( + ToolUserResponse( + **{ + **tool.model_dump(), + "has_user_valves": hasattr(tool_module, "UserValves"), + } + ) ) - for tool in Tools.get_tools() - ] # OpenAPI Tool Servers for server in await get_tool_servers(request): diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 8d9729bae2..51c3f4f5f7 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -166,6 +166,48 @@ def load_function_module_by_id(function_id: str, content: str | None = None): os.unlink(temp_file.name) +def get_tool_module_from_cache(request, tool_id, load_from_db=True): + if load_from_db: + # Always load from the database by default + tool = Tools.get_tool_by_id(tool_id) + if not tool: + raise Exception(f"Tool not found: {tool_id}") + content = tool.content + + new_content = replace_imports(content) + if new_content != content: + content = new_content + # Update the tool content in the database + Tools.update_tool_by_id(tool_id, {"content": content}) + + if ( + hasattr(request.app.state, "TOOL_CONTENTS") + and tool_id in request.app.state.TOOL_CONTENTS + ) and ( + hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS + ): + if request.app.state.TOOL_CONTENTS[tool_id] == content: + return request.app.state.TOOLS[tool_id], None + + tool_module, frontmatter = load_tool_module_by_id(tool_id, content) + else: + if hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS: + return request.app.state.TOOLS[tool_id], None + + tool_module, frontmatter = load_tool_module_by_id(tool_id) + + if not hasattr(request.app.state, "TOOLS"): + request.app.state.TOOLS = {} + + if not hasattr(request.app.state, "TOOL_CONTENTS"): + request.app.state.TOOL_CONTENTS = {} + + request.app.state.TOOLS[tool_id] = tool_module + request.app.state.TOOL_CONTENTS[tool_id] = content + + return tool_module, frontmatter + + def get_function_module_from_cache(request, function_id, load_from_db=True): if load_from_db: # Always load from the database by default