mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
wip: functions
This commit is contained in:
parent
3b9e454fb4
commit
9a210e743d
9 changed files with 107 additions and 90 deletions
|
|
@ -30,7 +30,6 @@ from open_webui.models.functions import Functions
|
||||||
from open_webui.models.models import Models
|
from open_webui.models.models import Models
|
||||||
|
|
||||||
from open_webui.utils.plugin import (
|
from open_webui.utils.plugin import (
|
||||||
load_function_module_by_id,
|
|
||||||
get_function_module_from_cache,
|
get_function_module_from_cache,
|
||||||
)
|
)
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools
|
||||||
|
|
@ -56,17 +55,17 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
async def get_function_module_by_id(request: Request, pipe_id: str):
|
||||||
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
function_module, _, _ = await get_function_module_from_cache(request, pipe_id)
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
valves = await Functions.get_function_valves_by_id(pipe_id)
|
||||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||||
return function_module
|
return function_module
|
||||||
|
|
||||||
|
|
||||||
async def get_function_models(request):
|
async def get_function_models(request):
|
||||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
pipes = await Functions.get_functions_by_type("pipe", active_only=True)
|
||||||
pipe_models = []
|
pipe_models = []
|
||||||
|
|
||||||
for pipe in pipes:
|
for pipe in pipes:
|
||||||
|
|
|
||||||
|
|
@ -477,7 +477,8 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
if SAFE_MODE:
|
if SAFE_MODE:
|
||||||
print("SAFE MODE ENABLED")
|
print("SAFE MODE ENABLED")
|
||||||
Functions.deactivate_all_functions()
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(Functions.deactivate_all_functions())
|
||||||
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[FunctionResponse])
|
@router.get("/", response_model=list[FunctionResponse])
|
||||||
async def get_functions(user=Depends(get_verified_user)):
|
async def get_functions(user=Depends(get_verified_user)):
|
||||||
return Functions.get_functions()
|
return await Functions.get_functions()
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -48,7 +48,7 @@ async def get_functions(user=Depends(get_verified_user)):
|
||||||
|
|
||||||
@router.get("/export", response_model=list[FunctionModel])
|
@router.get("/export", response_model=list[FunctionModel])
|
||||||
async def get_functions(user=Depends(get_admin_user)):
|
async def get_functions(user=Depends(get_admin_user)):
|
||||||
return Functions.get_functions()
|
return await Functions.get_functions()
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
@ -142,12 +142,14 @@ async def sync_functions(
|
||||||
try:
|
try:
|
||||||
for function in form_data.functions:
|
for function in form_data.functions:
|
||||||
function.content = replace_imports(function.content)
|
function.content = replace_imports(function.content)
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
function_module, function_type, frontmatter = (
|
||||||
function.id,
|
await load_function_module_by_id(
|
||||||
content=function.content,
|
function.id,
|
||||||
|
content=function.content,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return Functions.sync_functions(user.id, form_data.functions)
|
return await Functions.sync_functions(user.id, form_data.functions)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Failed to load a function: {e}")
|
log.exception(f"Failed to load a function: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -173,20 +175,24 @@ async def create_new_function(
|
||||||
|
|
||||||
form_data.id = form_data.id.lower()
|
form_data.id = form_data.id.lower()
|
||||||
|
|
||||||
function = Functions.get_function_by_id(form_data.id)
|
function = await Functions.get_function_by_id(form_data.id)
|
||||||
if function is None:
|
if function is None:
|
||||||
try:
|
try:
|
||||||
form_data.content = replace_imports(form_data.content)
|
form_data.content = replace_imports(form_data.content)
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
function_module, function_type, frontmatter = (
|
||||||
form_data.id,
|
await load_function_module_by_id(
|
||||||
content=form_data.content,
|
form_data.id,
|
||||||
|
content=form_data.content,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
form_data.meta.manifest = frontmatter
|
form_data.meta.manifest = frontmatter
|
||||||
|
|
||||||
FUNCTIONS = request.app.state.FUNCTIONS
|
FUNCTIONS = request.app.state.FUNCTIONS
|
||||||
FUNCTIONS[form_data.id] = function_module
|
FUNCTIONS[form_data.id] = function_module
|
||||||
|
|
||||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
function = await Functions.insert_new_function(
|
||||||
|
user.id, function_type, form_data
|
||||||
|
)
|
||||||
|
|
||||||
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
function_cache_dir = CACHE_DIR / "functions" / form_data.id
|
||||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -218,7 +224,7 @@ async def create_new_function(
|
||||||
|
|
||||||
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
||||||
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
return function
|
return function
|
||||||
|
|
@ -236,9 +242,9 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
|
||||||
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function = Functions.update_function_by_id(
|
function = await Functions.update_function_by_id(
|
||||||
id, {"is_active": not function.is_active}
|
id, {"is_active": not function.is_active}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -263,9 +269,9 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
|
|
||||||
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
|
||||||
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function = Functions.update_function_by_id(
|
function = await unctions.update_function_by_id(
|
||||||
id, {"is_global": not function.is_global}
|
id, {"is_global": not function.is_global}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -294,7 +300,7 @@ async def update_function_by_id(
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
form_data.content = replace_imports(form_data.content)
|
form_data.content = replace_imports(form_data.content)
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
function_module, function_type, frontmatter = await load_function_module_by_id(
|
||||||
id, content=form_data.content
|
id, content=form_data.content
|
||||||
)
|
)
|
||||||
form_data.meta.manifest = frontmatter
|
form_data.meta.manifest = frontmatter
|
||||||
|
|
@ -305,7 +311,7 @@ async def update_function_by_id(
|
||||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||||
log.debug(updated)
|
log.debug(updated)
|
||||||
|
|
||||||
function = Functions.update_function_by_id(id, updated)
|
function = await Functions.update_function_by_id(id, updated)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
return function
|
return function
|
||||||
|
|
@ -331,7 +337,7 @@ async def update_function_by_id(
|
||||||
async def delete_function_by_id(
|
async def delete_function_by_id(
|
||||||
request: Request, id: str, user=Depends(get_admin_user)
|
request: Request, id: str, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
result = Functions.delete_function_by_id(id)
|
result = await Functions.delete_function_by_id(id)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
FUNCTIONS = request.app.state.FUNCTIONS
|
FUNCTIONS = request.app.state.FUNCTIONS
|
||||||
|
|
@ -348,10 +354,10 @@ async def delete_function_by_id(
|
||||||
|
|
||||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
try:
|
try:
|
||||||
valves = Functions.get_function_valves_by_id(id)
|
valves = await Functions.get_function_valves_by_id(id)
|
||||||
return valves
|
return valves
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -374,10 +380,10 @@ async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||||
async def get_function_valves_spec_by_id(
|
async def get_function_valves_spec_by_id(
|
||||||
request: Request, id: str, user=Depends(get_admin_user)
|
request: Request, id: str, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = (
|
||||||
request, id
|
await get_function_module_from_cache(request, id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(function_module, "Valves"):
|
if hasattr(function_module, "Valves"):
|
||||||
|
|
@ -400,10 +406,10 @@ async def get_function_valves_spec_by_id(
|
||||||
async def update_function_valves_by_id(
|
async def update_function_valves_by_id(
|
||||||
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = (
|
||||||
request, id
|
await get_function_module_from_cache(request, id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(function_module, "Valves"):
|
if hasattr(function_module, "Valves"):
|
||||||
|
|
@ -412,7 +418,7 @@ async def update_function_valves_by_id(
|
||||||
try:
|
try:
|
||||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||||
valves = Valves(**form_data)
|
valves = Valves(**form_data)
|
||||||
Functions.update_function_valves_by_id(id, valves.model_dump())
|
await Functions.update_function_valves_by_id(id, valves.model_dump())
|
||||||
return valves.model_dump()
|
return valves.model_dump()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error updating function values by id {id}: {e}")
|
log.exception(f"Error updating function values by id {id}: {e}")
|
||||||
|
|
@ -461,10 +467,10 @@ async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user
|
||||||
async def get_function_user_valves_spec_by_id(
|
async def get_function_user_valves_spec_by_id(
|
||||||
request: Request, id: str, user=Depends(get_verified_user)
|
request: Request, id: str, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
function = Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = (
|
||||||
request, id
|
await get_function_module_from_cache(request, id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(function_module, "UserValves"):
|
if hasattr(function_module, "UserValves"):
|
||||||
|
|
@ -485,8 +491,8 @@ async def update_function_user_valves_by_id(
|
||||||
function = await Functions.get_function_by_id(id)
|
function = await Functions.get_function_by_id(id)
|
||||||
|
|
||||||
if function:
|
if function:
|
||||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
function_module, function_type, frontmatter = (
|
||||||
request, id
|
await get_function_module_from_cache(request, id)
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(function_module, "UserValves"):
|
if hasattr(function_module, "UserValves"):
|
||||||
|
|
|
||||||
|
|
@ -1433,7 +1433,7 @@ async def generate_openai_completion(
|
||||||
if ":" not in model_id:
|
if ":" not in model_id:
|
||||||
model_id = f"{model_id}:latest"
|
model_id = f"{model_id}:latest"
|
||||||
|
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = await Models.get_model_by_id(model_id)
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
payload["model"] = model_info.base_model_id
|
payload["model"] = model_info.base_model_id
|
||||||
|
|
@ -1514,7 +1514,7 @@ async def generate_openai_chat_completion(
|
||||||
if ":" not in model_id:
|
if ":" not in model_id:
|
||||||
model_id = f"{model_id}:latest"
|
model_id = f"{model_id}:latest"
|
||||||
|
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = await Models.get_model_by_id(model_id)
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
payload["model"] = model_info.base_model_id
|
payload["model"] = model_info.base_model_id
|
||||||
|
|
@ -1629,7 +1629,7 @@ async def get_openai_models(
|
||||||
# Filter models based on user access control
|
# Filter models based on user access control
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
for model in models:
|
for model in models:
|
||||||
model_info = Models.get_model_by_id(model["id"])
|
model_info = await Models.get_model_by_id(model["id"])
|
||||||
if model_info:
|
if model_info:
|
||||||
if user.id == model_info.user_id or (
|
if user.id == model_info.user_id or (
|
||||||
await has_access(
|
await has_access(
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ from open_webui.models.models import Models
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.plugin import (
|
from open_webui.utils.plugin import (
|
||||||
load_function_module_by_id,
|
|
||||||
get_function_module_from_cache,
|
get_function_module_from_cache,
|
||||||
)
|
)
|
||||||
from open_webui.utils.models import get_all_models, check_model_access
|
from open_webui.utils.models import get_all_models, check_model_access
|
||||||
|
|
@ -328,8 +327,8 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filter_functions = [
|
filter_functions = [
|
||||||
Functions.get_function_by_id(filter_id)
|
await Functions.get_function_by_id(filter_id)
|
||||||
for filter_id in get_sorted_filter_ids(
|
for filter_id in await get_sorted_filter_ids(
|
||||||
request, model, metadata.get("filter_ids", [])
|
request, model, metadata.get("filter_ids", [])
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -352,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||||
else:
|
else:
|
||||||
sub_action_id = None
|
sub_action_id = None
|
||||||
|
|
||||||
action = Functions.get_function_by_id(action_id)
|
action = await Functions.get_function_by_id(action_id)
|
||||||
if not action:
|
if not action:
|
||||||
raise Exception(f"Action not found: {action_id}")
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
|
|
@ -390,10 +389,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
function_module, _, _ = get_function_module_from_cache(request, action_id)
|
function_module, _, _ = await get_function_module_from_cache(request, action_id)
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
valves = Functions.get_function_valves_by_id(action_id)
|
valves = await Functions.get_function_valves_by_id(action_id)
|
||||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||||
|
|
||||||
if hasattr(function_module, "action"):
|
if hasattr(function_module, "action"):
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from open_webui.utils.plugin import (
|
from open_webui.utils.plugin import (
|
||||||
load_function_module_by_id,
|
|
||||||
get_function_module_from_cache,
|
get_function_module_from_cache,
|
||||||
)
|
)
|
||||||
from open_webui.models.functions import Functions
|
from open_webui.models.functions import Functions
|
||||||
|
|
@ -12,35 +11,39 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
def get_function_module(request, function_id, load_from_db=True):
|
async def get_function_module(request, function_id, load_from_db=True):
|
||||||
"""
|
"""
|
||||||
Get the function module by its ID.
|
Get the function module by its ID.
|
||||||
"""
|
"""
|
||||||
function_module, _, _ = get_function_module_from_cache(
|
function_module, _, _ = await get_function_module_from_cache(
|
||||||
request, function_id, load_from_db
|
request, function_id, load_from_db
|
||||||
)
|
)
|
||||||
return function_module
|
return function_module
|
||||||
|
|
||||||
|
|
||||||
def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
async def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None):
|
||||||
def get_priority(function_id):
|
async def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = await Functions.get_function_by_id(function_id)
|
||||||
if function is not None:
|
if function is not None:
|
||||||
valves = Functions.get_function_valves_by_id(function_id)
|
valves = await Functions.get_function_valves_by_id(function_id)
|
||||||
return valves.get("priority", 0) if valves else 0
|
return valves.get("priority", 0) if valves else 0
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
filter_ids = [
|
||||||
|
function.id for function in await Functions.get_global_filter_functions()
|
||||||
|
]
|
||||||
if "info" in model and "meta" in model["info"]:
|
if "info" in model and "meta" in model["info"]:
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
filter_ids = list(set(filter_ids))
|
filter_ids = list(set(filter_ids))
|
||||||
active_filter_ids = [
|
active_filter_ids = [
|
||||||
function.id
|
function.id
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
for function in await Functions.get_functions_by_type(
|
||||||
|
"filter", active_only=True
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_active_status(filter_id):
|
async def get_active_status(filter_id):
|
||||||
function_module = get_function_module(request, filter_id)
|
function_module = await get_function_module(request, filter_id)
|
||||||
|
|
||||||
if getattr(function_module, "toggle", None):
|
if getattr(function_module, "toggle", None):
|
||||||
return filter_id in (enabled_filter_ids or [])
|
return filter_id in (enabled_filter_ids or [])
|
||||||
|
|
@ -48,10 +51,13 @@ def get_sorted_filter_ids(request, model: dict, enabled_filter_ids: list = None)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
active_filter_ids = [
|
active_filter_ids = [
|
||||||
filter_id for filter_id in active_filter_ids if get_active_status(filter_id)
|
filter_id
|
||||||
|
for filter_id in active_filter_ids
|
||||||
|
if await get_active_status(filter_id)
|
||||||
]
|
]
|
||||||
|
|
||||||
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
filter_ids = [fid for fid in filter_ids if fid in active_filter_ids]
|
||||||
|
|
||||||
filter_ids.sort(key=get_priority)
|
filter_ids.sort(key=get_priority)
|
||||||
|
|
||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
@ -68,7 +74,7 @@ async def process_filter_functions(
|
||||||
if not filter:
|
if not filter:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
function_module = get_function_module(
|
function_module = await get_function_module(
|
||||||
request, filter_id, load_from_db=(filter_type != "stream")
|
request, filter_id, load_from_db=(filter_type != "stream")
|
||||||
)
|
)
|
||||||
# Prepare handler function
|
# Prepare handler function
|
||||||
|
|
@ -82,7 +88,7 @@ async def process_filter_functions(
|
||||||
|
|
||||||
# Apply valves to the function
|
# Apply valves to the function
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||||
valves = Functions.get_function_valves_by_id(filter_id)
|
valves = await Functions.get_function_valves_by_id(filter_id)
|
||||||
function_module.valves = function_module.Valves(
|
function_module.valves = function_module.Valves(
|
||||||
**(valves if valves else {})
|
**(valves if valves else {})
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,6 @@ from open_webui.utils.misc import (
|
||||||
convert_logit_bias_input_to_json,
|
convert_logit_bias_input_to_json,
|
||||||
)
|
)
|
||||||
from open_webui.utils.tools import get_tools
|
from open_webui.utils.tools import get_tools
|
||||||
from open_webui.utils.plugin import load_function_module_by_id
|
|
||||||
from open_webui.utils.filter import (
|
from open_webui.utils.filter import (
|
||||||
get_sorted_filter_ids,
|
get_sorted_filter_ids,
|
||||||
process_filter_functions,
|
process_filter_functions,
|
||||||
|
|
@ -840,8 +839,8 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filter_functions = [
|
filter_functions = [
|
||||||
Functions.get_function_by_id(filter_id)
|
await Functions.get_function_by_id(filter_id)
|
||||||
for filter_id in get_sorted_filter_ids(
|
for filter_id in await get_sorted_filter_ids(
|
||||||
request, model, metadata.get("filter_ids", [])
|
request, model, metadata.get("filter_ids", [])
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -908,7 +907,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
tools_dict = {}
|
tools_dict = {}
|
||||||
|
|
||||||
if tool_ids:
|
if tool_ids:
|
||||||
tools_dict = get_tools(
|
tools_dict = await get_tools(
|
||||||
request,
|
request,
|
||||||
tool_ids,
|
tool_ids,
|
||||||
user,
|
user,
|
||||||
|
|
@ -1323,7 +1322,9 @@ async def process_chat_response(
|
||||||
|
|
||||||
# Send a webhook notification if the user is not active
|
# Send a webhook notification if the user is not active
|
||||||
if not get_active_status_by_user_id(user.id):
|
if not get_active_status_by_user_id(user.id):
|
||||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
webhook_url = await Users.get_user_webhook_url_by_id(
|
||||||
|
user.id
|
||||||
|
)
|
||||||
if webhook_url:
|
if webhook_url:
|
||||||
post_webhook(
|
post_webhook(
|
||||||
request.app.state.WEBUI_NAME,
|
request.app.state.WEBUI_NAME,
|
||||||
|
|
@ -1394,8 +1395,8 @@ async def process_chat_response(
|
||||||
"__model__": model,
|
"__model__": model,
|
||||||
}
|
}
|
||||||
filter_functions = [
|
filter_functions = [
|
||||||
Functions.get_function_by_id(filter_id)
|
await Functions.get_function_by_id(filter_id)
|
||||||
for filter_id in get_sorted_filter_ids(
|
for filter_id in await get_sorted_filter_ids(
|
||||||
request, model, metadata.get("filter_ids", [])
|
request, model, metadata.get("filter_ids", [])
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -2523,7 +2524,7 @@ async def process_chat_response(
|
||||||
|
|
||||||
# Send a webhook notification if the user is not active
|
# Send a webhook notification if the user is not active
|
||||||
if not get_active_status_by_user_id(user.id):
|
if not get_active_status_by_user_id(user.id):
|
||||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
webhook_url = await Users.get_user_webhook_url_by_id(user.id)
|
||||||
if webhook_url:
|
if webhook_url:
|
||||||
post_webhook(
|
post_webhook(
|
||||||
request.app.state.WEBUI_NAME,
|
request.app.state.WEBUI_NAME,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ from open_webui.models.models import Models
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.plugin import (
|
from open_webui.utils.plugin import (
|
||||||
load_function_module_by_id,
|
|
||||||
get_function_module_from_cache,
|
get_function_module_from_cache,
|
||||||
)
|
)
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
|
|
@ -130,19 +129,23 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
models = models + arena_models
|
models = models + arena_models
|
||||||
|
|
||||||
global_action_ids = [
|
global_action_ids = [
|
||||||
function.id for function in Functions.get_global_action_functions()
|
function.id for function in await Functions.get_global_action_functions()
|
||||||
]
|
]
|
||||||
enabled_action_ids = [
|
enabled_action_ids = [
|
||||||
function.id
|
function.id
|
||||||
for function in Functions.get_functions_by_type("action", active_only=True)
|
for function in await Functions.get_functions_by_type(
|
||||||
|
"action", active_only=True
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
global_filter_ids = [
|
global_filter_ids = [
|
||||||
function.id for function in Functions.get_global_filter_functions()
|
function.id for function in await Functions.get_global_filter_functions()
|
||||||
]
|
]
|
||||||
enabled_filter_ids = [
|
enabled_filter_ids = [
|
||||||
function.id
|
function.id
|
||||||
for function in Functions.get_functions_by_type("filter", active_only=True)
|
for function in await Functions.get_functions_by_type(
|
||||||
|
"filter", active_only=True
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
custom_models = await Models.get_all_models()
|
custom_models = await Models.get_all_models()
|
||||||
|
|
@ -265,8 +268,10 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_function_module_by_id(function_id):
|
async def get_function_module_by_id(function_id):
|
||||||
function_module, _, _ = get_function_module_from_cache(request, function_id)
|
function_module, _, _ = await get_function_module_from_cache(
|
||||||
|
request, function_id
|
||||||
|
)
|
||||||
return function_module
|
return function_module
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
|
|
@ -283,22 +288,22 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
||||||
|
|
||||||
model["actions"] = []
|
model["actions"] = []
|
||||||
for action_id in action_ids:
|
for action_id in action_ids:
|
||||||
action_function = Functions.get_function_by_id(action_id)
|
action_function = await Functions.get_function_by_id(action_id)
|
||||||
if action_function is None:
|
if action_function is None:
|
||||||
raise Exception(f"Action not found: {action_id}")
|
raise Exception(f"Action not found: {action_id}")
|
||||||
|
|
||||||
function_module = get_function_module_by_id(action_id)
|
function_module = await get_function_module_by_id(action_id)
|
||||||
model["actions"].extend(
|
model["actions"].extend(
|
||||||
get_action_items_from_module(action_function, function_module)
|
get_action_items_from_module(action_function, function_module)
|
||||||
)
|
)
|
||||||
|
|
||||||
model["filters"] = []
|
model["filters"] = []
|
||||||
for filter_id in filter_ids:
|
for filter_id in filter_ids:
|
||||||
filter_function = Functions.get_function_by_id(filter_id)
|
filter_function = await Functions.get_function_by_id(filter_id)
|
||||||
if filter_function is None:
|
if filter_function is None:
|
||||||
raise Exception(f"Filter not found: {filter_id}")
|
raise Exception(f"Filter not found: {filter_id}")
|
||||||
|
|
||||||
function_module = get_function_module_by_id(filter_id)
|
function_module = await get_function_module_by_id(filter_id)
|
||||||
|
|
||||||
if getattr(function_module, "toggle", None):
|
if getattr(function_module, "toggle", None):
|
||||||
model["filters"].extend(
|
model["filters"].extend(
|
||||||
|
|
|
||||||
|
|
@ -115,15 +115,15 @@ async def load_tool_module_by_id(tool_id, content=None):
|
||||||
os.unlink(temp_file.name)
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
def load_function_module_by_id(function_id: str, content: str | None = None):
|
async def load_function_module_by_id(function_id: str, content: str | None = None):
|
||||||
if content is None:
|
if content is None:
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = await Functions.get_function_by_id(function_id)
|
||||||
if not function:
|
if not function:
|
||||||
raise Exception(f"Function not found: {function_id}")
|
raise Exception(f"Function not found: {function_id}")
|
||||||
content = function.content
|
content = function.content
|
||||||
|
|
||||||
content = replace_imports(content)
|
content = replace_imports(content)
|
||||||
Functions.update_function_by_id(function_id, {"content": content})
|
await Functions.update_function_by_id(function_id, {"content": content})
|
||||||
else:
|
else:
|
||||||
frontmatter = extract_frontmatter(content)
|
frontmatter = extract_frontmatter(content)
|
||||||
install_frontmatter_requirements(frontmatter.get("requirements", ""))
|
install_frontmatter_requirements(frontmatter.get("requirements", ""))
|
||||||
|
|
@ -160,19 +160,19 @@ def load_function_module_by_id(function_id: str, content: str | None = None):
|
||||||
# Cleanup by removing the module in case of error
|
# Cleanup by removing the module in case of error
|
||||||
del sys.modules[module_name]
|
del sys.modules[module_name]
|
||||||
|
|
||||||
Functions.update_function_by_id(function_id, {"is_active": False})
|
await Functions.update_function_by_id(function_id, {"is_active": False})
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
os.unlink(temp_file.name)
|
os.unlink(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
def get_function_module_from_cache(request, function_id, load_from_db=True):
|
async def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||||
if load_from_db:
|
if load_from_db:
|
||||||
# Always load from the database by default
|
# Always load from the database by default
|
||||||
# This is useful for hooks like "inlet" or "outlet" where the content might change
|
# This is useful for hooks like "inlet" or "outlet" where the content might change
|
||||||
# and we want to ensure the latest content is used.
|
# and we want to ensure the latest content is used.
|
||||||
|
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = await Functions.get_function_by_id(function_id)
|
||||||
if not function:
|
if not function:
|
||||||
raise Exception(f"Function not found: {function_id}")
|
raise Exception(f"Function not found: {function_id}")
|
||||||
content = function.content
|
content = function.content
|
||||||
|
|
@ -181,7 +181,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||||
if new_content != content:
|
if new_content != content:
|
||||||
content = new_content
|
content = new_content
|
||||||
# Update the function content in the database
|
# Update the function content in the database
|
||||||
Functions.update_function_by_id(function_id, {"content": content})
|
await Functions.update_function_by_id(function_id, {"content": content})
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(request.app.state, "FUNCTION_CONTENTS")
|
hasattr(request.app.state, "FUNCTION_CONTENTS")
|
||||||
|
|
@ -193,7 +193,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||||
if request.app.state.FUNCTION_CONTENTS[function_id] == content:
|
if request.app.state.FUNCTION_CONTENTS[function_id] == content:
|
||||||
return request.app.state.FUNCTIONS[function_id], None, None
|
return request.app.state.FUNCTIONS[function_id], None, None
|
||||||
|
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
function_module, function_type, frontmatter = await load_function_module_by_id(
|
||||||
function_id, content
|
function_id, content
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -206,7 +206,7 @@ def get_function_module_from_cache(request, function_id, load_from_db=True):
|
||||||
):
|
):
|
||||||
return request.app.state.FUNCTIONS[function_id], None, None
|
return request.app.state.FUNCTIONS[function_id], None, None
|
||||||
|
|
||||||
function_module, function_type, frontmatter = load_function_module_by_id(
|
function_module, function_type, frontmatter = await load_function_module_by_id(
|
||||||
function_id
|
function_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue