diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 6eb5c1bbdb..422d121382 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -174,7 +174,7 @@ async def generate_function_chat_completion( pipe_id, _ = pipe_id.split(".", 1) return pipe_id - def get_function_params(function_module, form_data, user, extra_params=None): + async def get_function_params(function_module, form_data, user, extra_params=None): if extra_params is None: extra_params = {} @@ -187,7 +187,9 @@ async def generate_function_chat_completion( } if "__user__" in params and hasattr(function_module, "UserValves"): - user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) + user_valves = await Functions.get_user_valves_by_id_and_user_id( + pipe_id, user.id + ) try: params["__user__"]["valves"] = function_module.UserValves(**user_valves) except Exception as e: @@ -232,7 +234,7 @@ async def generate_function_chat_completion( "__metadata__": metadata, "__request__": request, } - extra_params["__tools__"] = get_tools( + extra_params["__tools__"] = await get_tools( request, tool_ids, user, @@ -261,7 +263,7 @@ async def generate_function_chat_completion( function_module = get_function_module_by_id(request, pipe_id) pipe = function_module.pipe - params = get_function_params(function_module, form_data, user, extra_params) + params = await get_function_params(function_module, form_data, user, extra_params) if form_data.get("stream", False): diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 4e867103a8..d3b302ec5c 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -530,7 +530,7 @@ async def lifespan(app: FastAPI): # This should be blocking (sync) so functions are not deactivated on first /get_models calls # when the first user lands on the / route. log.info("Installing external dependencies of functions and tools...") - install_tool_and_function_dependencies() + await install_tool_and_function_dependencies() app.state.redis = get_redis_connection( redis_url=REDIS_URL, @@ -1267,11 +1267,11 @@ if audit_level != AuditLevel.NONE: async def get_models( request: Request, refresh: bool = False, user=Depends(get_verified_user) ): - def get_filtered_models(models, user): + async def get_filtered_models(models, user): filtered_models = [] for model in models: if model.get("arena"): - if has_access( + if await has_access( user.id, type="read", access_control=model.get("info", {}) @@ -1286,7 +1286,7 @@ async def get_models( if ( (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) or user.id == model_info.user_id - or has_access( + or await has_access( user.id, type="read", access_control=model_info.access_control ) ): @@ -1334,7 +1334,7 @@ async def get_models( user.role == "user" or (user.role == "admin" and not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) ) and not BYPASS_MODEL_ACCESS_CONTROL: - models = get_filtered_models(models, user) + models = await get_filtered_models(models, user) log.debug( f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}" @@ -1408,7 +1408,7 @@ async def chat_completion( user.role != "admin" or not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS ): try: - check_model_access(user, model) + await check_model_access(user, model) except Exception as e: raise e else: @@ -1628,7 +1628,7 @@ async def get_app_config(request: Request): if data is not None and "id" in data: user = await Users.get_user_by_id(data["id"]) - user_count = Users.get_num_users() + user_count = await Users.get_num_users() onboarding = False if user is None: diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index ac018f5a7a..f034e2b332 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -101,7 +101,7 @@ class ChannelTable: channel for channel in channels if channel.user_id == user_id - or has_access(user_id, permission, channel.access_control) + or await has_access(user_id, permission, channel.access_control) ] async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]: diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index 6dff07ae4e..729d3608d9 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -143,7 +143,7 @@ class KnowledgeTable: ) return knowledge_bases - def get_knowledge_bases_by_user_id( + async def get_knowledge_bases_by_user_id( self, user_id: str, permission: str = "write" ) -> list[KnowledgeUserModel]: knowledge_bases = self.get_knowledge_bases() @@ -151,31 +151,31 @@ class KnowledgeTable: knowledge_base for knowledge_base in knowledge_bases if knowledge_base.user_id == user_id - or has_access(user_id, permission, knowledge_base.access_control) + or await has_access(user_id, permission, knowledge_base.access_control) ] - def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: + async def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: try: async with get_db() as db: - knowledge = db.query(Knowledge).filter_by(id=id).first() + knowledge = await db.query(Knowledge).filter_by(id=id).first() return KnowledgeModel.model_validate(knowledge) if knowledge else None except Exception: return None - def update_knowledge_by_id( + async def update_knowledge_by_id( self, id: str, form_data: KnowledgeForm, overwrite: bool = False ) -> Optional[KnowledgeModel]: try: async with get_db() as db: - knowledge = self.get_knowledge_by_id(id=id) - db.query(Knowledge).filter_by(id=id).update( + knowledge = await self.get_knowledge_by_id(id=id) + await db.query(Knowledge).filter_by(id=id).update( { **form_data.model_dump(), "updated_at": int(time.time()), } ) - db.commit() - return self.get_knowledge_by_id(id=id) + await db.commit() + return await self.get_knowledge_by_id(id=id) except Exception as e: log.exception(e) return None diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 202dd9ac5f..28fd9bd51b 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -195,7 +195,7 @@ class ModelsTable: for model in db.query(Model).filter(Model.base_model_id == None).all() ] - def get_models_by_user_id( + async def get_models_by_user_id( self, user_id: str, permission: str = "write" ) -> list[ModelUserResponse]: models = self.get_models() @@ -203,7 +203,7 @@ class ModelsTable: model for model in models if model.user_id == user_id - or has_access(user_id, permission, model.access_control) + or await has_access(user_id, permission, model.access_control) ] def get_model_by_id(self, id: str) -> Optional[ModelModel]: diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index ec24a7de8b..5df49050ac 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -96,20 +96,20 @@ class NoteTable: db.commit() return note - def get_notes(self) -> list[NoteModel]: + async def get_notes(self) -> list[NoteModel]: async with get_db() as db: - notes = db.query(Note).order_by(Note.updated_at.desc()).all() + notes = await db.query(Note).order_by(Note.updated_at.desc()).all() return [NoteModel.model_validate(note) for note in notes] - def get_notes_by_user_id( + async def get_notes_by_user_id( self, user_id: str, permission: str = "write" ) -> list[NoteModel]: - notes = self.get_notes() + notes = await self.get_notes() return [ note for note in notes if note.user_id == user_id - or has_access(user_id, permission, note.access_control) + or await has_access(user_id, permission, note.access_control) ] def get_note_by_id(self, id: str) -> Optional[NoteModel]: diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index faeb0b85a1..6c94db504b 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -127,7 +127,7 @@ class PromptsTable: prompt for prompt in prompts if prompt.user_id == user_id - or has_access(user_id, permission, prompt.access_control) + or await has_access(user_id, permission, prompt.access_control) ] def update_prompt_by_command( diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index f9a5d56829..2899648fc4 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -107,7 +107,7 @@ class ToolValves(BaseModel): class ToolsTable: - def insert_new_tool( + async def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: list[dict] ) -> Optional[ToolModel]: async with get_db() as db: @@ -123,9 +123,9 @@ class ToolsTable: try: result = Tool(**tool.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return ToolModel.model_validate(result) else: @@ -134,10 +134,10 @@ class ToolsTable: log.exception(f"Error creating a new tool: {e}") return None - def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + async def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: async with get_db() as db: - tool = db.get(Tool, id) + tool = await db.get(Tool, id) return ToolModel.model_validate(tool) except Exception: return None @@ -145,7 +145,7 @@ class ToolsTable: async def get_tools(self) -> list[ToolUserModel]: async with get_db() as db: tools = [] - for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): + for tool in await db.query(Tool).order_by(Tool.updated_at.desc()).all(): user = await Users.get_user_by_id(tool.user_id) tools.append( ToolUserModel.model_validate( @@ -157,35 +157,37 @@ class ToolsTable: ) return tools - def get_tools_by_user_id( + async def get_tools_by_user_id( self, user_id: str, permission: str = "write" ) -> list[ToolUserModel]: - tools = self.get_tools() + tools = await self.get_tools() return [ tool for tool in tools if tool.user_id == user_id - or has_access(user_id, permission, tool.access_control) + or await has_access(user_id, permission, tool.access_control) ] - def get_tool_valves_by_id(self, id: str) -> Optional[dict]: + async def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: async with get_db() as db: - tool = db.get(Tool, id) + tool = await db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: log.exception(f"Error getting tool valves by id {id}: {e}") return None - def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: + async def update_tool_valves_by_id( + self, id: str, valves: dict + ) -> Optional[ToolValves]: try: async with get_db() as db: - db.query(Tool).filter_by(id=id).update( + await db.query(Tool).filter_by(id=id).update( {"valves": valves, "updated_at": int(time.time())} ) - db.commit() - return self.get_tool_by_id(id) + await db.commit() + return await self.get_tool_by_id(id) except Exception: return None @@ -225,7 +227,7 @@ class ToolsTable: user_settings["tools"]["valves"][id] = valves # Update the user settings in the database - Users.update_user_by_id(user_id, {"settings": user_settings}) + await Users.update_user_by_id(user_id, {"settings": user_settings}) return user_settings["tools"]["valves"][id] except Exception as e: @@ -234,25 +236,25 @@ class ToolsTable: ) return None - def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: + async def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: async with get_db() as db: - db.query(Tool).filter_by(id=id).update( + await db.query(Tool).filter_by(id=id).update( {**updated, "updated_at": int(time.time())} ) - db.commit() + await db.commit() - tool = db.query(Tool).get(id) - db.refresh(tool) + tool = await db.query(Tool).get(id) + await db.refresh(tool) return ToolModel.model_validate(tool) except Exception: return None - def delete_tool_by_id(self, id: str) -> bool: + async def delete_tool_by_id(self, id: str) -> bool: try: async with get_db() as db: - db.query(Tool).filter_by(id=id).delete() - db.commit() + await db.query(Tool).filter_by(id=id).delete() + await db.commit() return True except Exception: diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 7e13cbf164..69ace9668f 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -517,7 +517,7 @@ def get_sources_from_items( if note and ( user.role == "admin" or note.user_id == user.id - or has_access(user.id, "read", note.access_control) + or await has_access(user.id, "read", note.access_control) ): # User has access to the note query_result = { @@ -581,7 +581,7 @@ def get_sources_from_items( if knowledge_base and ( user.role == "admin" - or has_access(user.id, "read", knowledge_base.access_control) + or await has_access(user.id, "read", knowledge_base.access_control) ): file_ids = knowledge_base.data.get("file_ids", []) diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index cab71d076b..aa562d18e5 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -135,7 +135,7 @@ async def update_profile( form_data: UpdateProfileForm, session_user=Depends(get_verified_user) ): if session_user: - user = Users.update_user_by_id( + user = await Users.update_user_by_id( session_user.id, {"profile_image_url": form_data.profile_image_url, "name": form_data.name}, ) @@ -348,12 +348,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not connection_user.bind(): raise HTTPException(400, "Authentication failed.") - user = Users.get_user_by_email(email) + user = await Users.get_user_by_email(email) if not user: try: role = ( "admin" - if not Users.has_users() + if not await Users.has_users() else request.app.state.config.DEFAULT_USER_ROLE ) @@ -463,7 +463,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): if WEBUI_AUTH_TRUSTED_NAME_HEADER: name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) - if not Users.get_user_by_email(email.lower()): + if not await Users.get_user_by_email(email.lower()): await signup( request, response, @@ -484,10 +484,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm): admin_email = "admin@localhost" admin_password = "admin" - if Users.get_user_by_email(admin_email.lower()): + if await Users.get_user_by_email(admin_email.lower()): user = await Auths.authenticate_user(admin_email.lower(), admin_password) else: - if Users.has_users(): + if await Users.has_users(): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( @@ -556,7 +556,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SessionUserResponse) async def signup(request: Request, response: Response, form_data: SignupForm): - has_users = Users.has_users() + has_users = await Users.has_users() if WEBUI_AUTH: if ( @@ -577,7 +577,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if await Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -733,7 +733,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if await Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -780,11 +780,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") if admin_email: - admin = Users.get_user_by_email(admin_email) + admin = await Users.get_user_by_email(admin_email) if admin: admin_name = admin.name else: - admin = Users.get_first_user() + admin = await Users.get_first_user() if admin: admin_email = admin.email admin_name = admin.name @@ -1025,7 +1025,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): ) api_key = create_api_key() - success = Users.update_user_api_key_by_id(user.id, api_key) + success = await Users.update_user_api_key_by_id(user.id, api_key) if success: return { @@ -1038,14 +1038,14 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)): # delete api key @router.delete("/api_key", response_model=bool) async def delete_api_key(user=Depends(get_current_user)): - success = Users.update_user_api_key_by_id(user.id, None) + success = await Users.update_user_api_key_by_id(user.id, None) return success # get api key @router.get("/api_key", response_model=ApiKey) async def get_api_key(user=Depends(get_current_user)): - api_key = Users.get_user_api_key_by_id(user.id) + api_key = await Users.get_user_api_key_by_id(user.id) if api_key: return { "api_key": api_key, diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index df606d61cc..39958a0ccb 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -80,7 +80,7 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -157,7 +157,7 @@ async def get_channel_messages( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -197,7 +197,7 @@ async def get_channel_messages( async def send_notification(name, webui_url, channel, message, active_user_ids): - users = get_users_with_access("read", channel.access_control) + users = await get_users_with_access("read", channel.access_control) for user in users: if user.id in active_user_ids: @@ -236,7 +236,7 @@ async def post_new_message( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -341,7 +341,7 @@ async def get_channel_message( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -390,7 +390,7 @@ async def get_channel_thread_messages( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -452,7 +452,9 @@ async def update_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access(user.id, type="read", access_control=channel.access_control) + and not await has_access( + user.id, type="read", access_control=channel.access_control + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -512,7 +514,7 @@ async def add_reaction_to_message( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -578,7 +580,7 @@ async def remove_reaction_by_id_and_user_id_and_name( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( + if user.role != "admin" and not await has_access( user.id, type="read", access_control=channel.access_control ): raise HTTPException( @@ -661,7 +663,9 @@ async def delete_message_by_id( if ( user.role != "admin" and message.user_id != user.id - and not has_access(user.id, type="read", access_control=channel.access_control) + and not await has_access( + user.id, type="read", access_control=channel.access_control + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index b5beb96cf0..a37f1cb3a6 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -440,10 +440,10 @@ async def update_function_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: try: - user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id) + user_valves = await Functions.get_user_valves_by_id_and_user_id(id, user.id) return user_valves except Exception as e: raise HTTPException( @@ -482,7 +482,7 @@ async def get_function_user_valves_spec_by_id( async def update_function_user_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - function = Functions.get_function_by_id(id) + function = await Functions.get_function_by_id(id) if function: function_module, function_type, frontmatter = get_function_module_from_cache( @@ -495,7 +495,7 @@ async def update_function_user_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) - Functions.update_user_valves_by_id_and_user_id( + await Functions.update_user_valves_by_id_and_user_id( id, user.id, user_valves.model_dump() ) return user_valves.model_dump() diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index bf286fe001..e5b31ff8e3 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -90,7 +90,7 @@ async def update_group_by_id( ): try: if form_data.user_ids: - form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) + form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids) group = Groups.update_group_by_id(id, form_data) if group: @@ -119,7 +119,7 @@ async def add_user_to_group( ): try: if form_data.user_ids: - form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) + form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids) group = Groups.add_users_to_group(id, form_data.user_ids) if group: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 69198816b3..555581a95d 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -261,11 +261,11 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( user.role == "admin" or knowledge.user_id == user.id - or has_access(user.id, "read", knowledge.access_control) + or await has_access(user.id, "read", knowledge.access_control) ): file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_file_metadatas_by_ids(file_ids) + files = await Files.get_file_metadatas_by_ids(file_ids) return KnowledgeFilesResponse( **knowledge.model_dump(), @@ -298,7 +298,7 @@ async def update_knowledge_by_id( # Is the user the original creator, in a group with write access, or an admin if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): raise HTTPException( @@ -348,7 +348,7 @@ def add_file_to_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): raise HTTPException( @@ -432,7 +432,7 @@ def update_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): @@ -503,7 +503,7 @@ def remove_file_from_knowledge_by_id( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): raise HTTPException( @@ -591,7 +591,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): raise HTTPException( @@ -654,7 +654,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)): if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): raise HTTPException( @@ -697,7 +697,7 @@ def add_files_to_knowledge_batch( if ( knowledge.user_id != user.id - and not has_access(user.id, "write", knowledge.access_control) + and not await has_access(user.id, "write", knowledge.access_control) and user.role != "admin" ): raise HTTPException( diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index e1a5ec1937..7ae7f07301 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -119,7 +119,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): if ( (user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS) or model.user_id == user.id - or has_access(user.id, "read", model.access_control) + or await has_access(user.id, "read", model.access_control) ): return model else: @@ -141,7 +141,7 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)): if ( user.role == "admin" or model.user_id == user.id - or has_access(user.id, "write", model.access_control) + or await has_access(user.id, "write", model.access_control) ): model = Models.toggle_model_by_id(id) @@ -185,7 +185,7 @@ async def update_model_by_id( if ( model.user_id != user.id - and not has_access(user.id, "write", model.access_control) + and not await has_access(user.id, "write", model.access_control) and user.role != "admin" ): raise HTTPException( @@ -204,7 +204,7 @@ async def update_model_by_id( @router.delete("/model/delete", response_model=bool) async def delete_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) + model = await Models.get_model_by_id(id) if not model: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -214,18 +214,18 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)): if ( user.role != "admin" and model.user_id != user.id - and not has_access(user.id, "write", model.access_control) + and not await has_access(user.id, "write", model.access_control) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - result = Models.delete_model_by_id(id) + result = await Models.delete_model_by_id(id) return result @router.delete("/delete/all", response_model=bool) async def delete_all_models(user=Depends(get_admin_user)): - result = Models.delete_all_models() + result = await Models.delete_all_models() return result diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index e9e63788be..d151338db8 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -133,7 +133,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us if user.role != "admin" and ( user.id != note.user_id - and (not has_access(user.id, type="read", access_control=note.access_control)) + and ( + not await has_access( + user.id, type="read", access_control=note.access_control + ) + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -167,7 +171,9 @@ async def update_note_by_id( if user.role != "admin" and ( user.id != note.user_id - and not has_access(user.id, type="write", access_control=note.access_control) + and not await has_access( + user.id, type="write", access_control=note.access_control + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -212,7 +218,9 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified if user.role != "admin" and ( user.id != note.user_id - and not has_access(user.id, type="write", access_control=note.access_control) + and not await has_access( + user.id, type="write", access_control=note.access_control + ) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 5894a72c35..7c0bdc3f4c 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -443,8 +443,10 @@ async def get_filtered_models(models, user): for model in models.get("models", []): model_info = Models.get_model_by_id(model["model"]) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + if user.id == model_info.user_id or ( + await has_access( + user.id, type="read", access_control=model_info.access_control + ) ): filtered_models.append(model) return filtered_models @@ -1336,8 +1338,10 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control + or ( + await has_access( + user.id, type="read", access_control=model_info.access_control + ) ) ): raise HTTPException( @@ -1442,8 +1446,10 @@ async def generate_openai_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control + or ( + await has_access( + user.id, type="read", access_control=model_info.access_control + ) ) ): raise HTTPException( @@ -1525,8 +1531,10 @@ async def generate_openai_chat_completion( if user.role == "user": if not ( user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control + or ( + await has_access( + user.id, type="read", access_control=model_info.access_control + ) ) ): raise HTTPException( @@ -1623,8 +1631,10 @@ async def get_openai_models( for model in models: model_info = Models.get_model_by_id(model["id"]) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + if user.id == model_info.user_id or ( + await has_access( + user.id, type="read", access_control=model_info.access_control + ) ): filtered_models.append(model) models = filtered_models diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index c8a3aebdd0..c643591669 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -385,8 +385,10 @@ async def get_filtered_models(models, user): for model in models.get("data", []): model_info = Models.get_model_by_id(model["id"]) if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + if user.id == model_info.user_id or ( + await has_access( + user.id, type="read", access_control=model_info.access_control + ) ): filtered_models.append(model) return filtered_models @@ -756,7 +758,7 @@ async def generate_chat_completion( if not bypass_filter and user.role == "user": if not ( user.id == model_info.user_id - or has_access( + or await has_access( user.id, type="read", access_control=model_info.access_control ) ): diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index afc00951fd..3988bbd236 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -79,13 +79,13 @@ async def create_new_prompt( @router.get("/command/{command}", response_model=Optional[PromptModel]) async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") + prompt = await Prompts.get_prompt_by_command(f"/{command}") if prompt: if ( user.role == "admin" or prompt.user_id == user.id - or has_access(user.id, "read", prompt.access_control) + or await has_access(user.id, "read", prompt.access_control) ): return prompt else: @@ -106,7 +106,7 @@ async def update_prompt_by_command( form_data: PromptForm, user=Depends(get_verified_user), ): - prompt = Prompts.get_prompt_by_command(f"/{command}") + prompt = await Prompts.get_prompt_by_command(f"/{command}") if not prompt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -116,7 +116,7 @@ async def update_prompt_by_command( # Is the user the original creator, in a group with write access, or an admin if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control) + and not await has_access(user.id, "write", prompt.access_control) and user.role != "admin" ): raise HTTPException( @@ -124,7 +124,7 @@ async def update_prompt_by_command( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + prompt = await Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: return prompt else: @@ -150,7 +150,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user) if ( prompt.user_id != user.id - and not has_access(user.id, "write", prompt.access_control) + and not await has_access(user.id, "write", prompt.access_control) and user.role != "admin" ): raise HTTPException( diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index 77a9f035b1..fb8c8263dd 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -297,7 +297,7 @@ def get_scim_auth( ) -def user_to_scim(user: UserModel, request: Request) -> SCIMUser: +async def user_to_scim(user: UserModel, request: Request) -> SCIMUser: """Convert internal User model to SCIM User""" # Parse display name into name components name_parts = user.name.split(" ", 1) if user.name else ["", ""] @@ -346,7 +346,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser: ) -def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: +async def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: """Convert internal Group model to SCIM Group""" members = [] for user_id in group.user_ids: @@ -493,20 +493,20 @@ async def get_users( # In production, you'd want a more robust filter parser if "userName eq" in filter: email = filter.split('"')[1] - user = Users.get_user_by_email(email) + user = await Users.get_user_by_email(email) users_list = [user] if user else [] total = 1 if user else 0 else: - response = Users.get_users(skip=skip, limit=limit) + response = await Users.get_users(skip=skip, limit=limit) users_list = response["users"] total = response["total"] else: - response = Users.get_users(skip=skip, limit=limit) + response = await Users.get_users(skip=skip, limit=limit) users_list = response["users"] total = response["total"] # Convert to SCIM format - scim_users = [user_to_scim(user, request) for user in users_list] + scim_users = [await user_to_scim(user, request) for user in users_list] return SCIMListResponse( totalResults=total, @@ -529,7 +529,7 @@ async def get_user( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" ) - return user_to_scim(user, request) + return await user_to_scim(user, request) @router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) @@ -540,7 +540,7 @@ async def create_user( ): """Create SCIM User""" # Check if user already exists - existing_user = Users.get_user_by_email(user_data.userName) + existing_user = await Users.get_user_by_email(user_data.userName) if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -579,7 +579,7 @@ async def create_user( detail="Failed to create user", ) - return user_to_scim(new_user, request) + return await user_to_scim(new_user, request) @router.put("/Users/{user_id}", response_model=SCIMUser) @@ -623,14 +623,14 @@ async def update_user( update_data["profile_image_url"] = user_data.photos[0].value # Update user - updated_user = Users.update_user_by_id(user_id, update_data) + updated_user = await Users.update_user_by_id(user_id, update_data) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update user", ) - return user_to_scim(updated_user, request) + return await user_to_scim(updated_user, request) @router.patch("/Users/{user_id}", response_model=SCIMUser) @@ -669,7 +669,7 @@ async def patch_user( # Update user if update_data: - updated_user = Users.update_user_by_id(user_id, update_data) + updated_user = await Users.update_user_by_id(user_id, update_data) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -678,7 +678,7 @@ async def patch_user( else: updated_user = user - return user_to_scim(updated_user, request) + return await user_to_scim(updated_user, request) @router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -695,7 +695,7 @@ async def delete_user( detail=f"User {user_id} not found", ) - success = Users.delete_user_by_id(user_id) + success = await Users.delete_user_by_id(user_id) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -716,7 +716,7 @@ async def get_groups( ): """List SCIM Groups""" # Get all groups - groups_list = Groups.get_groups() + groups_list = await Groups.get_groups() # Apply pagination total = len(groups_list) @@ -725,7 +725,7 @@ async def get_groups( paginated_groups = groups_list[start:end] # Convert to SCIM format - scim_groups = [group_to_scim(group, request) for group in paginated_groups] + scim_groups = [await group_to_scim(group, request) for group in paginated_groups] return SCIMListResponse( totalResults=total, @@ -742,14 +742,14 @@ async def get_group( _: bool = Depends(get_scim_auth), ): """Get SCIM Group by ID""" - group = Groups.get_group_by_id(group_id) + group = await Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - return group_to_scim(group, request) + return await group_to_scim(group, request) @router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) @@ -774,14 +774,14 @@ async def create_group( ) # Need to get the creating user's ID - we'll use the first admin - admin_user = Users.get_super_admin_user() + admin_user = await Users.get_super_admin_user() if not admin_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No admin user found", ) - new_group = Groups.insert_new_group(admin_user.id, form) + new_group = await Groups.insert_new_group(admin_user.id, form) if not new_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -797,10 +797,10 @@ async def create_group( description=new_group.description, user_ids=member_ids, ) - Groups.update_group_by_id(new_group.id, update_form) - new_group = Groups.get_group_by_id(new_group.id) + await Groups.update_group_by_id(new_group.id, update_form) + new_group = await Groups.get_group_by_id(new_group.id) - return group_to_scim(new_group, request) + return await group_to_scim(new_group, request) @router.put("/Groups/{group_id}", response_model=SCIMGroup) @@ -839,7 +839,7 @@ async def update_group( detail="Failed to update group", ) - return group_to_scim(updated_group, request) + return await group_to_scim(updated_group, request) @router.patch("/Groups/{group_id}", response_model=SCIMGroup) @@ -899,7 +899,7 @@ async def patch_group( detail="Failed to update group", ) - return group_to_scim(updated_group, request) + return await group_to_scim(updated_group, request) @router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 3c3e06a985..befa084d9a 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -49,7 +49,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): request.app.state.config.TOOL_SERVER_CONNECTIONS ) - tools = Tools.get_tools() + tools = await Tools.get_tools() for server in request.app.state.TOOL_SERVERS: tools.append( ToolUserResponse( @@ -83,7 +83,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control) + or await has_access(user.id, "read", tool.access_control) ] return tools @@ -96,9 +96,9 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): @router.get("/list", response_model=list[ToolUserResponse]) async def get_tool_list(user=Depends(get_verified_user)): if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: - tools = Tools.get_tools() + tools = await Tools.get_tools() else: - tools = Tools.get_tools_by_user_id(user.id, "write") + tools = await Tools.get_tools_by_user_id(user.id, "write") return tools @@ -184,7 +184,7 @@ async def load_tool_from_url( @router.get("/export", response_model=list[ToolModel]) async def export_tools(user=Depends(get_admin_user)): - tools = Tools.get_tools() + tools = await Tools.get_tools() return tools @@ -215,11 +215,11 @@ async def create_new_tools( form_data.id = form_data.id.lower() - tools = Tools.get_tool_by_id(form_data.id) + tools = await Tools.get_tool_by_id(form_data.id) if tools is None: try: form_data.content = replace_imports(form_data.content) - tool_module, frontmatter = load_tool_module_by_id( + tool_module, frontmatter = await load_tool_module_by_id( form_data.id, content=form_data.content ) form_data.meta.manifest = frontmatter @@ -228,7 +228,7 @@ async def create_new_tools( TOOLS[form_data.id] = tool_module specs = get_tool_specs(TOOLS[form_data.id]) - tools = Tools.insert_new_tool(user.id, form_data, specs) + tools = await Tools.insert_new_tool(user.id, form_data, specs) tool_cache_dir = CACHE_DIR / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) @@ -260,13 +260,13 @@ async def create_new_tools( @router.get("/id/{id}", response_model=Optional[ToolModel]) async def get_tools_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if tools: if ( user.role == "admin" or tools.user_id == user.id - or has_access(user.id, "read", tools.access_control) + or await has_access(user.id, "read", tools.access_control) ): return tools else: @@ -288,7 +288,7 @@ async def update_tools_by_id( form_data: ToolForm, user=Depends(get_verified_user), ): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -298,7 +298,7 @@ async def update_tools_by_id( # Is the user the original creator, in a group with write access, or an admin if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not await has_access(user.id, "write", tools.access_control) and user.role != "admin" ): raise HTTPException( @@ -308,7 +308,9 @@ async def update_tools_by_id( try: form_data.content = replace_imports(form_data.content) - tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content) + tool_module, frontmatter = await load_tool_module_by_id( + id, content=form_data.content + ) form_data.meta.manifest = frontmatter TOOLS = request.app.state.TOOLS @@ -322,7 +324,7 @@ async def update_tools_by_id( } log.debug(updated) - tools = Tools.update_tool_by_id(id, updated) + tools = await Tools.update_tool_by_id(id, updated) if tools: return tools @@ -348,7 +350,7 @@ async def update_tools_by_id( async def delete_tools_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -357,7 +359,7 @@ async def delete_tools_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not await has_access(user.id, "write", tools.access_control) and user.role != "admin" ): raise HTTPException( @@ -365,7 +367,7 @@ async def delete_tools_by_id( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - result = Tools.delete_tool_by_id(id) + result = await Tools.delete_tool_by_id(id) if result: TOOLS = request.app.state.TOOLS if id in TOOLS: @@ -381,10 +383,10 @@ async def delete_tools_by_id( @router.get("/id/{id}/valves", response_model=Optional[dict]) async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if tools: try: - valves = Tools.get_tool_valves_by_id(id) + valves = await Tools.get_tool_valves_by_id(id) return valves except Exception as e: raise HTTPException( @@ -407,12 +409,12 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] else: - tools_module, _ = load_tool_module_by_id(id) + tools_module, _ = await load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module if hasattr(tools_module, "Valves"): @@ -435,7 +437,7 @@ async def get_tools_valves_spec_by_id( async def update_tools_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if not tools: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -444,7 +446,7 @@ async def update_tools_valves_by_id( if ( tools.user_id != user.id - and not has_access(user.id, "write", tools.access_control) + and not await has_access(user.id, "write", tools.access_control) and user.role != "admin" ): raise HTTPException( @@ -455,7 +457,7 @@ async def update_tools_valves_by_id( if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] else: - tools_module, _ = load_tool_module_by_id(id) + tools_module, _ = await load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module if not hasattr(tools_module, "Valves"): @@ -468,7 +470,7 @@ async def update_tools_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) - Tools.update_tool_valves_by_id(id, valves.model_dump()) + await Tools.update_tool_valves_by_id(id, valves.model_dump()) return valves.model_dump() except Exception as e: log.exception(f"Failed to update tool valves by id {id}: {e}") @@ -485,10 +487,10 @@ async def update_tools_valves_by_id( @router.get("/id/{id}/valves/user", response_model=Optional[dict]) async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if tools: try: - user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) + user_valves = await Tools.get_user_valves_by_id_and_user_id(id, user.id) return user_valves except Exception as e: raise HTTPException( @@ -506,12 +508,12 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_user_valves_spec_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] else: - tools_module, _ = load_tool_module_by_id(id) + tools_module, _ = await load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module if hasattr(tools_module, "UserValves"): @@ -529,13 +531,13 @@ async def get_tools_user_valves_spec_by_id( async def update_tools_user_valves_by_id( request: Request, id: str, form_data: dict, user=Depends(get_verified_user) ): - tools = Tools.get_tool_by_id(id) + tools = await Tools.get_tool_by_id(id) if tools: if id in request.app.state.TOOLS: tools_module = request.app.state.TOOLS[id] else: - tools_module, _ = load_tool_module_by_id(id) + tools_module, _ = await load_tool_module_by_id(id) request.app.state.TOOLS[id] = tools_module if hasattr(tools_module, "UserValves"): @@ -544,7 +546,7 @@ async def update_tools_user_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) - Tools.update_user_valves_by_id_and_user_id( + await Tools.update_user_valves_by_id_and_user_id( id, user.id, user_valves.model_dump() ) return user_valves.model_dump() diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index cad29c475d..d6bbbbadee 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -88,14 +88,14 @@ async def get_users( if direction: filter["direction"] = direction - return Users.get_users(filter=filter, skip=skip, limit=limit) + return await Users.get_users(filter=filter, skip=skip, limit=limit) @router.get("/all", response_model=UserInfoListResponse) async def get_all_users( user=Depends(get_admin_user), ): - return Users.get_users() + return await Users.get_users() ############################ @@ -205,7 +205,7 @@ async def update_default_user_permissions( @router.get("/user/settings", response_model=Optional[UserSettings]) async def get_user_settings_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) + user = await Users.get_user_by_id(user.id) if user: return user.settings else: @@ -237,7 +237,7 @@ async def update_user_settings_by_session_user( # If the user is not an admin and does not have permission to use tool servers, remove the key updated_user_settings["ui"].pop("toolServers", None) - user = Users.update_user_settings_by_id(user.id, updated_user_settings) + user = await Users.update_user_settings_by_id(user.id, updated_user_settings) if user: return user.settings else: @@ -254,7 +254,7 @@ async def update_user_settings_by_session_user( @router.get("/user/info", response_model=Optional[dict]) async def get_user_info_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) + user = await Users.get_user_by_id(user.id) if user: return user.info else: @@ -273,12 +273,14 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): async def update_user_info_by_session_user( form_data: dict, user=Depends(get_verified_user) ): - user = Users.get_user_by_id(user.id) + user = await Users.get_user_by_id(user.id) if user: if user.info is None: user.info = {} - user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) + user = await Users.update_user_by_id( + user.id, {"info": {**user.info, **form_data}} + ) if user: return user.info else: @@ -343,7 +345,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): @router.get("/{user_id}/profile/image") async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): - user = Users.get_user_by_id(user_id) + user = await Users.get_user_by_id(user_id) if user: if user.profile_image_url: # check if it's url or base64 @@ -398,7 +400,7 @@ async def update_user_by_id( ): # Prevent modification of the primary admin user by other admins try: - first_user = Users.get_first_user() + first_user = await Users.get_first_user() if first_user: if user_id == first_user.id: if session_user.id != user_id: @@ -472,7 +474,7 @@ async def update_user_by_id( async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): # Prevent deletion of the primary admin user try: - first_user = Users.get_first_user() + first_user = await Users.get_first_user() if first_user and user_id == first_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 63dda91502..96217c895a 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -263,7 +263,7 @@ async def connect(sid, environ, auth): data = decode_token(auth["token"]) if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + user = await Users.get_user_by_id(data["id"]) if user: SESSION_POOL[sid] = user.model_dump() @@ -284,7 +284,7 @@ async def user_join(sid, data): if data is None or "id" not in data: return - user = Users.get_user_by_id(data["id"]) + user = await Users.get_user_by_id(data["id"]) if not user: return @@ -312,7 +312,7 @@ async def join_channel(sid, data): if data is None or "id" not in data: return - user = Users.get_user_by_id(data["id"]) + user = await Users.get_user_by_id(data["id"]) if not user: return @@ -333,7 +333,7 @@ async def join_note(sid, data): if token_data is None or "id" not in token_data: return - user = Users.get_user_by_id(token_data["id"]) + user = await Users.get_user_by_id(token_data["id"]) if not user: return @@ -345,7 +345,9 @@ async def join_note(sid, data): if ( user.role != "admin" and user.id != note.user_id - and not has_access(user.id, type="read", access_control=note.access_control) + and not await has_access( + user.id, type="read", access_control=note.access_control + ) ): log.error(f"User {user.id} does not have access to note {data['note_id']}") return @@ -400,7 +402,7 @@ async def ydoc_document_join(sid, data): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( + and not await has_access( user.get("id"), type="read", access_control=note.access_control ) ): @@ -470,14 +472,14 @@ async def document_save_handler(document_id, data, user): if ( user.get("role") != "admin" and user.get("id") != note.user_id - and not has_access( + and not await has_access( user.get("id"), type="read", access_control=note.access_control ) ): log.error(f"User {user.get('id')} does not have access to note {note_id}") return - Notes.update_note_by_id(note_id, NoteUpdateForm(data=data)) + await Notes.update_note_by_id(note_id, NoteUpdateForm(data=data)) @sio.on("ydoc:document:state") diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index c93574527f..ec732c0408 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -107,7 +107,7 @@ def has_permission( return get_permission(default_permissions, permission_hierarchy) -def has_access( +async def has_access( user_id: str, type: str = "write", access_control: Optional[dict] = None, @@ -115,7 +115,7 @@ def has_access( if access_control is None: return type == "read" - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = await Groups.get_groups_by_member_id(user_id) user_group_ids = [group.id for group in user_groups] permission_access = access_control.get(type, {}) permitted_group_ids = permission_access.get("group_ids", []) @@ -127,11 +127,11 @@ def has_access( # Get all users with access to a resource -def get_users_with_access( +async def get_users_with_access( type: str = "write", access_control: Optional[dict] = None ) -> List[UserModel]: if access_control is None: - return Users.get_users() + return await Users.get_users() permission_access = access_control.get(type, {}) permitted_group_ids = permission_access.get("group_ids", []) @@ -140,8 +140,8 @@ def get_users_with_access( user_ids_with_access = set(permitted_user_ids) for group_id in permitted_group_ids: - group_user_ids = Groups.get_group_user_ids_by_id(group_id) + group_user_ids = await Groups.get_group_user_ids_by_id(group_id) if group_user_ids: user_ids_with_access.update(group_user_ids) - return Users.get_users_by_user_ids(list(user_ids_with_access)) + return await Users.get_users_by_user_ids(list(user_ids_with_access)) diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 228dd3e30a..79959e3ea2 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -206,7 +206,7 @@ def get_http_authorization_cred(auth_header: Optional[str]): return None -def get_current_user( +async def get_current_user( request: Request, response: Response, background_tasks: BackgroundTasks, @@ -270,7 +270,7 @@ def get_current_user( ) if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + user = await Users.get_user_by_id(data["id"]) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -303,6 +303,7 @@ def get_current_user( # Refresh the user's last active timestamp asynchronously # to prevent blocking the request if background_tasks: + background_tasks.add_task(Users.update_user_last_active_by_id, user.id) return user else: @@ -312,8 +313,8 @@ def get_current_user( ) -def get_current_user_by_api_key(api_key: str): - user = Users.get_user_by_api_key(api_key) +async def get_current_user_by_api_key(api_key: str): + user = await Users.get_user_by_api_key(api_key) if user is None: raise HTTPException( @@ -329,7 +330,7 @@ def get_current_user_by_api_key(api_key: str): current_span.set_attribute("client.user.role", user.role) current_span.set_attribute("client.auth.type", "api_key") - Users.update_user_last_active_by_id(user.id) + await Users.update_user_last_active_by_id(user.id) return user diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 83483f391b..16cfb2e82f 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -199,7 +199,7 @@ async def generate_chat_completion( # Check if user has access to the model if not bypass_filter and user.role == "user": try: - check_model_access(user, model) + await check_model_access(user, model) except Exception as e: raise e @@ -424,8 +424,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A try: if hasattr(function_module, "UserValves"): __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id + **( + await Functions.get_user_valves_by_id_and_user_id( + action_id, user.id + ) ) ) except Exception as e: diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py index 49ce72c3c5..040b78f982 100644 --- a/backend/open_webui/utils/embeddings.py +++ b/backend/open_webui/utils/embeddings.py @@ -70,7 +70,7 @@ async def generate_embeddings( # Access filtering if not getattr(request.state, "direct", False): if not bypass_filter and user.role == "user": - check_model_access(user, model) + await check_model_access(user, model) # Ollama backend if model.get("owned_by") == "ollama": diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 1986e55b64..dbb25d9368 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -109,8 +109,10 @@ async def process_filter_functions( if hasattr(function_module, "UserValves"): try: params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, params["__user__"]["id"] + **( + await Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) ) ) except Exception as e: diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index b713b84307..27bf5d2f2d 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -311,9 +311,9 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) return models -def check_model_access(user, model): +async def check_model_access(user, model): if model.get("arena"): - if not has_access( + if not await has_access( user.id, type="read", access_control=model.get("info", {}) @@ -322,12 +322,12 @@ def check_model_access(user, model): ): raise Exception("Model not found") else: - model_info = Models.get_model_by_id(model.get("id")) + model_info = await Models.get_model_by_id(model.get("id")) if not model_info: raise Exception("Model not found") elif not ( user.id == model_info.user_id - or has_access( + or await has_access( user.id, type="read", access_control=model_info.access_control ) ): diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 9ad067eeb1..4c3986aab8 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -89,8 +89,8 @@ class OAuthManager: def get_client(self, provider_name): return self.oauth.create_client(provider_name) - def get_user_role(self, user, user_data): - user_count = Users.get_num_users() + async def get_user_role(self, user, user_data): + user_count = await Users.get_num_users() if user and user_count == 1: # If the user is the only user, assign the role "admin" - actually repairs role for single user on login log.debug("Assigning the only user the admin role") @@ -147,7 +147,7 @@ class OAuthManager: return role - def update_user_groups(self, user, user_data, default_permissions): + async def update_user_groups(self, user, user_data, default_permissions): log.debug("Running OAUTH Group management") oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM @@ -172,8 +172,10 @@ class OAuthManager: else: user_oauth_groups = [] - user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) - all_available_groups: list[GroupModel] = Groups.get_groups() + user_current_groups: list[GroupModel] = await Groups.get_groups_by_member_id( + user.id + ) + all_available_groups: list[GroupModel] = await Groups.get_groups() # Create groups if they don't exist and creation is enabled if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: @@ -181,7 +183,7 @@ class OAuthManager: all_group_names = {g.name for g in all_available_groups} groups_created = False # Determine creator ID: Prefer admin, fallback to current user if no admin exists - admin_user = Users.get_super_admin_user() + admin_user = await Users.get_super_admin_user() creator_id = admin_user.id if admin_user else user.id log.debug(f"Using creator ID {creator_id} for potential group creation.") @@ -198,7 +200,7 @@ class OAuthManager: user_ids=[], # Start with no users, user will be added later by subsequent logic ) # Use determined creator ID (admin or fallback to current user) - created_group = Groups.insert_new_group( + created_group = await Groups.insert_new_group( creator_id, new_group_form ) if created_group: @@ -217,7 +219,7 @@ class OAuthManager: # Refresh the list of all available groups if any were created if groups_created: - all_available_groups = Groups.get_groups() + all_available_groups = await Groups.get_groups() log.debug("Refreshed list of all available groups after creation.") log.debug(f"Oauth Groups claim: {oauth_claim}") @@ -430,21 +432,21 @@ class OAuthManager: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) + user = await Users.get_user_by_oauth_sub(provider_sub) if not user: # If the user does not exist, check if merging is enabled if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: # Check if the user exists by email - user = Users.get_user_by_email(email) + user = await Users.get_user_by_email(email) if user: # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) + await Users.update_user_oauth_sub_by_id(user.id, provider_sub) if user: - determined_role = self.get_user_role(user, user_data) + determined_role = await self.get_user_role(user, user_data) if user.role != determined_role: - Users.update_user_role_by_id(user.id, determined_role) + await Users.update_user_role_by_id(user.id, determined_role) # Update profile picture if enabled and different from current if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN: @@ -457,7 +459,7 @@ class OAuthManager: new_picture_url, token.get("access_token") ) if processed_picture_url != user.profile_image_url: - Users.update_user_profile_image_url_by_id( + await Users.update_user_profile_image_url_by_id( user.id, processed_picture_url ) log.debug(f"Updated profile picture for user {user.email}") @@ -466,7 +468,7 @@ class OAuthManager: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(email) + existing_user = await Users.get_user_by_email(email) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -488,7 +490,7 @@ class OAuthManager: log.warning("Username claim is missing, using email as name") name = email - role = self.get_user_role(None, user_data) + role = await self.get_user_role(None, user_data) user = await Auths.insert_new_auth( email=email, @@ -523,7 +525,7 @@ class OAuthManager: ) if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin": - self.update_user_groups( + await self.update_user_groups( user=user, user_data=user_data, default_permissions=request.app.state.config.USER_PERMISSIONS, diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 9d539f4840..a38657a854 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -68,17 +68,17 @@ def replace_imports(content): return content -def load_tool_module_by_id(tool_id, content=None): +async def load_tool_module_by_id(tool_id, content=None): if content is None: - tool = Tools.get_tool_by_id(tool_id) + tool = await Tools.get_tool_by_id(tool_id) if not tool: raise Exception(f"Toolkit not found: {tool_id}") content = tool.content content = replace_imports(content) - Tools.update_tool_by_id(tool_id, {"content": content}) + await Tools.update_tool_by_id(tool_id, {"content": content}) else: frontmatter = extract_frontmatter(content) # Install required packages found within the frontmatter @@ -241,7 +241,7 @@ def install_frontmatter_requirements(requirements: str): log.info("No requirements found in frontmatter.") -def install_tool_and_function_dependencies(): +async def install_tool_and_function_dependencies(): """ Install all dependencies for all admin tools and active functions. @@ -249,8 +249,8 @@ def install_tool_and_function_dependencies(): and then installing them using pip. Duplicates or similar version specifications are handled by pip as much as possible. """ - function_list = Functions.get_functions(active_only=True) - tool_list = Tools.get_tools() + function_list = await Functions.get_functions(active_only=True) + tool_list = await Tools.get_tools() all_dependencies = "" try: diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 45f0ef7dda..1a108f0302 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -68,13 +68,13 @@ def get_async_tool_function_and_apply_extra_params( return new_function -def get_tools( +async def get_tools( request: Request, tool_ids: list[str], user: UserModel, extra_params: dict ) -> dict[str, dict]: tools_dict = {} for tool_id in tool_ids: - tool = Tools.get_tool_by_id(tool_id) + tool = await Tools.get_tool_by_id(tool_id) if tool is None: if tool_id.startswith("server:"): server_idx = int(tool_id.split(":")[1]) @@ -140,18 +140,18 @@ def get_tools( else: module = request.app.state.TOOLS.get(tool_id, None) if module is None: - module, _ = load_tool_module_by_id(tool_id) + module, _ = await load_tool_module_by_id(tool_id) request.app.state.TOOLS[tool_id] = module extra_params["__id__"] = tool_id # Set valves for the tool if hasattr(module, "valves") and hasattr(module, "Valves"): - valves = Tools.get_tool_valves_by_id(tool_id) or {} + valves = await Tools.get_tool_valves_by_id(tool_id) or {} module.valves = module.Valves(**valves) if hasattr(module, "UserValves"): extra_params["__user__"]["valves"] = module.UserValves( # type: ignore - **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) + **(await Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)) ) for spec in tool.specs: