wip: functions

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 17:20:13 +04:00
parent 3b9e454fb4
commit 9a210e743d
9 changed files with 107 additions and 90 deletions

View file

@ -30,7 +30,6 @@ from open_webui.models.functions import Functions
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.utils.plugin import ( from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache, get_function_module_from_cache,
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
@ -56,17 +55,17 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module_by_id(request: Request, pipe_id: str): async def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id) function_module, _, _ = await get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id) valves = await Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(**(valves if valves else {})) function_module.valves = function_module.Valves(**(valves if valves else {}))
return function_module return function_module
async def get_function_models(request): async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True) pipes = await Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:

View file

@ -477,7 +477,8 @@ from open_webui.constants import ERROR_MESSAGES
if SAFE_MODE: if SAFE_MODE:
print("SAFE MODE ENABLED") print("SAFE MODE ENABLED")
Functions.deactivate_all_functions() loop = asyncio.get_event_loop()
loop.run_until_complete(Functions.deactivate_all_functions())
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -38,7 +38,7 @@ router = APIRouter()
@router.get("/", response_model=list[FunctionResponse]) @router.get("/", response_model=list[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)): async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions() return await Functions.get_functions()
############################ ############################
@ -48,7 +48,7 @@ async def get_functions(user=Depends(get_verified_user)):
@router.get("/export", response_model=list[FunctionModel]) @router.get("/export", response_model=list[FunctionModel])
async def get_functions(user=Depends(get_admin_user)): async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions() return await Functions.get_functions()
############################ ############################
@ -142,12 +142,14 @@ async def sync_functions(
try: try:
for function in form_data.functions: for function in form_data.functions:
function.content = replace_imports(function.content) function.content = replace_imports(function.content)
function_module, function_type, frontmatter = load_function_module_by_id( function_module, function_type, frontmatter = (
await load_function_module_by_id(
function.id, function.id,
content=function.content, content=function.content,
) )
)
return Functions.sync_functions(user.id, form_data.functions) return await Functions.sync_functions(user.id, form_data.functions)
except Exception as e: except Exception as e:
log.exception(f"Failed to load a function: {e}") log.exception(f"Failed to load a function: {e}")
raise HTTPException( raise HTTPException(
@ -173,20 +175,24 @@ async def create_new_function(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id) function = await Functions.get_function_by_id(form_data.id)
if function is None: if function is None:
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
function_module, function_type, frontmatter = load_function_module_by_id( function_module, function_type, frontmatter = (
await load_function_module_by_id(
form_data.id, form_data.id,
content=form_data.content, content=form_data.content,
) )
)
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data) function = await Functions.insert_new_function(
user.id, function_type, form_data
)
function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir = CACHE_DIR / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True) function_cache_dir.mkdir(parents=True, exist_ok=True)
@ -218,7 +224,7 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel]) @router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)): async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
return function return function
@ -236,9 +242,9 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel]) @router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function = Functions.update_function_by_id( function = await Functions.update_function_by_id(
id, {"is_active": not function.is_active} id, {"is_active": not function.is_active}
) )
@ -263,9 +269,9 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel]) @router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)): async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function = Functions.update_function_by_id( function = await unctions.update_function_by_id(
id, {"is_global": not function.is_global} id, {"is_global": not function.is_global}
) )
@ -294,7 +300,7 @@ async def update_function_by_id(
): ):
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
function_module, function_type, frontmatter = load_function_module_by_id( function_module, function_type, frontmatter = await load_function_module_by_id(
id, content=form_data.content id, content=form_data.content
) )
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
@ -305,7 +311,7 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
log.debug(updated) log.debug(updated)
function = Functions.update_function_by_id(id, updated) function = await Functions.update_function_by_id(id, updated)
if function: if function:
return function return function
@ -331,7 +337,7 @@ async def update_function_by_id(
async def delete_function_by_id( async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user) request: Request, id: str, user=Depends(get_admin_user)
): ):
result = Functions.delete_function_by_id(id) result = await Functions.delete_function_by_id(id)
if result: if result:
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
@ -348,10 +354,10 @@ async def delete_function_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict]) @router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)): async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
try: try:
valves = Functions.get_function_valves_by_id(id) valves = await Functions.get_function_valves_by_id(id)
return valves return valves
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@ -374,10 +380,10 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
async def get_function_valves_spec_by_id( async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user) request: Request, id: str, user=Depends(get_admin_user)
): ):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = (
request, id await get_function_module_from_cache(request, id)
) )
if hasattr(function_module, "Valves"): if hasattr(function_module, "Valves"):
@ -400,10 +406,10 @@ async def get_function_valves_spec_by_id(
async def update_function_valves_by_id( async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user) request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
): ):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = (
request, id await get_function_module_from_cache(request, id)
) )
if hasattr(function_module, "Valves"): if hasattr(function_module, "Valves"):
@ -412,7 +418,7 @@ async def update_function_valves_by_id(
try: try:
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data) valves = Valves(**form_data)
Functions.update_function_valves_by_id(id, valves.model_dump()) await Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump() return valves.model_dump()
except Exception as e: except Exception as e:
log.exception(f"Error updating function values by id {id}: {e}") log.exception(f"Error updating function values by id {id}: {e}")
@ -461,10 +467,10 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
async def get_function_user_valves_spec_by_id( async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user)
): ):
function = Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = (
request, id await get_function_module_from_cache(request, id)
) )
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):
@ -485,8 +491,8 @@ async def update_function_user_valves_by_id(
function = await Functions.get_function_by_id(id) function = await Functions.get_function_by_id(id)
if function: if function:
function_module, function_type, frontmatter = get_function_module_from_cache( function_module, function_type, frontmatter = (
request, id await get_function_module_from_cache(request, id)
) )
if hasattr(function_module, "UserValves"): if hasattr(function_module, "UserValves"):

View file

@ -1433,7 +1433,7 @@ async def generate_openai_completion(
if ":" not in model_id: if ":" not in model_id:
model_id = f"{model_id}:latest" model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id) model_info = await Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
@ -1514,7 +1514,7 @@ async def generate_openai_chat_completion(
if ":" not in model_id: if ":" not in model_id:
model_id = f"{model_id}:latest" model_id = f"{model_id}:latest"
model_info = Models.get_model_by_id(model_id) model_info = await Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
@ -1629,7 +1629,7 @@ async def get_openai_models(
# Filter models based on user access control # Filter models based on user access control
filtered_models = [] filtered_models = []
for model in models: for model in models:
model_info = Models.get_model_by_id(model["id"]) model_info = await Models.get_model_by_id(model["id"])
if model_info: if model_info:
if user.id == model_info.user_id or ( if user.id == model_info.user_id or (
await has_access( await has_access(

View file

@ -41,7 +41,6 @@ from open_webui.models.models import Models
from open_webui.utils.plugin import ( from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache, get_function_module_from_cache,
) )
from open_webui.utils.models import get_all_models, check_model_access from open_webui.utils.models import get_all_models, check_model_access
@ -328,8 +327,8 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
try: try:
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) await Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids( for filter_id in await get_sorted_filter_ids(
request, model, metadata.get("filter_ids", []) request, model, metadata.get("filter_ids", [])
) )
] ]
@ -352,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
else: else:
sub_action_id = None sub_action_id = None
action = Functions.get_function_by_id(action_id) action = await Functions.get_function_by_id(action_id)
if not action: if not action:
raise Exception(f"Action not found: {action_id}") raise Exception(f"Action not found: {action_id}")
@ -390,10 +389,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
} }
) )
function_module, _, _ = get_function_module_from_cache(request, action_id) function_module, _, _ = await get_function_module_from_cache(request, action_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id) valves = await Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {})) function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"): if hasattr(function_module, "action"):

View file

@ -2,7 +2,6 @@ import inspect
import logging import logging
from open_webui.utils.plugin import ( from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache, get_function_module_from_cache,
) )
from open_webui.models.functions import Functions from open_webui.models.functions import Functions
@ -12,35 +11,39 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
def get_function_module(request, function_id, load_from_db=True): async def get_function_module(request, function_id, load_from_db=True):
""" """
Get the function module by its ID. Get the function module by its ID.
""" """
function_module, _, _ = get_function_module_from_cache( function_module, _, _ = await get_function_module_from_cache(
request, function_id, load_from_db request, function_id, load_from_db
) )
return function_module return function_module
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None): async def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
def get_priority(function_id): async def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = await Functions.get_function_by_id(function_id)
if function is not None: if function is not None:
valves = Functions.get_function_valves_by_id(function_id) valves = await Functions.get_function_valves_by_id(function_id)
return valves.get("priority", 0) if valves else 0 return valves.get("priority", 0) if valves else 0
return 0 return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()] filter_ids = [
function.id for function in await Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids)) filter_ids = list(set(filter_ids))
active_filter_ids = [ active_filter_ids = [
function.id function.id
for function in Functions.get_functions_by_type("filter", active_only=True) for function in await Functions.get_functions_by_type(
"filter", active_only=True
)
] ]
def get_active_status(filter_id): async def get_active_status(filter_id):
function_module = get_function_module(request, filter_id) function_module = await get_function_module(request, filter_id)
if getattr(function_module, "toggle", None): if getattr(function_module, "toggle", None):
return filter_id in (enabled_filter_ids or []) return filter_id in (enabled_filter_ids or [])
@ -48,10 +51,13 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None)
return True return True
active_filter_ids = [ active_filter_ids = [
filter_id for filter_id in active_filter_ids if get_active_status(filter_id) filter_id
for filter_id in active_filter_ids
if await get_active_status(filter_id)
] ]
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids] filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
filter_ids.sort(key=get_priority) filter_ids.sort(key=get_priority)
return filter_ids return filter_ids
@ -68,7 +74,7 @@ async def process_filter_functions(
if not filter: if not filter:
continue continue
function_module = get_function_module( function_module = await get_function_module(
request, filter_id, load_from_db=(filter_type != "stream") request, filter_id, load_from_db=(filter_type != "stream")
) )
# Prepare handler function # Prepare handler function
@ -82,7 +88,7 @@ async def process_filter_functions(
# Apply valves to the function # Apply valves to the function
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id) valves = await Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves( function_module.valves = function_module.Valves(
**(valves if valves else {}) **(valves if valves else {})
) )

View file

@ -77,7 +77,6 @@ from open_webui.utils.misc import (
convert_logit_bias_input_to_json, convert_logit_bias_input_to_json,
) )
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
from open_webui.utils.plugin import load_function_module_by_id
from open_webui.utils.filter import ( from open_webui.utils.filter import (
get_sorted_filter_ids, get_sorted_filter_ids,
process_filter_functions, process_filter_functions,
@ -840,8 +839,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
try: try:
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) await Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids( for filter_id in await get_sorted_filter_ids(
request, model, metadata.get("filter_ids", []) request, model, metadata.get("filter_ids", [])
) )
] ]
@ -908,7 +907,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tools_dict = {} tools_dict = {}
if tool_ids: if tool_ids:
tools_dict = get_tools( tools_dict = await get_tools(
request, request,
tool_ids, tool_ids,
user, user,
@ -1323,7 +1322,9 @@ async def process_chat_response(
# Send a webhook notification if the user is not active # Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id): if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id) webhook_url = await Users.get_user_webhook_url_by_id(
user.id
)
if webhook_url: if webhook_url:
post_webhook( post_webhook(
request.app.state.WEBUI_NAME, request.app.state.WEBUI_NAME,
@ -1394,8 +1395,8 @@ async def process_chat_response(
"__model__": model, "__model__": model,
} }
filter_functions = [ filter_functions = [
Functions.get_function_by_id(filter_id) await Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids( for filter_id in await get_sorted_filter_ids(
request, model, metadata.get("filter_ids", []) request, model, metadata.get("filter_ids", [])
) )
] ]
@ -2523,7 +2524,7 @@ async def process_chat_response(
# Send a webhook notification if the user is not active # Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id): if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id) webhook_url = await Users.get_user_webhook_url_by_id(user.id)
if webhook_url: if webhook_url:
post_webhook( post_webhook(
request.app.state.WEBUI_NAME, request.app.state.WEBUI_NAME,

View file

@ -15,7 +15,6 @@ from open_webui.models.models import Models
from open_webui.utils.plugin import ( from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache, get_function_module_from_cache,
) )
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
@ -130,19 +129,23 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
models = models + arena_models models = models + arena_models
global_action_ids = [ global_action_ids = [
function.id for function in Functions.get_global_action_functions() function.id for function in await Functions.get_global_action_functions()
] ]
enabled_action_ids = [ enabled_action_ids = [
function.id function.id
for function in Functions.get_functions_by_type("action", active_only=True) for function in await Functions.get_functions_by_type(
"action", active_only=True
)
] ]
global_filter_ids = [ global_filter_ids = [
function.id for function in Functions.get_global_filter_functions() function.id for function in await Functions.get_global_filter_functions()
] ]
enabled_filter_ids = [ enabled_filter_ids = [
function.id function.id
for function in Functions.get_functions_by_type("filter", active_only=True) for function in await Functions.get_functions_by_type(
"filter", active_only=True
)
] ]
custom_models = await Models.get_all_models() custom_models = await Models.get_all_models()
@ -265,8 +268,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
} }
] ]
def get_function_module_by_id(function_id): async def get_function_module_by_id(function_id):
function_module, _, _ = get_function_module_from_cache(request, function_id) function_module, _, _ = await get_function_module_from_cache(
request, function_id
)
return function_module return function_module
for model in models: for model in models:
@ -283,22 +288,22 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
model["actions"] = [] model["actions"] = []
for action_id in action_ids: for action_id in action_ids:
action_function = Functions.get_function_by_id(action_id) action_function = await Functions.get_function_by_id(action_id)
if action_function is None: if action_function is None:
raise Exception(f"Action not found: {action_id}") raise Exception(f"Action not found: {action_id}")
function_module = get_function_module_by_id(action_id) function_module = await get_function_module_by_id(action_id)
model["actions"].extend( model["actions"].extend(
get_action_items_from_module(action_function, function_module) get_action_items_from_module(action_function, function_module)
) )
model["filters"] = [] model["filters"] = []
for filter_id in filter_ids: for filter_id in filter_ids:
filter_function = Functions.get_function_by_id(filter_id) filter_function = await Functions.get_function_by_id(filter_id)
if filter_function is None: if filter_function is None:
raise Exception(f"Filter not found: {filter_id}") raise Exception(f"Filter not found: {filter_id}")
function_module = get_function_module_by_id(filter_id) function_module = await get_function_module_by_id(filter_id)
if getattr(function_module, "toggle", None): if getattr(function_module, "toggle", None):
model["filters"].extend( model["filters"].extend(

View file

@ -115,15 +115,15 @@ async def load_tool_module_by_id(tool_id, content=None):
os.unlink(temp_file.name) os.unlink(temp_file.name)
def load_function_module_by_id(function_id: str, content: str | None = None): async def load_function_module_by_id(function_id: str, content: str | None = None):
if content is None: if content is None:
function = Functions.get_function_by_id(function_id) function = await Functions.get_function_by_id(function_id)
if not function: if not function:
raise Exception(f"Function not found: {function_id}") raise Exception(f"Function not found: {function_id}")
content = function.content content = function.content
content = replace_imports(content) content = replace_imports(content)
Functions.update_function_by_id(function_id, {"content": content}) await Functions.update_function_by_id(function_id, {"content": content})
else: else:
frontmatter = extract_frontmatter(content) frontmatter = extract_frontmatter(content)
install_frontmatter_requirements(frontmatter.get("requirements", "")) install_frontmatter_requirements(frontmatter.get("requirements", ""))
@ -160,19 +160,19 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
# Cleanup by removing the module in case of error # Cleanup by removing the module in case of error
del sys.modules[module_name] del sys.modules[module_name]
Functions.update_function_by_id(function_id, {"is_active": False}) await Functions.update_function_by_id(function_id, {"is_active": False})
raise e raise e
finally: finally:
os.unlink(temp_file.name) os.unlink(temp_file.name)
def get_function_module_from_cache(request, function_id, load_from_db=True): async def get_function_module_from_cache(request, function_id, load_from_db=True):
if load_from_db: if load_from_db:
# Always load from the database by default # Always load from the database by default
# This is useful for hooks like "inlet" or "outlet" where the content might change # This is useful for hooks like "inlet" or "outlet" where the content might change
# and we want to ensure the latest content is used. # and we want to ensure the latest content is used.
function = Functions.get_function_by_id(function_id) function = await Functions.get_function_by_id(function_id)
if not function: if not function:
raise Exception(f"Function not found: {function_id}") raise Exception(f"Function not found: {function_id}")
content = function.content content = function.content
@ -181,7 +181,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
if new_content != content: if new_content != content:
content = new_content content = new_content
# Update the function content in the database # Update the function content in the database
Functions.update_function_by_id(function_id, {"content": content}) await Functions.update_function_by_id(function_id, {"content": content})
if ( if (
hasattr(request.app.state, "FUNCTION_CONTENTS") hasattr(request.app.state, "FUNCTION_CONTENTS")
@ -193,7 +193,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
if request.app.state.FUNCTION_CONTENTS[function_id] == content: if request.app.state.FUNCTION_CONTENTS[function_id] == content:
return request.app.state.FUNCTIONS[function_id], None, None return request.app.state.FUNCTIONS[function_id], None, None
function_module, function_type, frontmatter = load_function_module_by_id( function_module, function_type, frontmatter = await load_function_module_by_id(
function_id, content function_id, content
) )
else: else:
@ -206,7 +206,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
): ):
return request.app.state.FUNCTIONS[function_id], None, None return request.app.state.FUNCTIONS[function_id], None, None
function_module, function_type, frontmatter = load_function_module_by_id( function_module, function_type, frontmatter = await load_function_module_by_id(
function_id function_id
) )