mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
wip: functions
This commit is contained in:
parent
3b9e454fb4
commit
9a210e743d
9 changed files with 107 additions and 90 deletions
|
|
@ -30,7 +30,6 @@ from open_webui.models.functions import Functions
|
|||
from open_webui.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
|
|
@ -56,17 +55,17 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
||||
async def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
function_module, _, _ = await get_function_module_from_cache(request, pipe_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
valves = await Functions.get_function_valves_by_id(pipe_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
return function_module
|
||||
|
||||
|
||||
async def get_function_models(request):
|
||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipes = await Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
|
|
|
|||
|
|
@ -477,7 +477,8 @@ from open_webui.constants import ERROR_MESSAGES
|
|||
|
||||
if SAFE_MODE:
|
||||
print("SAFE MODE ENABLED")
|
||||
Functions.deactivate_all_functions()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(Functions.deactivate_all_functions())
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/", response_model=list[FunctionResponse])
|
||||
async def get_functions(user=Depends(get_verified_user)):
|
||||
return Functions.get_functions()
|
||||
return await Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -48,7 +48,7 @@ async def get_functions(user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/export", response_model=list[FunctionModel])
|
||||
async def get_functions(user=Depends(get_admin_user)):
|
||||
return Functions.get_functions()
|
||||
return await Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -142,12 +142,14 @@ async def sync_functions(
|
|||
try:
|
||||
for function in form_data.functions:
|
||||
function.content = replace_imports(function.content)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function.id,
|
||||
content=function.content,
|
||||
function_module, function_type, frontmatter = (
|
||||
await load_function_module_by_id(
|
||||
function.id,
|
||||
content=function.content,
|
||||
)
|
||||
)
|
||||
|
||||
return Functions.sync_functions(user.id, form_data.functions)
|
||||
return await Functions.sync_functions(user.id, form_data.functions)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to load a function: {e}")
|
||||
raise HTTPException(
|
||||
|
|
@ -173,20 +175,24 @@ async def create_new_function(
|
|||
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
function = Functions.get_function_by_id(form_data.id)
|
||||
function = await Functions.get_function_by_id(form_data.id)
|
||||
if function is None:
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
form_data.id,
|
||||
content=form_data.content,
|
||||
function_module, function_type, frontmatter = (
|
||||
await load_function_module_by_id(
|
||||
form_data.id,
|
||||
content=form_data.content,
|
||||
)
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
FUNCTIONS[form_data.id] = function_module
|
||||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
function = await Functions.insert_new_function(
|
||||
user.id, function_type, form_data
|
||||
)
|
||||
|
||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -218,7 +224,7 @@ async def create_new_function(
|
|||
|
||||
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
||||
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
return function
|
||||
|
|
@ -236,9 +242,9 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
||||
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
function = await Functions.update_function_by_id(
|
||||
id, {"is_active": not function.is_active}
|
||||
)
|
||||
|
||||
|
|
@ -263,9 +269,9 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
|||
|
||||
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function = Functions.update_function_by_id(
|
||||
function = await unctions.update_function_by_id(
|
||||
id, {"is_global": not function.is_global}
|
||||
)
|
||||
|
||||
|
|
@ -294,7 +300,7 @@ async def update_function_by_id(
|
|||
):
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function_module, function_type, frontmatter = await load_function_module_by_id(
|
||||
id, content=form_data.content
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
|
@ -305,7 +311,7 @@ async def update_function_by_id(
|
|||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
log.debug(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
function = await Functions.update_function_by_id(id, updated)
|
||||
|
||||
if function:
|
||||
return function
|
||||
|
|
@ -331,7 +337,7 @@ async def update_function_by_id(
|
|||
async def delete_function_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
result = Functions.delete_function_by_id(id)
|
||||
result = await Functions.delete_function_by_id(id)
|
||||
|
||||
if result:
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
|
|
@ -348,10 +354,10 @@ async def delete_function_by_id(
|
|||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
try:
|
||||
valves = Functions.get_function_valves_by_id(id)
|
||||
valves = await Functions.get_function_valves_by_id(id)
|
||||
return valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -374,10 +380,10 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
|||
async def get_function_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
function_module, function_type, frontmatter = (
|
||||
await get_function_module_from_cache(request, id)
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
|
|
@ -400,10 +406,10 @@ async def get_function_valves_spec_by_id(
|
|||
async def update_function_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
function_module, function_type, frontmatter = (
|
||||
await get_function_module_from_cache(request, id)
|
||||
)
|
||||
|
||||
if hasattr(function_module, "Valves"):
|
||||
|
|
@ -412,7 +418,7 @@ async def update_function_valves_by_id(
|
|||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
valves = Valves(**form_data)
|
||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
await Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
log.exception(f"Error updating function values by id {id}: {e}")
|
||||
|
|
@ -461,10 +467,10 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
|
|||
async def get_function_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
function_module, function_type, frontmatter = (
|
||||
await get_function_module_from_cache(request, id)
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
|
|
@ -485,8 +491,8 @@ async def update_function_user_valves_by_id(
|
|||
function = await Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
request, id
|
||||
function_module, function_type, frontmatter = (
|
||||
await get_function_module_from_cache(request, id)
|
||||
)
|
||||
|
||||
if hasattr(function_module, "UserValves"):
|
||||
|
|
|
|||
|
|
@ -1433,7 +1433,7 @@ async def generate_openai_completion(
|
|||
if ":" not in model_id:
|
||||
model_id = f"{model_id}:latest"
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
model_info = await Models.get_model_by_id(model_id)
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
|
@ -1514,7 +1514,7 @@ async def generate_openai_chat_completion(
|
|||
if ":" not in model_id:
|
||||
model_id = f"{model_id}:latest"
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
model_info = await Models.get_model_by_id(model_id)
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
|
@ -1629,7 +1629,7 @@ async def get_openai_models(
|
|||
# Filter models based on user access control
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
model_info = await Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or (
|
||||
await has_access(
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ from open_webui.models.models import Models
|
|||
|
||||
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.models import get_all_models, check_model_access
|
||||
|
|
@ -328,8 +327,8 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
await Functions.get_function_by_id(filter_id)
|
||||
for filter_id in await get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
|
@ -352,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|||
else:
|
||||
sub_action_id = None
|
||||
|
||||
action = Functions.get_function_by_id(action_id)
|
||||
action = await Functions.get_function_by_id(action_id)
|
||||
if not action:
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
|
|
@ -390,10 +389,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|||
}
|
||||
)
|
||||
|
||||
function_module, _, _ = get_function_module_from_cache(request, action_id)
|
||||
function_module, _, _ = await get_function_module_from_cache(request, action_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(action_id)
|
||||
valves = await Functions.get_function_valves_by_id(action_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
|
||||
if hasattr(function_module, "action"):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import inspect
|
|||
import logging
|
||||
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.models.functions import Functions
|
||||
|
|
@ -12,35 +11,39 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_function_module(request, function_id, load_from_db=True):
|
||||
async def get_function_module(request, function_id, load_from_db=True):
|
||||
"""
|
||||
Get the function module by its ID.
|
||||
"""
|
||||
function_module, _, _ = get_function_module_from_cache(
|
||||
function_module, _, _ = await get_function_module_from_cache(
|
||||
request, function_id, load_from_db
|
||||
)
|
||||
return function_module
|
||||
|
||||
|
||||
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
async def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
||||
async def get_priority(function_id):
|
||||
function = await Functions.get_function_by_id(function_id)
|
||||
if function is not None:
|
||||
valves = Functions.get_function_valves_by_id(function_id)
|
||||
valves = await Functions.get_function_valves_by_id(function_id)
|
||||
return valves.get("priority", 0) if valves else 0
|
||||
return 0
|
||||
|
||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||
filter_ids = [
|
||||
function.id for function in await Functions.get_global_filter_functions()
|
||||
]
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||
filter_ids = list(set(filter_ids))
|
||||
active_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
for function in await Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
|
||||
def get_active_status(filter_id):
|
||||
function_module = get_function_module(request, filter_id)
|
||||
async def get_active_status(filter_id):
|
||||
function_module = await get_function_module(request, filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None):
|
||||
return filter_id in (enabled_filter_ids or [])
|
||||
|
|
@ -48,10 +51,13 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None)
|
|||
return True
|
||||
|
||||
active_filter_ids = [
|
||||
filter_id for filter_id in active_filter_ids if get_active_status(filter_id)
|
||||
filter_id
|
||||
for filter_id in active_filter_ids
|
||||
if await get_active_status(filter_id)
|
||||
]
|
||||
|
||||
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
||||
|
||||
filter_ids.sort(key=get_priority)
|
||||
|
||||
return filter_ids
|
||||
|
|
@ -68,7 +74,7 @@ async def process_filter_functions(
|
|||
if not filter:
|
||||
continue
|
||||
|
||||
function_module = get_function_module(
|
||||
function_module = await get_function_module(
|
||||
request, filter_id, load_from_db=(filter_type != "stream")
|
||||
)
|
||||
# Prepare handler function
|
||||
|
|
@ -82,7 +88,7 @@ async def process_filter_functions(
|
|||
|
||||
# Apply valves to the function
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(filter_id)
|
||||
valves = await Functions.get_function_valves_by_id(filter_id)
|
||||
function_module.valves = function_module.Valves(
|
||||
**(valves if valves else {})
|
||||
)
|
||||
|
|
|
|||
|
|
@ -77,7 +77,6 @@ from open_webui.utils.misc import (
|
|||
convert_logit_bias_input_to_json,
|
||||
)
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
|
|
@ -840,8 +839,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
await Functions.get_function_by_id(filter_id)
|
||||
for filter_id in await get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
|
@ -908,7 +907,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
tools_dict = {}
|
||||
|
||||
if tool_ids:
|
||||
tools_dict = get_tools(
|
||||
tools_dict = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
|
|
@ -1323,7 +1322,9 @@ async def process_chat_response(
|
|||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
webhook_url = await Users.get_user_webhook_url_by_id(
|
||||
user.id
|
||||
)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
|
|
@ -1394,8 +1395,8 @@ async def process_chat_response(
|
|||
"__model__": model,
|
||||
}
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
await Functions.get_function_by_id(filter_id)
|
||||
for filter_id in await get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
|
@ -2523,7 +2524,7 @@ async def process_chat_response(
|
|||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
webhook_url = await Users.get_user_webhook_url_by_id(user.id)
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
request.app.state.WEBUI_NAME,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from open_webui.models.models import Models
|
|||
|
||||
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
|
@ -130,19 +129,23 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
|||
models = models + arena_models
|
||||
|
||||
global_action_ids = [
|
||||
function.id for function in Functions.get_global_action_functions()
|
||||
function.id for function in await Functions.get_global_action_functions()
|
||||
]
|
||||
enabled_action_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("action", active_only=True)
|
||||
for function in await Functions.get_functions_by_type(
|
||||
"action", active_only=True
|
||||
)
|
||||
]
|
||||
|
||||
global_filter_ids = [
|
||||
function.id for function in Functions.get_global_filter_functions()
|
||||
function.id for function in await Functions.get_global_filter_functions()
|
||||
]
|
||||
enabled_filter_ids = [
|
||||
function.id
|
||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||
for function in await Functions.get_functions_by_type(
|
||||
"filter", active_only=True
|
||||
)
|
||||
]
|
||||
|
||||
custom_models = await Models.get_all_models()
|
||||
|
|
@ -265,8 +268,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
|||
}
|
||||
]
|
||||
|
||||
def get_function_module_by_id(function_id):
|
||||
function_module, _, _ = get_function_module_from_cache(request, function_id)
|
||||
async def get_function_module_by_id(function_id):
|
||||
function_module, _, _ = await get_function_module_from_cache(
|
||||
request, function_id
|
||||
)
|
||||
return function_module
|
||||
|
||||
for model in models:
|
||||
|
|
@ -283,22 +288,22 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
|||
|
||||
model["actions"] = []
|
||||
for action_id in action_ids:
|
||||
action_function = Functions.get_function_by_id(action_id)
|
||||
action_function = await Functions.get_function_by_id(action_id)
|
||||
if action_function is None:
|
||||
raise Exception(f"Action not found: {action_id}")
|
||||
|
||||
function_module = get_function_module_by_id(action_id)
|
||||
function_module = await get_function_module_by_id(action_id)
|
||||
model["actions"].extend(
|
||||
get_action_items_from_module(action_function, function_module)
|
||||
)
|
||||
|
||||
model["filters"] = []
|
||||
for filter_id in filter_ids:
|
||||
filter_function = Functions.get_function_by_id(filter_id)
|
||||
filter_function = await Functions.get_function_by_id(filter_id)
|
||||
if filter_function is None:
|
||||
raise Exception(f"Filter not found: {filter_id}")
|
||||
|
||||
function_module = get_function_module_by_id(filter_id)
|
||||
function_module = await get_function_module_by_id(filter_id)
|
||||
|
||||
if getattr(function_module, "toggle", None):
|
||||
model["filters"].extend(
|
||||
|
|
|
|||
|
|
@ -115,15 +115,15 @@ async def load_tool_module_by_id(tool_id, content=None):
|
|||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
def load_function_module_by_id(function_id: str, content: str | None = None):
|
||||
async def load_function_module_by_id(function_id: str, content: str | None = None):
|
||||
if content is None:
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
function = await Functions.get_function_by_id(function_id)
|
||||
if not function:
|
||||
raise Exception(f"Function not found: {function_id}")
|
||||
content = function.content
|
||||
|
||||
content = replace_imports(content)
|
||||
Functions.update_function_by_id(function_id, {"content": content})
|
||||
await Functions.update_function_by_id(function_id, {"content": content})
|
||||
else:
|
||||
frontmatter = extract_frontmatter(content)
|
||||
install_frontmatter_requirements(frontmatter.get("requirements", ""))
|
||||
|
|
@ -160,19 +160,19 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
|
|||
# Cleanup by removing the module in case of error
|
||||
del sys.modules[module_name]
|
||||
|
||||
Functions.update_function_by_id(function_id, {"is_active": False})
|
||||
await Functions.update_function_by_id(function_id, {"is_active": False})
|
||||
raise e
|
||||
finally:
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
|
||||
def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||
async def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||
if load_from_db:
|
||||
# Always load from the database by default
|
||||
# This is useful for hooks like "inlet" or "outlet" where the content might change
|
||||
# and we want to ensure the latest content is used.
|
||||
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
function = await Functions.get_function_by_id(function_id)
|
||||
if not function:
|
||||
raise Exception(f"Function not found: {function_id}")
|
||||
content = function.content
|
||||
|
|
@ -181,7 +181,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
|
|||
if new_content != content:
|
||||
content = new_content
|
||||
# Update the function content in the database
|
||||
Functions.update_function_by_id(function_id, {"content": content})
|
||||
await Functions.update_function_by_id(function_id, {"content": content})
|
||||
|
||||
if (
|
||||
hasattr(request.app.state, "FUNCTION_CONTENTS")
|
||||
|
|
@ -193,7 +193,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
|
|||
if request.app.state.FUNCTION_CONTENTS[function_id] == content:
|
||||
return request.app.state.FUNCTIONS[function_id], None, None
|
||||
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function_module, function_type, frontmatter = await load_function_module_by_id(
|
||||
function_id, content
|
||||
)
|
||||
else:
|
||||
|
|
@ -206,7 +206,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
|
|||
):
|
||||
return request.app.state.FUNCTIONS[function_id], None, None
|
||||
|
||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
||||
function_module, function_type, frontmatter = await load_function_module_by_id(
|
||||
function_id
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue