wip: groups

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:52:04 +04:00
parent eb86ac7a2b
commit 3b9e454fb4
7 changed files with 78 additions and 74 deletions

View file

@ -92,7 +92,7 @@ class GroupUpdateForm(GroupForm, UserIdsForm):
class GroupTable:
def insert_new_group(
async def insert_new_group(
self, user_id: str, form_data: GroupForm
) -> Optional[GroupModel]:
async with get_db() as db:
@ -108,9 +108,9 @@ class GroupTable:
try:
result = Group(**group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
await db.add(result)
await db.commit()
await db.refresh(result)
if result:
return GroupModel.model_validate(result)
else:
@ -119,18 +119,20 @@ class GroupTable:
except Exception:
return None
def get_groups(self) -> list[GroupModel]:
async def get_groups(self) -> list[GroupModel]:
async with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
for group in await db.query(Group)
.order_by(Group.updated_at.desc())
.all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
async def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
async with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
for group in await db.query(Group)
.filter(
func.json_array_length(Group.user_ids) > 0
) # Ensure array exists
@ -141,82 +143,82 @@ class GroupTable:
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
async def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
async with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
group = await db.query(Group).filter_by(id=id).first()
return GroupModel.model_validate(group) if group else None
except Exception:
return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = self.get_group_by_id(id)
async def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
group = await self.get_group_by_id(id)
if group:
return group.user_ids
else:
return None
def update_group_by_id(
async def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]:
try:
async with get_db() as db:
db.query(Group).filter_by(id=id).update(
await db.query(Group).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_group_by_id(id=id)
await db.commit()
return await self.get_group_by_id(id=id)
except Exception as e:
log.exception(e)
return None
def delete_group_by_id(self, id: str) -> bool:
async def delete_group_by_id(self, id: str) -> bool:
try:
async with get_db() as db:
db.query(Group).filter_by(id=id).delete()
db.commit()
await db.query(Group).filter_by(id=id).delete()
await db.commit()
return True
except Exception:
return False
def delete_all_groups(self) -> bool:
async def delete_all_groups(self) -> bool:
async with get_db() as db:
try:
db.query(Group).delete()
db.commit()
await db.query(Group).delete()
await db.commit()
return True
except Exception:
return False
def remove_user_from_all_groups(self, user_id: str) -> bool:
async def remove_user_from_all_groups(self, user_id: str) -> bool:
async with get_db() as db:
try:
groups = self.get_groups_by_member_id(user_id)
groups = await self.get_groups_by_member_id(user_id)
for group in groups:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
await db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
await db.commit()
return True
except Exception:
return False
def create_groups_by_group_names(
async def create_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> list[GroupModel]:
# check for existing groups
existing_groups = self.get_groups()
existing_groups = await self.get_groups()
existing_group_names = {group.name for group in existing_groups}
new_groups = []
@ -234,28 +236,30 @@ class GroupTable:
)
try:
result = Group(**new_group.model_dump())
db.add(result)
db.commit()
db.refresh(result)
await db.add(result)
await db.commit()
await db.refresh(result)
new_groups.append(GroupModel.model_validate(result))
except Exception as e:
log.exception(e)
continue
return new_groups
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
async def sync_groups_by_group_names(
self, user_id: str, group_names: list[str]
) -> bool:
async with get_db() as db:
try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
groups = await db.query(Group).filter(Group.name.in_(group_names)).all()
group_ids = [group.id for group in groups]
# Remove user from groups not in the new list
existing_groups = self.get_groups_by_member_id(user_id)
existing_groups = await self.get_groups_by_member_id(user_id)
for group in existing_groups:
if group.id not in group_ids:
group.user_ids.remove(user_id)
db.query(Group).filter_by(id=group.id).update(
await db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
@ -266,25 +270,25 @@ class GroupTable:
for group in groups:
if user_id not in group.user_ids:
group.user_ids.append(user_id)
db.query(Group).filter_by(id=group.id).update(
await db.query(Group).filter_by(id=group.id).update(
{
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
)
db.commit()
await db.commit()
return True
except Exception as e:
log.exception(e)
return False
def add_users_to_group(
async def add_users_to_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
async with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
group = await db.query(Group).filter_by(id=id).first()
if not group:
return None
@ -296,19 +300,19 @@ class GroupTable:
group.user_ids.append(user_id)
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
await db.commit()
await db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)
return None
def remove_users_from_group(
async def remove_users_from_group(
self, id: str, user_ids: Optional[list[str]] = None
) -> Optional[GroupModel]:
try:
async with get_db() as db:
group = db.query(Group).filter_by(id=id).first()
group = await db.query(Group).filter_by(id=id).first()
if not group:
return None
@ -320,8 +324,8 @@ class GroupTable:
group.user_ids.remove(user_id)
group.updated_at = int(time.time())
db.commit()
db.refresh(group)
await db.commit()
await db.refresh(group)
return GroupModel.model_validate(group)
except Exception as e:
log.exception(e)

View file

@ -108,7 +108,7 @@ async def get_session_user(
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
@ -406,7 +406,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
@ -416,10 +416,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
and user_groups
):
if ENABLE_LDAP_GROUP_CREATION:
Groups.create_groups_by_group_names(user.id, user_groups)
await Groups.create_groups_by_group_names(user.id, user_groups)
try:
Groups.sync_groups_by_group_names(user.id, user_groups)
await Groups.sync_groups_by_group_names(user.id, user_groups)
log.info(
f"Successfully synced groups for user {user.id}: {user_groups}"
)
@ -478,7 +478,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
group_names = [name.strip() for name in group_names if name.strip()]
if group_names:
Groups.sync_groups_by_group_names(user.id, group_names)
await Groups.sync_groups_by_group_names(user.id, group_names)
elif WEBUI_AUTH == False:
admin_email = "admin@localhost"
@ -530,7 +530,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
secure=WEBUI_AUTH_COOKIE_SECURE,
)
user_permissions = get_permissions(
user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
@ -638,7 +638,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
},
)
user_permissions = get_permissions(
user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)

View file

@ -33,9 +33,9 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse])
async def get_groups(user=Depends(get_verified_user)):
if user.role == "admin":
return Groups.get_groups()
return await Groups.get_groups()
else:
return Groups.get_groups_by_member_id(user.id)
return await Groups.get_groups_by_member_id(user.id)
############################
@ -46,7 +46,7 @@ async def get_groups(user=Depends(get_verified_user)):
@router.post("/create", response_model=Optional[GroupResponse])
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try:
group = Groups.insert_new_group(user.id, form_data)
group = await Groups.insert_new_group(user.id, form_data)
if group:
return group
else:
@ -69,7 +69,7 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
@router.get("/id/{id}", response_model=Optional[GroupResponse])
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id)
group = await Groups.get_group_by_id(id)
if group:
return group
else:
@ -92,7 +92,7 @@ async def update_group_by_id(
if form_data.user_ids:
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data)
group = await Groups.update_group_by_id(id, form_data)
if group:
return group
else:
@ -121,7 +121,7 @@ async def add_user_to_group(
if form_data.user_ids:
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.add_users_to_group(id, form_data.user_ids)
group = await Groups.add_users_to_group(id, form_data.user_ids)
if group:
return group
else:
@ -142,7 +142,7 @@ async def remove_users_from_group(
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
):
try:
group = Groups.remove_users_from_group(id, form_data.user_ids)
group = await Groups.remove_users_from_group(id, form_data.user_ids)
if group:
return group
else:
@ -166,7 +166,7 @@ async def remove_users_from_group(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
try:
result = Groups.delete_group_by_id(id)
result = await Groups.delete_group_by_id(id)
if result:
return result
else:

View file

@ -305,7 +305,7 @@ async def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
family_name = name_parts[1] if len(name_parts) > 1 else ""
# Get user's groups
user_groups = Groups.get_groups_by_member_id(user.id)
user_groups = await Groups.get_groups_by_member_id(user.id)
groups = [
{
"value": group.id,
@ -811,7 +811,7 @@ async def update_group(
_: bool = Depends(get_scim_auth),
):
"""Update SCIM Group (full update)"""
group = Groups.get_group_by_id(group_id)
group = await Groups.get_group_by_id(group_id)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -832,7 +832,7 @@ async def update_group(
update_form.user_ids = member_ids
# Update group
updated_group = Groups.update_group_by_id(group_id, update_form)
updated_group = await Groups.update_group_by_id(group_id, update_form)
if not updated_group:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -850,7 +850,7 @@ async def patch_group(
_: bool = Depends(get_scim_auth),
):
"""Update SCIM Group (partial update)"""
group = Groups.get_group_by_id(group_id)
group = await Groups.get_group_by_id(group_id)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -892,7 +892,7 @@ async def patch_group(
update_form.user_ids.remove(member_id)
# Update group
updated_group = Groups.update_group_by_id(group_id, update_form)
updated_group = await Groups.update_group_by_id(group_id, update_form)
if not updated_group:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -909,14 +909,14 @@ async def delete_group(
_: bool = Depends(get_scim_auth),
):
"""Delete SCIM Group"""
group = Groups.get_group_by_id(group_id)
group = await Groups.get_group_by_id(group_id)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found",
)
success = Groups.delete_group_by_id(group_id)
success = await Groups.delete_group_by_id(group_id)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View file

@ -105,7 +105,7 @@ async def get_all_users(
@router.get("/groups")
async def get_user_groups(user=Depends(get_verified_user)):
return Groups.get_groups_by_member_id(user.id)
return await Groups.get_groups_by_member_id(user.id)
############################
@ -115,7 +115,7 @@ async def get_user_groups(user=Depends(get_verified_user)):
@router.get("/permissions")
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
user_permissions = get_permissions(
user_permissions = await get_permissions(
user.id, request.app.state.config.USER_PERMISSIONS
)
@ -512,4 +512,4 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
@router.get("/{user_id}/groups")
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
return Groups.get_groups_by_member_id(user_id)
return await Groups.get_groups_by_member_id(user_id)

View file

@ -25,7 +25,7 @@ def fill_missing_permissions(
return permissions
def get_permissions(
async def get_permissions(
user_id: str,
default_permissions: Dict[str, Any],
) -> Dict[str, Any]:
@ -53,7 +53,7 @@ def get_permissions(
) # Use the most permissive value (True > False)
return permissions
user_groups = Groups.get_groups_by_member_id(user_id)
user_groups = await Groups.get_groups_by_member_id(user_id)
# Deep copy default permissions to avoid modifying the original dict
permissions = json.loads(json.dumps(default_permissions))

View file

@ -255,7 +255,7 @@ class OAuthManager:
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
await Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
)
@ -286,7 +286,7 @@ class OAuthManager:
permissions=group_permissions,
user_ids=user_ids,
)
Groups.update_group_by_id(
await Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
)