wip: functions

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 17:20:13 +04:00
parent 3b9e454fb4
commit 9a210e743d
9 changed files with 107 additions and 90 deletions

View file

@ -30,7 +30,6 @@ from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.tools import get_tools
@ -56,17 +55,17 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
async def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = await get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
valves = await Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module
async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipes = await Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:

View file

@ -477,7 +477,8 @@ from open_webui.constants import ERROR_MESSAGES
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
loop = asyncio.get_event_loop()
loop.run_until_complete(Functions.deactivate_all_functions())
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)

View file

@ -38,7 +38,7 @@ router = APIRouter()
@router.get("/", response_model=list[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
return await Functions.get_functions()
############################
@ -48,7 +48,7 @@ async def get_functions(user=Depends(get_verified_user)):
@router.get("/export", response_model=list[FunctionModel])
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
return await Functions.get_functions()
############################
@ -142,12 +142,14 @@ async def sync_functions(
try:
for function in form_data.functions:
function.content = replace_imports(function.content)
function_module, function_type, frontmatter = load_function_module_by_id(
function.id,
content=function.content,
function_module, function_type, frontmatter = (
await load_function_module_by_id(
function.id,
content=function.content,
)
)
return Functions.sync_functions(user.id, form_data.functions)
return await Functions.sync_functions(user.id, form_data.functions)
except Exception as e:
log.exception(f"Failed to load a function: {e}")
raise HTTPException(
@ -173,20 +175,24 @@ async def create_new_function(
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
function = await Functions.get_function_by_id(form_data.id)
if function is None:
try:
form_data.content = replace_imports(form_data.content)
function_module, function_type, frontmatter = load_function_module_by_id(
form_data.id,
content=form_data.content,
function_module, function_type, frontmatter = (
await load_function_module_by_id(
form_data.id,
content=form_data.content,
)
)
form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function = await Functions.insert_new_function(
user.id, function_type, form_data
)
function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
@ -218,7 +224,7 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
return function
@ -236,9 +242,9 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
function = Functions.update_function_by_id(
function = await Functions.update_function_by_id(
id, {"is_active": not function.is_active}
)
@ -263,9 +269,9 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
function = Functions.update_function_by_id(
function = await unctions.update_function_by_id(
id, {"is_global": not function.is_global}
)
@ -294,7 +300,7 @@ async def update_function_by_id(
):
try:
form_data.content = replace_imports(form_data.content)
function_module, function_type, frontmatter = load_function_module_by_id(
function_module, function_type, frontmatter = await load_function_module_by_id(
id, content=form_data.content
)
form_data.meta.manifest = frontmatter
@ -305,7 +311,7 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
log.debug(updated)
function = Functions.update_function_by_id(id, updated)
function = await Functions.update_function_by_id(id, updated)
if function:
return function
@ -331,7 +337,7 @@ async def update_function_by_id(
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
result = Functions.delete_function_by_id(id)
result = await Functions.delete_function_by_id(id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
@ -348,10 +354,10 @@ async def delete_function_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
try:
valves = Functions.get_function_valves_by_id(id)
valves = await Functions.get_function_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
@ -374,10 +380,10 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
function_module, function_type, frontmatter = (
await get_function_module_from_cache(request, id)
)
if hasattr(function_module, "Valves"):
@ -400,10 +406,10 @@ async def get_function_valves_spec_by_id(
async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
function_module, function_type, frontmatter = (
await get_function_module_from_cache(request, id)
)
if hasattr(function_module, "Valves"):
@ -412,7 +418,7 @@ async def update_function_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Functions.update_function_valves_by_id(id, valves.model_dump())
await Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}")
@ -461,10 +467,10 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
function_module, function_type, frontmatter = (
await get_function_module_from_cache(request, id)
)
if hasattr(function_module, "UserValves"):
@ -485,8 +491,8 @@ async def update_function_user_valves_by_id(
function = await Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
request, id
function_module, function_type, frontmatter = (
await get_function_module_from_cache(request, id)
)
if hasattr(function_module, "UserValves"):

View file

@ -1433,7 +1433,7 @@ async def generate_openai_completion(
if ":" not in model_id:
model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id)
model_info = await Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@ -1514,7 +1514,7 @@ async def generate_openai_chat_completion(
if ":" not in model_id:
model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id)
model_info = await Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@ -1629,7 +1629,7 @@ async def get_openai_models(
# Filter models based on user access control
filtered_models = []
for model in models:
model_info = Models.get_model_by_id(model["id"])
model_info = await Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or (
await has_access(

View file

@ -41,7 +41,6 @@ from open_webui.models.models import Models
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.models import get_all_models, check_model_access
@ -328,8 +327,8 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
await Functions.get_function_by_id(filter_id)
for filter_id in await get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
@ -352,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
else:
sub_action_id = None
action = Functions.get_function_by_id(action_id)
action = await Functions.get_function_by_id(action_id)
if not action:
raise Exception(f"Action not found: {action_id}")
@ -390,10 +389,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
}
)
function_module, _, _ = get_function_module_from_cache(request, action_id)
function_module, _, _ = await get_function_module_from_cache(request, action_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
valves = await Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):

View file

@ -2,7 +2,6 @@ import inspect
import logging
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.models.functions import Functions
@ -12,35 +11,39 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module(request, function_id, load_from_db=True):
async def get_function_module(request, function_id, load_from_db=True):
"""
Get the function module by its ID.
"""
function_module, _, _ = get_function_module_from_cache(
function_module, _, _ = await get_function_module_from_cache(
request, function_id, load_from_db
)
return function_module
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
async def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
async def get_priority(function_id):
function = await Functions.get_function_by_id(function_id)
if function is not None:
valves = Functions.get_function_valves_by_id(function_id)
valves = await Functions.get_function_valves_by_id(function_id)
return valves.get("priority", 0) if valves else 0
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
filter_ids = [
function.id for function in await Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
active_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
for function in await Functions.get_functions_by_type(
"filter", active_only=True
)
]
def get_active_status(filter_id):
function_module = get_function_module(request, filter_id)
async def get_active_status(filter_id):
function_module = await get_function_module(request, filter_id)
if getattr(function_module, "toggle", None):
return filter_id in (enabled_filter_ids or [])
@ -48,10 +51,13 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None)
return True
active_filter_ids = [
filter_id for filter_id in active_filter_ids if get_active_status(filter_id)
filter_id
for filter_id in active_filter_ids
if await get_active_status(filter_id)
]
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
filter_ids.sort(key=get_priority)
return filter_ids
@ -68,7 +74,7 @@ async def process_filter_functions(
if not filter:
continue
function_module = get_function_module(
function_module = await get_function_module(
request, filter_id, load_from_db=(filter_type != "stream")
)
# Prepare handler function
@ -82,7 +88,7 @@ async def process_filter_functions(
# Apply valves to the function
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
valves = await Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)

View file

@ -77,7 +77,6 @@ from open_webui.utils.misc import (
convert_logit_bias_input_to_json,
)
from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.filter import (
get_sorted_filter_ids,
process_filter_functions,
@ -840,8 +839,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
await Functions.get_function_by_id(filter_id)
for filter_id in await get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
@ -908,7 +907,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tools_dict = {}
if tool_ids:
tools_dict = get_tools(
tools_dict = await get_tools(
request,
tool_ids,
user,
@ -1323,7 +1322,9 @@ async def process_chat_response(
# Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
webhook_url = await Users.get_user_webhook_url_by_id(
user.id
)
if webhook_url:
post_webhook(
request.app.state.WEBUI_NAME,
@ -1394,8 +1395,8 @@ async def process_chat_response(
"__model__": model,
}
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
await Functions.get_function_by_id(filter_id)
for filter_id in await get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
@ -2523,7 +2524,7 @@ async def process_chat_response(
# Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
webhook_url = await Users.get_user_webhook_url_by_id(user.id)
if webhook_url:
post_webhook(
request.app.state.WEBUI_NAME,

View file

@ -15,7 +15,6 @@ from open_webui.models.models import Models
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.access_control import has_access
@ -130,19 +129,23 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
models = models + arena_models
global_action_ids = [
function.id for function in Functions.get_global_action_functions()
function.id for function in await Functions.get_global_action_functions()
]
enabled_action_ids = [
function.id
for function in Functions.get_functions_by_type("action", active_only=True)
for function in await Functions.get_functions_by_type(
"action", active_only=True
)
]
global_filter_ids = [
function.id for function in Functions.get_global_filter_functions()
function.id for function in await Functions.get_global_filter_functions()
]
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
for function in await Functions.get_functions_by_type(
"filter", active_only=True
)
]
custom_models = await Models.get_all_models()
@ -265,8 +268,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
}
]
def get_function_module_by_id(function_id):
function_module, _, _ = get_function_module_from_cache(request, function_id)
async def get_function_module_by_id(function_id):
function_module, _, _ = await get_function_module_from_cache(
request, function_id
)
return function_module
for model in models:
@ -283,22 +288,22 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
model["actions"] = []
for action_id in action_ids:
action_function = Functions.get_function_by_id(action_id)
action_function = await Functions.get_function_by_id(action_id)
if action_function is None:
raise Exception(f"Action not found: {action_id}")
function_module = get_function_module_by_id(action_id)
function_module = await get_function_module_by_id(action_id)
model["actions"].extend(
get_action_items_from_module(action_function, function_module)
)
model["filters"] = []
for filter_id in filter_ids:
filter_function = Functions.get_function_by_id(filter_id)
filter_function = await Functions.get_function_by_id(filter_id)
if filter_function is None:
raise Exception(f"Filter not found: {filter_id}")
function_module = get_function_module_by_id(filter_id)
function_module = await get_function_module_by_id(filter_id)
if getattr(function_module, "toggle", None):
model["filters"].extend(

View file

@ -115,15 +115,15 @@ async def load_tool_module_by_id(tool_id, content=None):
os.unlink(temp_file.name)
def load_function_module_by_id(function_id: str, content: str | None = None):
async def load_function_module_by_id(function_id: str, content: str | None = None):
if content is None:
function = Functions.get_function_by_id(function_id)
function = await Functions.get_function_by_id(function_id)
if not function:
raise Exception(f"Function not found: {function_id}")
content = function.content
content = replace_imports(content)
Functions.update_function_by_id(function_id, {"content": content})
await Functions.update_function_by_id(function_id, {"content": content})
else:
frontmatter = extract_frontmatter(content)
install_frontmatter_requirements(frontmatter.get("requirements", ""))
@ -160,19 +160,19 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
# Cleanup by removing the module in case of error
del sys.modules[module_name]
Functions.update_function_by_id(function_id, {"is_active": False})
await Functions.update_function_by_id(function_id, {"is_active": False})
raise e
finally:
os.unlink(temp_file.name)
def get_function_module_from_cache(request, function_id, load_from_db=True):
async def get_function_module_from_cache(request, function_id, load_from_db=True):
if load_from_db:
# Always load from the database by default
# This is useful for hooks like "inlet" or "outlet" where the content might change
# and we want to ensure the latest content is used.
function = Functions.get_function_by_id(function_id)
function = await Functions.get_function_by_id(function_id)
if not function:
raise Exception(f"Function not found: {function_id}")
content = function.content
@ -181,7 +181,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
if new_content != content:
content = new_content
# Update the function content in the database
Functions.update_function_by_id(function_id, {"content": content})
await Functions.update_function_by_id(function_id, {"content": content})
if (
hasattr(request.app.state, "FUNCTION_CONTENTS")
@ -193,7 +193,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
if request.app.state.FUNCTION_CONTENTS[function_id] == content:
return request.app.state.FUNCTIONS[function_id], None, None
function_module, function_type, frontmatter = load_function_module_by_id(
function_module, function_type, frontmatter = await load_function_module_by_id(
function_id, content
)
else:
@ -206,7 +206,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
):
return request.app.state.FUNCTIONS[function_id], None, None
function_module, function_type, frontmatter = load_function_module_by_id(
function_module, function_type, frontmatter = await load_function_module_by_id(
function_id
)