wip: groups

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:52:04 +04:00
parent eb86ac7a2b
commit 3b9e454fb4
7 changed files with 78 additions and 74 deletions

View file

@ -92,7 +92,7 @@ class GroupUpdateForm(GroupForm, UserIdsForm):
class GroupTable: class GroupTable:
def insert_new_group( async def insert_new_group(
self, user_id: str, form_data: GroupForm self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
async with get_db() as db: async with get_db() as db:
@ -108,9 +108,9 @@ class GroupTable:
try: try:
result = Group(**group.model_dump()) result = Group(**group.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return GroupModel.model_validate(result) return GroupModel.model_validate(result)
else: else:
@ -119,18 +119,20 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_groups(self) -> list[GroupModel]: async def get_groups(self) -> list[GroupModel]:
async with get_db() as db: async with get_db() as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all() for group in await db.query(Group)
.order_by(Group.updated_at.desc())
.all()
] ]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: async def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
async with get_db() as db: async with get_db() as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group) for group in await db.query(Group)
.filter( .filter(
func.json_array_length(Group.user_ids) > 0 func.json_array_length(Group.user_ids) > 0
) # Ensure array exists ) # Ensure array exists
@ -141,82 +143,82 @@ class GroupTable:
.all() .all()
] ]
def get_group_by_id(self, id: str) -> Optional[GroupModel]: async def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try: try:
async with get_db() as db: async with get_db() as db:
group = db.query(Group).filter_by(id=id).first() group = await db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None return GroupModel.model_validate(group) if group else None
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]: async def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id) group = await self.get_group_by_id(id)
if group: if group:
return group.user_ids return group.user_ids
else: else:
return None return None
def update_group_by_id( async def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Group).filter_by(id=id).update( await db.query(Group).filter_by(id=id).update(
{ {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return self.get_group_by_id(id=id) return await self.get_group_by_id(id=id)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_group_by_id(self, id: str) -> bool: async def delete_group_by_id(self, id: str) -> bool:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Group).filter_by(id=id).delete() await db.query(Group).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_all_groups(self) -> bool: async def delete_all_groups(self) -> bool:
async with get_db() as db: async with get_db() as db:
try: try:
db.query(Group).delete() await db.query(Group).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def remove_user_from_all_groups(self, user_id: str) -> bool: async def remove_user_from_all_groups(self, user_id: str) -> bool:
async with get_db() as db: async with get_db() as db:
try: try:
groups = self.get_groups_by_member_id(user_id) groups = await self.get_groups_by_member_id(user_id)
for group in groups: for group in groups:
group.user_ids.remove(user_id) group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update( await db.query(Group).filter_by(id=group.id).update(
{ {
"user_ids": group.user_ids, "user_ids": group.user_ids,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def create_groups_by_group_names( async def create_groups_by_group_names(
self, user_id: str, group_names: list[str] self, user_id: str, group_names: list[str]
) -> list[GroupModel]: ) -> list[GroupModel]:
# check for existing groups # check for existing groups
existing_groups = self.get_groups() existing_groups = await self.get_groups()
existing_group_names = {group.name for group in existing_groups} existing_group_names = {group.name for group in existing_groups}
new_groups = [] new_groups = []
@ -234,28 +236,30 @@ class GroupTable:
) )
try: try:
result = Group(**new_group.model_dump()) result = Group(**new_group.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
new_groups.append(GroupModel.model_validate(result)) new_groups.append(GroupModel.model_validate(result))
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
continue continue
return new_groups return new_groups
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: async def sync_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> bool:
async with get_db() as db: async with get_db() as db:
try: try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all() groups = await db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups] group_ids = [group.id for group in groups]
# Remove user from groups not in the new list # Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id) existing_groups = await self.get_groups_by_member_id(user_id)
for group in existing_groups: for group in existing_groups:
if group.id not in group_ids: if group.id not in group_ids:
group.user_ids.remove(user_id) group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update( await db.query(Group).filter_by(id=group.id).update(
{ {
"user_ids": group.user_ids, "user_ids": group.user_ids,
"updated_at": int(time.time()), "updated_at": int(time.time()),
@ -266,25 +270,25 @@ class GroupTable:
for group in groups: for group in groups:
if user_id not in group.user_ids: if user_id not in group.user_ids:
group.user_ids.append(user_id) group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update( await db.query(Group).filter_by(id=group.id).update(
{ {
"user_ids": group.user_ids, "user_ids": group.user_ids,
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
db.commit() await db.commit()
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return False return False
def add_users_to_group( async def add_users_to_group(
self, id: str, user_ids: Optional[list[str]] = None self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
async with get_db() as db: async with get_db() as db:
group = db.query(Group).filter_by(id=id).first() group = await db.query(Group).filter_by(id=id).first()
if not group: if not group:
return None return None
@ -296,19 +300,19 @@ class GroupTable:
group.user_ids.append(user_id) group.user_ids.append(user_id)
group.updated_at = int(time.time()) group.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(group) await db.refresh(group)
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def remove_users_from_group( async def remove_users_from_group(
self, id: str, user_ids: Optional[list[str]] = None self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
try: try:
async with get_db() as db: async with get_db() as db:
group = db.query(Group).filter_by(id=id).first() group = await db.query(Group).filter_by(id=id).first()
if not group: if not group:
return None return None
@ -320,8 +324,8 @@ class GroupTable:
group.user_ids.remove(user_id) group.user_ids.remove(user_id)
group.updated_at = int(time.time()) group.updated_at = int(time.time())
db.commit() await db.commit()
db.refresh(group) await db.refresh(group)
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View file

@ -108,7 +108,7 @@ async def get_session_user(
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
) )
user_permissions = get_permissions( user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )
@ -406,7 +406,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
) )
user_permissions = get_permissions( user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )
@ -416,10 +416,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
and user_groups and user_groups
): ):
if ENABLE_LDAP_GROUP_CREATION: if ENABLE_LDAP_GROUP_CREATION:
Groups.create_groups_by_group_names(user.id, user_groups) await Groups.create_groups_by_group_names(user.id, user_groups)
try: try:
Groups.sync_groups_by_group_names(user.id, user_groups) await Groups.sync_groups_by_group_names(user.id, user_groups)
log.info( log.info(
f"Successfully synced groups for user {user.id}: {user_groups}" f"Successfully synced groups for user {user.id}: {user_groups}"
) )
@ -478,7 +478,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
group_names = [name.strip() for name in group_names if name.strip()] group_names = [name.strip() for name in group_names if name.strip()]
if group_names: if group_names:
Groups.sync_groups_by_group_names(user.id, group_names) await Groups.sync_groups_by_group_names(user.id, group_names)
elif WEBUI_AUTH == False: elif WEBUI_AUTH == False:
admin_email = "admin@localhost" admin_email = "admin@localhost"
@ -530,7 +530,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
secure=WEBUI_AUTH_COOKIE_SECURE, secure=WEBUI_AUTH_COOKIE_SECURE,
) )
user_permissions = get_permissions( user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )
@ -638,7 +638,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
}, },
) )
user_permissions = get_permissions( user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )

View file

@ -33,9 +33,9 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse]) @router.get("/", response_model=list[GroupResponse])
async def get_groups(user=Depends(get_verified_user)): async def get_groups(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
return Groups.get_groups() return await Groups.get_groups()
else: else:
return Groups.get_groups_by_member_id(user.id) return await Groups.get_groups_by_member_id(user.id)
############################ ############################
@ -46,7 +46,7 @@ async def get_groups(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[GroupResponse]) @router.post("/create", response_model=Optional[GroupResponse])
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try: try:
group = Groups.insert_new_group(user.id, form_data) group = await Groups.insert_new_group(user.id, form_data)
if group: if group:
return group return group
else: else:
@ -69,7 +69,7 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
@router.get("/id/{id}", response_model=Optional[GroupResponse]) @router.get("/id/{id}", response_model=Optional[GroupResponse])
async def get_group_by_id(id: str, user=Depends(get_admin_user)): async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id) group = await Groups.get_group_by_id(id)
if group: if group:
return group return group
else: else:
@ -92,7 +92,7 @@ async def update_group_by_id(
if form_data.user_ids: if form_data.user_ids:
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids) form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data) group = await Groups.update_group_by_id(id, form_data)
if group: if group:
return group return group
else: else:
@ -121,7 +121,7 @@ async def add_user_to_group(
if form_data.user_ids: if form_data.user_ids:
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids) form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.add_users_to_group(id, form_data.user_ids) group = await Groups.add_users_to_group(id, form_data.user_ids)
if group: if group:
return group return group
else: else:
@ -142,7 +142,7 @@ async def remove_users_from_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user) id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
): ):
try: try:
group = Groups.remove_users_from_group(id, form_data.user_ids) group = await Groups.remove_users_from_group(id, form_data.user_ids)
if group: if group:
return group return group
else: else:
@ -166,7 +166,7 @@ async def remove_users_from_group(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_group_by_id(id: str, user=Depends(get_admin_user)): async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
try: try:
result = Groups.delete_group_by_id(id) result = await Groups.delete_group_by_id(id)
if result: if result:
return result return result
else: else:

View file

@ -305,7 +305,7 @@ async def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
family_name = name_parts[1] if len(name_parts) > 1 else "" family_name = name_parts[1] if len(name_parts) > 1 else ""
# Get user's groups # Get user's groups
user_groups = Groups.get_groups_by_member_id(user.id) user_groups = await Groups.get_groups_by_member_id(user.id)
groups = [ groups = [
{ {
"value": group.id, "value": group.id,
@ -811,7 +811,7 @@ async def update_group(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Update SCIM Group (full update)""" """Update SCIM Group (full update)"""
group = Groups.get_group_by_id(group_id) group = await Groups.get_group_by_id(group_id)
if not group: if not group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -832,7 +832,7 @@ async def update_group(
update_form.user_ids = member_ids update_form.user_ids = member_ids
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = await Groups.update_group_by_id(group_id, update_form)
if not updated_group: if not updated_group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -850,7 +850,7 @@ async def patch_group(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Update SCIM Group (partial update)""" """Update SCIM Group (partial update)"""
group = Groups.get_group_by_id(group_id) group = await Groups.get_group_by_id(group_id)
if not group: if not group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -892,7 +892,7 @@ async def patch_group(
update_form.user_ids.remove(member_id) update_form.user_ids.remove(member_id)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = await Groups.update_group_by_id(group_id, update_form)
if not updated_group: if not updated_group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -909,14 +909,14 @@ async def delete_group(
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
): ):
"""Delete SCIM Group""" """Delete SCIM Group"""
group = Groups.get_group_by_id(group_id) group = await Groups.get_group_by_id(group_id)
if not group: if not group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found", detail=f"Group {group_id} not found",
) )
success = Groups.delete_group_by_id(group_id) success = await Groups.delete_group_by_id(group_id)
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View file

@ -105,7 +105,7 @@ async def get_all_users(
@router.get("/groups") @router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)): async def get_user_groups(user=Depends(get_verified_user)):
return Groups.get_groups_by_member_id(user.id) return await Groups.get_groups_by_member_id(user.id)
############################ ############################
@ -115,7 +115,7 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions") @router.get("/permissions")
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)): async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
user_permissions = get_permissions( user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )
@ -512,4 +512,4 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
@router.get("/{user_id}/groups") @router.get("/{user_id}/groups")
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)): async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
return Groups.get_groups_by_member_id(user_id) return await Groups.get_groups_by_member_id(user_id)

View file

@ -25,7 +25,7 @@ def fill_missing_permissions(
return permissions return permissions
def get_permissions( async def get_permissions(
user_id: str, user_id: str,
default_permissions: Dict[str, Any], default_permissions: Dict[str, Any],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -53,7 +53,7 @@ def get_permissions(
) # Use the most permissive value (True > False) ) # Use the most permissive value (True > False)
return permissions return permissions
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = await Groups.get_groups_by_member_id(user_id)
# Deep copy default permissions to avoid modifying the original dict # Deep copy default permissions to avoid modifying the original dict
permissions = json.loads(json.dumps(default_permissions)) permissions = json.loads(json.dumps(default_permissions))

View file

@ -255,7 +255,7 @@ class OAuthManager:
permissions=group_permissions, permissions=group_permissions,
user_ids=user_ids, user_ids=user_ids,
) )
Groups.update_group_by_id( await Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False id=group_model.id, form_data=update_form, overwrite=False
) )
@ -286,7 +286,7 @@ class OAuthManager:
permissions=group_permissions, permissions=group_permissions,
user_ids=user_ids, user_ids=user_ids,
) )
Groups.update_group_by_id( await Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False id=group_model.id, form_data=update_form, overwrite=False
) )