diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index b2c0a14352..bde6739bc7 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -30,7 +30,6 @@ from open_webui.models.functions import Functions from open_webui.models.models import Models from open_webui.utils.plugin import ( - load_function_module_by_id, get_function_module_from_cache, ) from open_webui.utils.tools import get_tools @@ -56,17 +55,17 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -def get_function_module_by_id(request: Request, pipe_id: str): - function_module, _, _ = get_function_module_from_cache(request, pipe_id) +async def get_function_module_by_id(request: Request, pipe_id: str): + function_module, _, _ = await get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(pipe_id) + valves = await Functions.get_function_valves_by_id(pipe_id) function_module.valves = function_module.Valves(**(valves if valves else {})) return function_module async def get_function_models(request): - pipes = Functions.get_functions_by_type("pipe", active_only=True) + pipes = await Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 423f604824..f7b0557b3b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -477,7 +477,8 @@ from open_webui.constants import ERROR_MESSAGES if SAFE_MODE: print("SAFE MODE ENABLED") - Functions.deactivate_all_functions() + loop = asyncio.get_event_loop() + loop.run_until_complete(Functions.deactivate_all_functions()) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index a37f1cb3a6..48c533ca3c 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -38,7 +38,7 @@ router = APIRouter() @router.get("/", response_model=list[FunctionResponse]) async def get_functions(user=Depends(get_verified_user)): - return Functions.get_functions() + return await Functions.get_functions() ############################ @@ -48,7 +48,7 @@ async def get_functions(user=Depends(get_verified_user)): @router.get("/export", response_model=list[FunctionModel]) async def get_functions(user=Depends(get_admin_user)): - return Functions.get_functions() + return await Functions.get_functions() ############################ @@ -142,12 +142,14 @@ async def sync_functions( try: for function in form_data.functions: function.content = replace_imports(function.content) - function_module, function_type, frontmatter = load_function_module_by_id( - function.id, - content=function.content, + function_module, function_type, frontmatter = ( + await load_function_module_by_id( + function.id, + content=function.content, + ) ) - return Functions.sync_functions(user.id, form_data.functions) + return await Functions.sync_functions(user.id, form_data.functions) except Exception as e: log.exception(f"Failed to load a function: {e}") raise HTTPException( @@ -173,20 +175,24 @@ async def create_new_function( form_data.id = form_data.id.lower() - function = Functions.get_function_by_id(form_data.id) + function = await Functions.get_function_by_id(form_data.id) if function is None: try: form_data.content = replace_imports(form_data.content) - function_module, function_type, frontmatter = load_function_module_by_id( - form_data.id, - content=form_data.content, + function_module, function_type, frontmatter = ( + await load_function_module_by_id( + form_data.id, + content=form_data.content, + ) ) form_data.meta.manifest = frontmatter FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(user.id, function_type, form_data) + function = await Functions.insert_new_function( + user.id, function_type, form_data + ) function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) @@ -218,7 +224,7 @@ async def create_new_function( @router.get("/id/{id}", response_model=Optional[FunctionModel]) async def get_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: return function @@ -236,9 +242,9 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: - function = Functions.update_function_by_id( + function = await Functions.update_function_by_id( id, {"is_active": not function.is_active} ) @@ -263,9 +269,9 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: - function = Functions.update_function_by_id( + function = await unctions.update_function_by_id( id, {"is_global": not function.is_global} ) @@ -294,7 +300,7 @@ async def update_function_by_id( ): try: form_data.content = replace_imports(form_data.content) - function_module, function_type, frontmatter = load_function_module_by_id( + function_module, function_type, frontmatter = await load_function_module_by_id( id, content=form_data.content ) form_data.meta.manifest = frontmatter @@ -305,7 +311,7 @@ async def update_function_by_id( updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} log.debug(updated) - function = Functions.update_function_by_id(id, updated) + function = await Functions.update_function_by_id(id, updated) if function: return function @@ -331,7 +337,7 @@ async def update_function_by_id( async def delete_function_by_id( request: Request, id: str, user=Depends(get_admin_user) ): - result = Functions.delete_function_by_id(id) + result = await Functions.delete_function_by_id(id) if result: FUNCTIONS = request.app.state.FUNCTIONS @@ -348,10 +354,10 @@ async def delete_function_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: try: - valves = Functions.get_function_valves_by_id(id) + valves = await Functions.get_function_valves_by_id(id) return valves except Exception as e: raise HTTPException( @@ -374,10 +380,10 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): async def get_function_valves_spec_by_id( request: Request, id: str, user=Depends(get_admin_user) ): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id + function_module, function_type, frontmatter = ( + await get_function_module_from_cache(request, id) ) if hasattr(function_module, "Valves"): @@ -400,10 +406,10 @@ async def get_function_valves_spec_by_id( async def update_function_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_admin_user) ): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id + function_module, function_type, frontmatter = ( + await get_function_module_from_cache(request, id) ) if hasattr(function_module, "Valves"): @@ -412,7 +418,7 @@ async def update_function_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) - Functions.update_function_valves_by_id(id, valves.model_dump()) + await Functions.update_function_valves_by_id(id, valves.model_dump()) return valves.model_dump() except Exception as e: log.exception(f"Error updating function values by id {id}: {e}") @@ -461,10 +467,10 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user async def get_function_user_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id + function_module, function_type, frontmatter = ( + await get_function_module_from_cache(request, id) ) if hasattr(function_module, "UserValves"): @@ -485,8 +491,8 @@ async def update_function_user_valves_by_id( function = await Functions.get_function_by_id(id) if function: - function_module, function_type, frontmatter = get_function_module_from_cache( - request, id + function_module, function_type, frontmatter = ( + await get_function_module_from_cache(request, id) ) if hasattr(function_module, "UserValves"): diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 56096579dd..37b53c260f 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1433,7 +1433,7 @@ async def generate_openai_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) + model_info = await Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1514,7 +1514,7 @@ async def generate_openai_chat_completion( if ":" not in model_id: model_id = f"{model_id}:latest" - model_info = Models.get_model_by_id(model_id) + model_info = await Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: payload["model"] = model_info.base_model_id @@ -1629,7 +1629,7 @@ async def get_openai_models( # Filter models based on user access control filtered_models = [] for model in models: - model_info = Models.get_model_by_id(model["id"]) + model_info = await Models.get_model_by_id(model["id"]) if model_info: if user.id == model_info.user_id or ( await has_access( diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 16cfb2e82f..58049821b5 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -41,7 +41,6 @@ from open_webui.models.models import Models from open_webui.utils.plugin import ( - load_function_module_by_id, get_function_module_from_cache, ) from open_webui.utils.models import get_all_models, check_model_access @@ -328,8 +327,8 @@ async def chat_completed(request: Request, form_data: dict, user: Any): try: filter_functions = [ - Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( + await Functions.get_function_by_id(filter_id) + for filter_id in await get_sorted_filter_ids( request, model, metadata.get("filter_ids", []) ) ] @@ -352,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A else: sub_action_id = None - action = Functions.get_function_by_id(action_id) + action = await Functions.get_function_by_id(action_id) if not action: raise Exception(f"Action not found: {action_id}") @@ -390,10 +389,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A } ) - function_module, _, _ = get_function_module_from_cache(request, action_id) + function_module, _, _ = await get_function_module_from_cache(request, action_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) + valves = await Functions.get_function_valves_by_id(action_id) function_module.valves = function_module.Valves(**(valves if valves else {})) if hasattr(function_module, "action"): diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index dbb25d9368..c3d2805872 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -2,7 +2,6 @@ import inspect import logging from open_webui.utils.plugin import ( - load_function_module_by_id, get_function_module_from_cache, ) from open_webui.models.functions import Functions @@ -12,35 +11,39 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -def get_function_module(request, function_id, load_from_db=True): +async def get_function_module(request, function_id, load_from_db=True): """ Get the function module by its ID. """ - function_module, _, _ = get_function_module_from_cache( + function_module, _, _ = await get_function_module_from_cache( request, function_id, load_from_db ) return function_module -def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None): - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) +async def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None): + async def get_priority(function_id): + function = await Functions.get_function_by_id(function_id) if function is not None: - valves = Functions.get_function_valves_by_id(function_id) + valves = await Functions.get_function_valves_by_id(function_id) return valves.get("priority", 0) if valves else 0 return 0 - filter_ids = [function.id for function in Functions.get_global_filter_functions()] + filter_ids = [ + function.id for function in await Functions.get_global_filter_functions() + ] if "info" in model and "meta" in model["info"]: filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids = list(set(filter_ids)) active_filter_ids = [ function.id - for function in Functions.get_functions_by_type("filter", active_only=True) + for function in await Functions.get_functions_by_type( + "filter", active_only=True + ) ] - def get_active_status(filter_id): - function_module = get_function_module(request, filter_id) + async def get_active_status(filter_id): + function_module = await get_function_module(request, filter_id) if getattr(function_module, "toggle", None): return filter_id in (enabled_filter_ids or []) @@ -48,10 +51,13 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None) return True active_filter_ids = [ - filter_id for filter_id in active_filter_ids if get_active_status(filter_id) + filter_id + for filter_id in active_filter_ids + if await get_active_status(filter_id) ] filter_ids = [fid for fid in filter_ids if fid in active_filter_ids] + filter_ids.sort(key=get_priority) return filter_ids @@ -68,7 +74,7 @@ async def process_filter_functions( if not filter: continue - function_module = get_function_module( + function_module = await get_function_module( request, filter_id, load_from_db=(filter_type != "stream") ) # Prepare handler function @@ -82,7 +88,7 @@ async def process_filter_functions( # Apply valves to the function if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) + valves = await Functions.get_function_valves_by_id(filter_id) function_module.valves = function_module.Valves( **(valves if valves else {}) ) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 72c2707846..62b3fcf120 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -77,7 +77,6 @@ from open_webui.utils.misc import ( convert_logit_bias_input_to_json, ) from open_webui.utils.tools import get_tools -from open_webui.utils.plugin import load_function_module_by_id from open_webui.utils.filter import ( get_sorted_filter_ids, process_filter_functions, @@ -840,8 +839,8 @@ async def process_chat_payload(request, form_data, user, metadata, model): try: filter_functions = [ - Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( + await Functions.get_function_by_id(filter_id) + for filter_id in await get_sorted_filter_ids( request, model, metadata.get("filter_ids", []) ) ] @@ -908,7 +907,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): tools_dict = {} if tool_ids: - tools_dict = get_tools( + tools_dict = await get_tools( request, tool_ids, user, @@ -1323,7 +1322,9 @@ async def process_chat_response( # Send a webhook notification if the user is not active if not get_active_status_by_user_id(user.id): - webhook_url = Users.get_user_webhook_url_by_id(user.id) + webhook_url = await Users.get_user_webhook_url_by_id( + user.id + ) if webhook_url: post_webhook( request.app.state.WEBUI_NAME, @@ -1394,8 +1395,8 @@ async def process_chat_response( "__model__": model, } filter_functions = [ - Functions.get_function_by_id(filter_id) - for filter_id in get_sorted_filter_ids( + await Functions.get_function_by_id(filter_id) + for filter_id in await get_sorted_filter_ids( request, model, metadata.get("filter_ids", []) ) ] @@ -2523,7 +2524,7 @@ async def process_chat_response( # Send a webhook notification if the user is not active if not get_active_status_by_user_id(user.id): - webhook_url = Users.get_user_webhook_url_by_id(user.id) + webhook_url = await Users.get_user_webhook_url_by_id(user.id) if webhook_url: post_webhook( request.app.state.WEBUI_NAME, diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index ae637615c9..43e1b9e44a 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -15,7 +15,6 @@ from open_webui.models.models import Models from open_webui.utils.plugin import ( - load_function_module_by_id, get_function_module_from_cache, ) from open_webui.utils.access_control import has_access @@ -130,19 +129,23 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) models = models + arena_models global_action_ids = [ - function.id for function in Functions.get_global_action_functions() + function.id for function in await Functions.get_global_action_functions() ] enabled_action_ids = [ function.id - for function in Functions.get_functions_by_type("action", active_only=True) + for function in await Functions.get_functions_by_type( + "action", active_only=True + ) ] global_filter_ids = [ - function.id for function in Functions.get_global_filter_functions() + function.id for function in await Functions.get_global_filter_functions() ] enabled_filter_ids = [ function.id - for function in Functions.get_functions_by_type("filter", active_only=True) + for function in await Functions.get_functions_by_type( + "filter", active_only=True + ) ] custom_models = await Models.get_all_models() @@ -265,8 +268,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) } ] - def get_function_module_by_id(function_id): - function_module, _, _ = get_function_module_from_cache(request, function_id) + async def get_function_module_by_id(function_id): + function_module, _, _ = await get_function_module_from_cache( + request, function_id + ) return function_module for model in models: @@ -283,22 +288,22 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) model["actions"] = [] for action_id in action_ids: - action_function = Functions.get_function_by_id(action_id) + action_function = await Functions.get_function_by_id(action_id) if action_function is None: raise Exception(f"Action not found: {action_id}") - function_module = get_function_module_by_id(action_id) + function_module = await get_function_module_by_id(action_id) model["actions"].extend( get_action_items_from_module(action_function, function_module) ) model["filters"] = [] for filter_id in filter_ids: - filter_function = Functions.get_function_by_id(filter_id) + filter_function = await Functions.get_function_by_id(filter_id) if filter_function is None: raise Exception(f"Filter not found: {filter_id}") - function_module = get_function_module_by_id(filter_id) + function_module = await get_function_module_by_id(filter_id) if getattr(function_module, "toggle", None): model["filters"].extend( diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index a38657a854..15d885ba58 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -115,15 +115,15 @@ async def load_tool_module_by_id(tool_id, content=None): os.unlink(temp_file.name) -def load_function_module_by_id(function_id: str, content: str | None = None): +async def load_function_module_by_id(function_id: str, content: str | None = None): if content is None: - function = Functions.get_function_by_id(function_id) + function = await Functions.get_function_by_id(function_id) if not function: raise Exception(f"Function not found: {function_id}") content = function.content content = replace_imports(content) - Functions.update_function_by_id(function_id, {"content": content}) + await Functions.update_function_by_id(function_id, {"content": content}) else: frontmatter = extract_frontmatter(content) install_frontmatter_requirements(frontmatter.get("requirements", "")) @@ -160,19 +160,19 @@ def load_function_module_by_id(function_id: str, content: str | None = None): # Cleanup by removing the module in case of error del sys.modules[module_name] - Functions.update_function_by_id(function_id, {"is_active": False}) + await Functions.update_function_by_id(function_id, {"is_active": False}) raise e finally: os.unlink(temp_file.name) -def get_function_module_from_cache(request, function_id, load_from_db=True): +async def get_function_module_from_cache(request, function_id, load_from_db=True): if load_from_db: # Always load from the database by default # This is useful for hooks like "inlet" or "outlet" where the content might change # and we want to ensure the latest content is used. - function = Functions.get_function_by_id(function_id) + function = await Functions.get_function_by_id(function_id) if not function: raise Exception(f"Function not found: {function_id}") content = function.content @@ -181,7 +181,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True): if new_content != content: content = new_content # Update the function content in the database - Functions.update_function_by_id(function_id, {"content": content}) + await Functions.update_function_by_id(function_id, {"content": content}) if ( hasattr(request.app.state, "FUNCTION_CONTENTS") @@ -193,7 +193,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True): if request.app.state.FUNCTION_CONTENTS[function_id] == content: return request.app.state.FUNCTIONS[function_id], None, None - function_module, function_type, frontmatter = load_function_module_by_id( + function_module, function_type, frontmatter = await load_function_module_by_id( function_id, content ) else: @@ -206,7 +206,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True): ): return request.app.state.FUNCTIONS[function_id], None, None - function_module, function_type, frontmatter = load_function_module_by_id( + function_module, function_type, frontmatter = await load_function_module_by_id( function_id )