mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
wip: groups
This commit is contained in:
parent
eb86ac7a2b
commit
3b9e454fb4
7 changed files with 78 additions and 74 deletions
|
|
@ -92,7 +92,7 @@ class GroupUpdateForm(GroupForm, UserIdsForm):
|
|||
|
||||
|
||||
class GroupTable:
|
||||
def insert_new_group(
|
||||
async def insert_new_group(
|
||||
self, user_id: str, form_data: GroupForm
|
||||
) -> Optional[GroupModel]:
|
||||
async with get_db() as db:
|
||||
|
|
@ -108,9 +108,9 @@ class GroupTable:
|
|||
|
||||
try:
|
||||
result = Group(**group.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
await db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
if result:
|
||||
return GroupModel.model_validate(result)
|
||||
else:
|
||||
|
|
@ -119,18 +119,20 @@ class GroupTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_groups(self) -> list[GroupModel]:
|
||||
async def get_groups(self) -> list[GroupModel]:
|
||||
async with get_db() as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||
for group in await db.query(Group)
|
||||
.order_by(Group.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||
async def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||
async with get_db() as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group)
|
||||
for group in await db.query(Group)
|
||||
.filter(
|
||||
func.json_array_length(Group.user_ids) > 0
|
||||
) # Ensure array exists
|
||||
|
|
@ -141,82 +143,82 @@ class GroupTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
||||
async def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
group = await db.query(Group).filter_by(id=id).first()
|
||||
return GroupModel.model_validate(group) if group else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||
group = self.get_group_by_id(id)
|
||||
async def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||
group = await self.get_group_by_id(id)
|
||||
if group:
|
||||
return group.user_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_group_by_id(
|
||||
async def update_group_by_id(
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Group).filter_by(id=id).update(
|
||||
await db.query(Group).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_group_by_id(id=id)
|
||||
await db.commit()
|
||||
return await self.get_group_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def delete_group_by_id(self, id: str) -> bool:
|
||||
async def delete_group_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Group).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
await db.query(Group).filter_by(id=id).delete()
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_groups(self) -> bool:
|
||||
async def delete_all_groups(self) -> bool:
|
||||
async with get_db() as db:
|
||||
try:
|
||||
db.query(Group).delete()
|
||||
db.commit()
|
||||
await db.query(Group).delete()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||
async def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||
async with get_db() as db:
|
||||
try:
|
||||
groups = self.get_groups_by_member_id(user_id)
|
||||
groups = await self.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in groups:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
await db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def create_groups_by_group_names(
|
||||
async def create_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str]
|
||||
) -> list[GroupModel]:
|
||||
|
||||
# check for existing groups
|
||||
existing_groups = self.get_groups()
|
||||
existing_groups = await self.get_groups()
|
||||
existing_group_names = {group.name for group in existing_groups}
|
||||
|
||||
new_groups = []
|
||||
|
|
@ -234,28 +236,30 @@ class GroupTable:
|
|||
)
|
||||
try:
|
||||
result = Group(**new_group.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
await db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
new_groups.append(GroupModel.model_validate(result))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
continue
|
||||
return new_groups
|
||||
|
||||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
||||
async def sync_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str]
|
||||
) -> bool:
|
||||
async with get_db() as db:
|
||||
try:
|
||||
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
groups = await db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
group_ids = [group.id for group in groups]
|
||||
|
||||
# Remove user from groups not in the new list
|
||||
existing_groups = self.get_groups_by_member_id(user_id)
|
||||
existing_groups = await self.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in existing_groups:
|
||||
if group.id not in group_ids:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
await db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
|
|
@ -266,25 +270,25 @@ class GroupTable:
|
|||
for group in groups:
|
||||
if user_id not in group.user_ids:
|
||||
group.user_ids.append(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
await db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return False
|
||||
|
||||
def add_users_to_group(
|
||||
async def add_users_to_group(
|
||||
self, id: str, user_ids: Optional[list[str]] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
group = await db.query(Group).filter_by(id=id).first()
|
||||
if not group:
|
||||
return None
|
||||
|
||||
|
|
@ -296,19 +300,19 @@ class GroupTable:
|
|||
group.user_ids.append(user_id)
|
||||
|
||||
group.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
await db.commit()
|
||||
await db.refresh(group)
|
||||
return GroupModel.model_validate(group)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
||||
def remove_users_from_group(
|
||||
async def remove_users_from_group(
|
||||
self, id: str, user_ids: Optional[list[str]] = None
|
||||
) -> Optional[GroupModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
group = await db.query(Group).filter_by(id=id).first()
|
||||
if not group:
|
||||
return None
|
||||
|
||||
|
|
@ -320,8 +324,8 @@ class GroupTable:
|
|||
group.user_ids.remove(user_id)
|
||||
|
||||
group.updated_at = int(time.time())
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
await db.commit()
|
||||
await db.refresh(group)
|
||||
return GroupModel.model_validate(group)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ async def get_session_user(
|
|||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user_permissions = await get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
|
|
@ -406,7 +406,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user_permissions = await get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
|
|
@ -416,10 +416,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
and user_groups
|
||||
):
|
||||
if ENABLE_LDAP_GROUP_CREATION:
|
||||
Groups.create_groups_by_group_names(user.id, user_groups)
|
||||
await Groups.create_groups_by_group_names(user.id, user_groups)
|
||||
|
||||
try:
|
||||
Groups.sync_groups_by_group_names(user.id, user_groups)
|
||||
await Groups.sync_groups_by_group_names(user.id, user_groups)
|
||||
log.info(
|
||||
f"Successfully synced groups for user {user.id}: {user_groups}"
|
||||
)
|
||||
|
|
@ -478,7 +478,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
group_names = [name.strip() for name in group_names if name.strip()]
|
||||
|
||||
if group_names:
|
||||
Groups.sync_groups_by_group_names(user.id, group_names)
|
||||
await Groups.sync_groups_by_group_names(user.id, group_names)
|
||||
|
||||
elif WEBUI_AUTH == False:
|
||||
admin_email = "admin@localhost"
|
||||
|
|
@ -530,7 +530,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
secure=WEBUI_AUTH_COOKIE_SECURE,
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user_permissions = await get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
|
|
@ -638,7 +638,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
},
|
||||
)
|
||||
|
||||
user_permissions = get_permissions(
|
||||
user_permissions = await get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -33,9 +33,9 @@ router = APIRouter()
|
|||
@router.get("/", response_model=list[GroupResponse])
|
||||
async def get_groups(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
return Groups.get_groups()
|
||||
return await Groups.get_groups()
|
||||
else:
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
return await Groups.get_groups_by_member_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -46,7 +46,7 @@ async def get_groups(user=Depends(get_verified_user)):
|
|||
@router.post("/create", response_model=Optional[GroupResponse])
|
||||
async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
group = await Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
|
|
@ -69,7 +69,7 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
|||
|
||||
@router.get("/id/{id}", response_model=Optional[GroupResponse])
|
||||
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
group = Groups.get_group_by_id(id)
|
||||
group = await Groups.get_group_by_id(id)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
|
|
@ -92,7 +92,7 @@ async def update_group_by_id(
|
|||
if form_data.user_ids:
|
||||
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
|
||||
|
||||
group = Groups.update_group_by_id(id, form_data)
|
||||
group = await Groups.update_group_by_id(id, form_data)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
|
|
@ -121,7 +121,7 @@ async def add_user_to_group(
|
|||
if form_data.user_ids:
|
||||
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
|
||||
|
||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||
group = await Groups.add_users_to_group(id, form_data.user_ids)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
|
|
@ -142,7 +142,7 @@ async def remove_users_from_group(
|
|||
id: str, form_data: UserIdsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
||||
group = await Groups.remove_users_from_group(id, form_data.user_ids)
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
|
|
@ -166,7 +166,7 @@ async def remove_users_from_group(
|
|||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
try:
|
||||
result = Groups.delete_group_by_id(id)
|
||||
result = await Groups.delete_group_by_id(id)
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ async def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
|||
family_name = name_parts[1] if len(name_parts) > 1 else ""
|
||||
|
||||
# Get user's groups
|
||||
user_groups = Groups.get_groups_by_member_id(user.id)
|
||||
user_groups = await Groups.get_groups_by_member_id(user.id)
|
||||
groups = [
|
||||
{
|
||||
"value": group.id,
|
||||
|
|
@ -811,7 +811,7 @@ async def update_group(
|
|||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Update SCIM Group (full update)"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = await Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -832,7 +832,7 @@ async def update_group(
|
|||
update_form.user_ids = member_ids
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
updated_group = await Groups.update_group_by_id(group_id, update_form)
|
||||
if not updated_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -850,7 +850,7 @@ async def patch_group(
|
|||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Update SCIM Group (partial update)"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = await Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -892,7 +892,7 @@ async def patch_group(
|
|||
update_form.user_ids.remove(member_id)
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
updated_group = await Groups.update_group_by_id(group_id, update_form)
|
||||
if not updated_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -909,14 +909,14 @@ async def delete_group(
|
|||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Delete SCIM Group"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = await Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
success = Groups.delete_group_by_id(group_id)
|
||||
success = await Groups.delete_group_by_id(group_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ async def get_all_users(
|
|||
|
||||
@router.get("/groups")
|
||||
async def get_user_groups(user=Depends(get_verified_user)):
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
return await Groups.get_groups_by_member_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -115,7 +115,7 @@ async def get_user_groups(user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/permissions")
|
||||
async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
|
||||
user_permissions = get_permissions(
|
||||
user_permissions = await get_permissions(
|
||||
user.id, request.app.state.config.USER_PERMISSIONS
|
||||
)
|
||||
|
||||
|
|
@ -512,4 +512,4 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
|||
|
||||
@router.get("/{user_id}/groups")
|
||||
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
return Groups.get_groups_by_member_id(user_id)
|
||||
return await Groups.get_groups_by_member_id(user_id)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def fill_missing_permissions(
|
|||
return permissions
|
||||
|
||||
|
||||
def get_permissions(
|
||||
async def get_permissions(
|
||||
user_id: str,
|
||||
default_permissions: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
|
|
@ -53,7 +53,7 @@ def get_permissions(
|
|||
) # Use the most permissive value (True > False)
|
||||
return permissions
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id)
|
||||
|
||||
# Deep copy default permissions to avoid modifying the original dict
|
||||
permissions = json.loads(json.dumps(default_permissions))
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class OAuthManager:
|
|||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
await Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
|
|
@ -286,7 +286,7 @@ class OAuthManager:
|
|||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
await Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue