mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 21:05:19 +00:00
refac: tools
This commit is contained in:
parent
1de5827eb3
commit
54beeeaf72
2 changed files with 67 additions and 9 deletions
|
|
@ -17,7 +17,11 @@ from open_webui.models.tools import (
|
||||||
ToolUserResponse,
|
ToolUserResponse,
|
||||||
Tools,
|
Tools,
|
||||||
)
|
)
|
||||||
from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
|
from open_webui.utils.plugin import (
|
||||||
|
load_tool_module_by_id,
|
||||||
|
replace_imports,
|
||||||
|
get_tool_module_from_cache,
|
||||||
|
)
|
||||||
from open_webui.utils.tools import get_tool_specs
|
from open_webui.utils.tools import get_tool_specs
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
|
|
@ -35,6 +39,14 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_module(request, tool_id, load_from_db=True):
|
||||||
|
"""
|
||||||
|
Get the tool module by its ID.
|
||||||
|
"""
|
||||||
|
tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db)
|
||||||
|
return tool_module
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetTools
|
# GetTools
|
||||||
############################
|
############################
|
||||||
|
|
@ -42,15 +54,19 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[ToolUserResponse])
|
@router.get("/", response_model=list[ToolUserResponse])
|
||||||
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
async def get_tools(request: Request, user=Depends(get_verified_user)):
|
||||||
tools = [
|
tools = []
|
||||||
ToolUserResponse(
|
|
||||||
**{
|
# Local Tools
|
||||||
**tool.model_dump(),
|
for tool in Tools.get_tools():
|
||||||
"has_user_valves": "class UserValves(BaseModel):" in tool.content,
|
tool_module = get_tool_module(request, tool.id)
|
||||||
}
|
tools.append(
|
||||||
|
ToolUserResponse(
|
||||||
|
**{
|
||||||
|
**tool.model_dump(),
|
||||||
|
"has_user_valves": hasattr(tool_module, "UserValves"),
|
||||||
|
}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for tool in Tools.get_tools()
|
|
||||||
]
|
|
||||||
|
|
||||||
# OpenAPI Tool Servers
|
# OpenAPI Tool Servers
|
||||||
for server in await get_tool_servers(request):
|
for server in await get_tool_servers(request):
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,48 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
|
||||||
os.unlink(temp_file.name)
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_module_from_cache(request, tool_id, load_from_db=True):
|
||||||
|
if load_from_db:
|
||||||
|
# Always load from the database by default
|
||||||
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
|
if not tool:
|
||||||
|
raise Exception(f"Tool not found: {tool_id}")
|
||||||
|
content = tool.content
|
||||||
|
|
||||||
|
new_content = replace_imports(content)
|
||||||
|
if new_content != content:
|
||||||
|
content = new_content
|
||||||
|
# Update the tool content in the database
|
||||||
|
Tools.update_tool_by_id(tool_id, {"content": content})
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(request.app.state, "TOOL_CONTENTS")
|
||||||
|
and tool_id in request.app.state.TOOL_CONTENTS
|
||||||
|
) and (
|
||||||
|
hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS
|
||||||
|
):
|
||||||
|
if request.app.state.TOOL_CONTENTS[tool_id] == content:
|
||||||
|
return request.app.state.TOOLS[tool_id], None
|
||||||
|
|
||||||
|
tool_module, frontmatter = load_tool_module_by_id(tool_id, content)
|
||||||
|
else:
|
||||||
|
if hasattr(request.app.state, "TOOLS") and tool_id in request.app.state.TOOLS:
|
||||||
|
return request.app.state.TOOLS[tool_id], None
|
||||||
|
|
||||||
|
tool_module, frontmatter = load_tool_module_by_id(tool_id)
|
||||||
|
|
||||||
|
if not hasattr(request.app.state, "TOOLS"):
|
||||||
|
request.app.state.TOOLS = {}
|
||||||
|
|
||||||
|
if not hasattr(request.app.state, "TOOL_CONTENTS"):
|
||||||
|
request.app.state.TOOL_CONTENTS = {}
|
||||||
|
|
||||||
|
request.app.state.TOOLS[tool_id] = tool_module
|
||||||
|
request.app.state.TOOL_CONTENTS[tool_id] = content
|
||||||
|
|
||||||
|
return tool_module, frontmatter
|
||||||
|
|
||||||
|
|
||||||
def get_function_module_from_cache(request, function_id, load_from_db=True):
|
def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||||
if load_from_db:
|
if load_from_db:
|
||||||
# Always load from the database by default
|
# Always load from the database by default
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue