diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 4097bab531..c714ae9d5e 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -92,7 +92,7 @@ class GroupUpdateForm(GroupForm, UserIdsForm): class GroupTable: - def insert_new_group( + async def insert_new_group( self, user_id: str, form_data: GroupForm ) -> Optional[GroupModel]: async with get_db() as db: @@ -108,9 +108,9 @@ class GroupTable: try: result = Group(**group.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return GroupModel.model_validate(result) else: @@ -119,18 +119,20 @@ class GroupTable: except Exception: return None - def get_groups(self) -> list[GroupModel]: + async def get_groups(self) -> list[GroupModel]: async with get_db() as db: return [ GroupModel.model_validate(group) - for group in db.query(Group).order_by(Group.updated_at.desc()).all() + for group in await db.query(Group) + .order_by(Group.updated_at.desc()) + .all() ] - def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: + async def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: async with get_db() as db: return [ GroupModel.model_validate(group) - for group in db.query(Group) + for group in await db.query(Group) .filter( func.json_array_length(Group.user_ids) > 0 ) # Ensure array exists @@ -141,82 +143,82 @@ class GroupTable: .all() ] - def get_group_by_id(self, id: str) -> Optional[GroupModel]: + async def get_group_by_id(self, id: str) -> Optional[GroupModel]: try: async with get_db() as db: - group = db.query(Group).filter_by(id=id).first() + group = await db.query(Group).filter_by(id=id).first() return GroupModel.model_validate(group) if group else None except Exception: return None - def get_group_user_ids_by_id(self, id: str) -> Optional[str]: - group = self.get_group_by_id(id) + async def get_group_user_ids_by_id(self, id: str) -> Optional[str]: + group = await self.get_group_by_id(id) if group: return group.user_ids else: return None - def update_group_by_id( + async def update_group_by_id( self, id: str, form_data: GroupUpdateForm, overwrite: bool = False ) -> Optional[GroupModel]: try: async with get_db() as db: - db.query(Group).filter_by(id=id).update( + await db.query(Group).filter_by(id=id).update( { **form_data.model_dump(exclude_none=True), "updated_at": int(time.time()), } ) - db.commit() - return self.get_group_by_id(id=id) + await db.commit() + return await self.get_group_by_id(id=id) except Exception as e: log.exception(e) return None - def delete_group_by_id(self, id: str) -> bool: + async def delete_group_by_id(self, id: str) -> bool: try: async with get_db() as db: - db.query(Group).filter_by(id=id).delete() - db.commit() + await db.query(Group).filter_by(id=id).delete() + await db.commit() return True except Exception: return False - def delete_all_groups(self) -> bool: + async def delete_all_groups(self) -> bool: async with get_db() as db: try: - db.query(Group).delete() - db.commit() + await db.query(Group).delete() + await db.commit() return True except Exception: return False - def remove_user_from_all_groups(self, user_id: str) -> bool: + async def remove_user_from_all_groups(self, user_id: str) -> bool: async with get_db() as db: try: - groups = self.get_groups_by_member_id(user_id) + groups = await self.get_groups_by_member_id(user_id) for group in groups: group.user_ids.remove(user_id) - db.query(Group).filter_by(id=group.id).update( + await db.query(Group).filter_by(id=group.id).update( { "user_ids": group.user_ids, "updated_at": int(time.time()), } ) - db.commit() + await db.commit() return True except Exception: return False - def create_groups_by_group_names( + async def create_groups_by_group_names( self, user_id: str, group_names: list[str] ) -> list[GroupModel]: # check for existing groups - existing_groups = self.get_groups() + existing_groups = await self.get_groups() existing_group_names = {group.name for group in existing_groups} new_groups = [] @@ -234,28 +236,30 @@ class GroupTable: ) try: result = Group(**new_group.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) new_groups.append(GroupModel.model_validate(result)) except Exception as e: log.exception(e) continue return new_groups - def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: + async def sync_groups_by_group_names( + self, user_id: str, group_names: list[str] + ) -> bool: async with get_db() as db: try: - groups = db.query(Group).filter(Group.name.in_(group_names)).all() + groups = await db.query(Group).filter(Group.name.in_(group_names)).all() group_ids = [group.id for group in groups] # Remove user from groups not in the new list - existing_groups = self.get_groups_by_member_id(user_id) + existing_groups = await self.get_groups_by_member_id(user_id) for group in existing_groups: if group.id not in group_ids: group.user_ids.remove(user_id) - db.query(Group).filter_by(id=group.id).update( + await db.query(Group).filter_by(id=group.id).update( { "user_ids": group.user_ids, "updated_at": int(time.time()), @@ -266,25 +270,25 @@ class GroupTable: for group in groups: if user_id not in group.user_ids: group.user_ids.append(user_id) - db.query(Group).filter_by(id=group.id).update( + await db.query(Group).filter_by(id=group.id).update( { "user_ids": group.user_ids, "updated_at": int(time.time()), } ) - db.commit() + await db.commit() return True except Exception as e: log.exception(e) return False - def add_users_to_group( + async def add_users_to_group( self, id: str, user_ids: Optional[list[str]] = None ) -> Optional[GroupModel]: try: async with get_db() as db: - group = db.query(Group).filter_by(id=id).first() + group = await db.query(Group).filter_by(id=id).first() if not group: return None @@ -296,19 +300,19 @@ class GroupTable: group.user_ids.append(user_id) group.updated_at = int(time.time()) - db.commit() - db.refresh(group) + await db.commit() + await db.refresh(group) return GroupModel.model_validate(group) except Exception as e: log.exception(e) return None - def remove_users_from_group( + async def remove_users_from_group( self, id: str, user_ids: Optional[list[str]] = None ) -> Optional[GroupModel]: try: async with get_db() as db: - group = db.query(Group).filter_by(id=id).first() + group = await db.query(Group).filter_by(id=id).first() if not group: return None @@ -320,8 +324,8 @@ class GroupTable: group.user_ids.remove(user_id) group.updated_at = int(time.time()) - db.commit() - db.refresh(group) + await db.commit() + await db.refresh(group) return GroupModel.model_validate(group) except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index aa562d18e5..d6ca7ec36d 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -108,7 +108,7 @@ async def get_session_user( secure=WEBUI_AUTH_COOKIE_SECURE, ) - user_permissions = get_permissions( + user_permissions = await get_permissions( user.id, request.app.state.config.USER_PERMISSIONS ) @@ -406,7 +406,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): secure=WEBUI_AUTH_COOKIE_SECURE, ) - user_permissions = get_permissions( + user_permissions = await get_permissions( user.id, request.app.state.config.USER_PERMISSIONS ) @@ -416,10 +416,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): and user_groups ): if ENABLE_LDAP_GROUP_CREATION: - Groups.create_groups_by_group_names(user.id, user_groups) + await Groups.create_groups_by_group_names(user.id, user_groups) try: - Groups.sync_groups_by_group_names(user.id, user_groups) + await Groups.sync_groups_by_group_names(user.id, user_groups) log.info( f"Successfully synced groups for user {user.id}: {user_groups}" ) @@ -478,7 +478,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): group_names = [name.strip() for name in group_names if name.strip()] if group_names: - Groups.sync_groups_by_group_names(user.id, group_names) + await Groups.sync_groups_by_group_names(user.id, group_names) elif WEBUI_AUTH == False: admin_email = "admin@localhost" @@ -530,7 +530,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): secure=WEBUI_AUTH_COOKIE_SECURE, ) - user_permissions = get_permissions( + user_permissions = await get_permissions( user.id, request.app.state.config.USER_PERMISSIONS ) @@ -638,7 +638,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): }, ) - user_permissions = get_permissions( + user_permissions = await get_permissions( user.id, request.app.state.config.USER_PERMISSIONS ) diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index e5b31ff8e3..026e33ecfa 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -33,9 +33,9 @@ router = APIRouter() @router.get("/", response_model=list[GroupResponse]) async def get_groups(user=Depends(get_verified_user)): if user.role == "admin": - return Groups.get_groups() + return await Groups.get_groups() else: - return Groups.get_groups_by_member_id(user.id) + return await Groups.get_groups_by_member_id(user.id) ############################ @@ -46,7 +46,7 @@ async def get_groups(user=Depends(get_verified_user)): @router.post("/create", response_model=Optional[GroupResponse]) async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): try: - group = Groups.insert_new_group(user.id, form_data) + group = await Groups.insert_new_group(user.id, form_data) if group: return group else: @@ -69,7 +69,7 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): @router.get("/id/{id}", response_model=Optional[GroupResponse]) async def get_group_by_id(id: str, user=Depends(get_admin_user)): - group = Groups.get_group_by_id(id) + group = await Groups.get_group_by_id(id) if group: return group else: @@ -92,7 +92,7 @@ async def update_group_by_id( if form_data.user_ids: form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids) - group = Groups.update_group_by_id(id, form_data) + group = await Groups.update_group_by_id(id, form_data) if group: return group else: @@ -121,7 +121,7 @@ async def add_user_to_group( if form_data.user_ids: form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids) - group = Groups.add_users_to_group(id, form_data.user_ids) + group = await Groups.add_users_to_group(id, form_data.user_ids) if group: return group else: @@ -142,7 +142,7 @@ async def remove_users_from_group( id: str, form_data: UserIdsForm, user=Depends(get_admin_user) ): try: - group = Groups.remove_users_from_group(id, form_data.user_ids) + group = await Groups.remove_users_from_group(id, form_data.user_ids) if group: return group else: @@ -166,7 +166,7 @@ async def remove_users_from_group( @router.delete("/id/{id}/delete", response_model=bool) async def delete_group_by_id(id: str, user=Depends(get_admin_user)): try: - result = Groups.delete_group_by_id(id) + result = await Groups.delete_group_by_id(id) if result: return result else: diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index fb8c8263dd..091a931681 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -305,7 +305,7 @@ async def user_to_scim(user: UserModel, request: Request) -> SCIMUser: family_name = name_parts[1] if len(name_parts) > 1 else "" # Get user's groups - user_groups = Groups.get_groups_by_member_id(user.id) + user_groups = await Groups.get_groups_by_member_id(user.id) groups = [ { "value": group.id, @@ -811,7 +811,7 @@ async def update_group( _: bool = Depends(get_scim_auth), ): """Update SCIM Group (full update)""" - group = Groups.get_group_by_id(group_id) + group = await Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -832,7 +832,7 @@ async def update_group( update_form.user_ids = member_ids # Update group - updated_group = Groups.update_group_by_id(group_id, update_form) + updated_group = await Groups.update_group_by_id(group_id, update_form) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -850,7 +850,7 @@ async def patch_group( _: bool = Depends(get_scim_auth), ): """Update SCIM Group (partial update)""" - group = Groups.get_group_by_id(group_id) + group = await Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -892,7 +892,7 @@ async def patch_group( update_form.user_ids.remove(member_id) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form) + updated_group = await Groups.update_group_by_id(group_id, update_form) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -909,14 +909,14 @@ async def delete_group( _: bool = Depends(get_scim_auth), ): """Delete SCIM Group""" - group = Groups.get_group_by_id(group_id) + group = await Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - success = Groups.delete_group_by_id(group_id) + success = await Groups.delete_group_by_id(group_id) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index d6bbbbadee..db96d9d98d 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -105,7 +105,7 @@ async def get_all_users( @router.get("/groups") async def get_user_groups(user=Depends(get_verified_user)): - return Groups.get_groups_by_member_id(user.id) + return await Groups.get_groups_by_member_id(user.id) ############################ @@ -115,7 +115,7 @@ async def get_user_groups(user=Depends(get_verified_user)): @router.get("/permissions") async def get_user_permissisions(request: Request, user=Depends(get_verified_user)): - user_permissions = get_permissions( + user_permissions = await get_permissions( user.id, request.app.state.config.USER_PERMISSIONS ) @@ -512,4 +512,4 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): @router.get("/{user_id}/groups") async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)): - return Groups.get_groups_by_member_id(user_id) + return await Groups.get_groups_by_member_id(user_id) diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 7c7b379e3e..2233c8913c 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -25,7 +25,7 @@ def fill_missing_permissions( return permissions -def get_permissions( +async def get_permissions( user_id: str, default_permissions: Dict[str, Any], ) -> Dict[str, Any]: @@ -53,7 +53,7 @@ def get_permissions( ) # Use the most permissive value (True > False) return permissions - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = await Groups.get_groups_by_member_id(user_id) # Deep copy default permissions to avoid modifying the original dict permissions = json.loads(json.dumps(default_permissions)) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 4c3986aab8..d71fdf1160 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -255,7 +255,7 @@ class OAuthManager: permissions=group_permissions, user_ids=user_ids, ) - Groups.update_group_by_id( + await Groups.update_group_by_id( id=group_model.id, form_data=update_form, overwrite=False ) @@ -286,7 +286,7 @@ class OAuthManager: permissions=group_permissions, user_ids=user_ids, ) - Groups.update_group_by_id( + await Groups.update_group_by_id( id=group_model.id, form_data=update_form, overwrite=False )