refac: tools

This commit is contained in:
Timothy Jaeryang Baek 2025-09-26 19:01:22 -05:00
parent 1de5827eb3
commit 54beeeaf72
2 changed files with 67 additions and 9 deletions

View file

@ -17,7 +17,11 @@ from open_webui.models.tools import (
ToolUserResponse, ToolUserResponse,
Tools, Tools,
) )
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports from open_webui.utils.plugin import (
load_tool_module_by_id,
replace_imports,
get_tool_module_from_cache,
)
from open_webui.utils.tools import get_tool_specs from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
@ -35,6 +39,14 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter() router = APIRouter()
def get_tool_module(request, tool_id, load_from_db=True):
"""
Get the tool module by its ID.
"""
tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db)
return tool_module
############################ ############################
# GetTools # GetTools
############################ ############################
@ -42,15 +54,19 @@ router = APIRouter()
@router.get("/", response_model=list[ToolUserResponse]) @router.get("/", response_model=list[ToolUserResponse])
async def get_tools(request: Request, user=Depends(get_verified_user)): async def get_tools(request: Request, user=Depends(get_verified_user)):
tools = [ tools = []
ToolUserResponse(
**{ # Local Tools
**tool.model_dump(), for tool in Tools.get_tools():
"has_user_valves": "class UserValves(BaseModel):" in tool.content, tool_module = get_tool_module(request, tool.id)
} tools.append(
ToolUserResponse(
**{
**tool.model_dump(),
"has_user_valves": hasattr(tool_module, "UserValves"),
}
)
) )
for tool in Tools.get_tools()
]
# OpenAPI Tool Servers # OpenAPI Tool Servers
for server in await get_tool_servers(request): for server in await get_tool_servers(request):

View file

@ -166,6 +166,48 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
os.unlink(temp_file.name) os.unlink(temp_file.name)
def get_tool_module_from_cache(request, tool_id, load_from_db=True):
if load_from_db:
# Always load from the database by default
tool = Tools.get_tool_by_id(tool_id)
if not tool:
raise Exception(f"Tool not found: {tool_id}")
content = tool.content
new_content = replace_imports(content)
if new_content != content:
content = new_content
# Update the tool content in the database
Tools.update_tool_by_id(tool_id, {"content": content})
if (
hasattr(request.app.state, "TOOL_CONTENTS")
and tool_id in request.app.state.TOOL_CONTENTS
) and (
hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS
):
if request.app.state.TOOL_CONTENTS[tool_id] == content:
return request.app.state.TOOLS[tool_id], None
tool_module, frontmatter = load_tool_module_by_id(tool_id, content)
else:
if hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS:
return request.app.state.TOOLS[tool_id], None
tool_module, frontmatter = load_tool_module_by_id(tool_id)
if not hasattr(request.app.state, "TOOLS"):
request.app.state.TOOLS = {}
if not hasattr(request.app.state, "TOOL_CONTENTS"):
request.app.state.TOOL_CONTENTS = {}
request.app.state.TOOLS[tool_id] = tool_module
request.app.state.TOOL_CONTENTS[tool_id] = content
return tool_module, frontmatter
def get_function_module_from_cache(request, function_id, load_from_db=True): def get_function_module_from_cache(request, function_id, load_from_db=True):
if load_from_db: if load_from_db:
# Always load from the database by default # Always load from the database by default