refac: group members backend

This commit is contained in:
Timothy Jaeryang Baek 2025-11-17 05:09:06 -05:00
parent f05e945a45
commit bc576782d7
4 changed files with 218 additions and 103 deletions

View file

@ -11,7 +11,7 @@ from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,7 +35,6 @@ class Group(Base):
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
permissions = Column(JSON, nullable=True) permissions = Column(JSON, nullable=True)
user_ids = Column(JSON, nullable=True)
created_at = Column(BigInteger) created_at = Column(BigInteger)
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
@ -53,12 +52,33 @@ class GroupModel(BaseModel):
meta: Optional[dict] = None meta: Optional[dict] = None
permissions: Optional[dict] = None permissions: Optional[dict] = None
user_ids: list[str] = []
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
class GroupMember(Base):
__tablename__ = "group_member"
id = Column(Text, unique=True, primary_key=True)
group_id = Column(
Text,
ForeignKey("group.id", ondelete="CASCADE"),
nullable=False,
)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=True)
updated_at = Column(BigInteger, nullable=True)
class GroupMemberModel(BaseModel):
id: str
group_id: str
user_id: str
created_at: Optional[int] = None # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch
#################### ####################
# Forms # Forms
#################### ####################
@ -72,7 +92,7 @@ class GroupResponse(BaseModel):
permissions: Optional[dict] = None permissions: Optional[dict] = None
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
user_ids: list[str] = [] member_count: Optional[int] = None
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
@ -87,7 +107,7 @@ class UserIdsForm(BaseModel):
user_ids: Optional[list[str]] = None user_ids: Optional[list[str]] = None
class GroupUpdateForm(GroupForm, UserIdsForm): class GroupUpdateForm(GroupForm):
pass pass
@ -131,12 +151,8 @@ class GroupTable:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
for group in db.query(Group) for group in db.query(Group)
.filter( .join(GroupMember, GroupMember.group_id == Group.id)
func.json_array_length(Group.user_ids) > 0 .filter(GroupMember.user_id == user_id)
) # Ensure array exists
.filter(
Group.user_ids.cast(String).like(f'%"{user_id}"%')
) # String-based check
.order_by(Group.updated_at.desc()) .order_by(Group.updated_at.desc())
.all() .all()
] ]
@ -149,13 +165,47 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id(self, id: str) -> Optional[str]: def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
group = self.get_group_by_id(id) with get_db() as db:
if group: members = (
return group.user_ids db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
else: )
if not members:
return None return None
return [m[0] for m in members]
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
with get_db() as db:
# Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
# Insert new members
now = int(time.time())
new_members = [
GroupMember(
id=str(uuid.uuid4()),
group_id=group_id,
user_id=user_id,
created_at=now,
updated_at=now,
)
for user_id in user_ids
]
db.add_all(new_members)
db.commit()
def get_group_member_count_by_id(self, id: str) -> int:
with get_db() as db:
count = (
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
return count if count else 0
def update_group_by_id( def update_group_by_id(
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
@ -195,20 +245,29 @@ class GroupTable:
def remove_user_from_all_groups(self, user_id: str) -> bool: def remove_user_from_all_groups(self, user_id: str) -> bool:
with get_db() as db: with get_db() as db:
try: try:
groups = self.get_groups_by_member_id(user_id) # Find all groups the user belongs to
groups = (
for group in groups: db.query(Group)
group.user_ids.remove(user_id) .join(GroupMember, GroupMember.group_id == Group.id)
db.query(Group).filter_by(id=group.id).update( .filter(GroupMember.user_id == user_id)
{ .all()
"user_ids": group.user_ids,
"updated_at": int(time.time()),
}
) )
db.commit()
# Remove the user from each group
for group in groups:
db.query(GroupMember).filter(
GroupMember.group_id == group.id, GroupMember.user_id == user_id
).delete()
db.query(Group).filter_by(id=group.id).update(
{"updated_at": int(time.time())}
)
db.commit()
return True return True
except Exception: except Exception:
db.rollback()
return False return False
def create_groups_by_group_names( def create_groups_by_group_names(
@ -246,37 +305,61 @@ class GroupTable:
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
with get_db() as db: with get_db() as db:
try: try:
groups = db.query(Group).filter(Group.name.in_(group_names)).all() now = int(time.time())
group_ids = [group.id for group in groups]
# Remove user from groups not in the new list # 1. Groups that SHOULD contain the user
existing_groups = self.get_groups_by_member_id(user_id) target_groups = (
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_group_ids = {g.id for g in target_groups}
for group in existing_groups: # 2. Groups the user is CURRENTLY in
if group.id not in group_ids: existing_group_ids = {
group.user_ids.remove(user_id) g.id
db.query(Group).filter_by(id=group.id).update( for g in db.query(Group)
{ .join(GroupMember, GroupMember.group_id == Group.id)
"user_ids": group.user_ids, .filter(GroupMember.user_id == user_id)
"updated_at": int(time.time()), .all()
} }
# 3. Determine adds + removals
groups_to_add = target_group_ids - existing_group_ids
groups_to_remove = existing_group_ids - target_group_ids
# 4. Remove in one bulk delete
if groups_to_remove:
db.query(GroupMember).filter(
GroupMember.user_id == user_id,
GroupMember.group_id.in_(groups_to_remove),
).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False
) )
# Add user to new groups # 5. Bulk insert missing memberships
for group in groups: for group_id in groups_to_add:
if user_id not in group.user_ids: db.add(
group.user_ids.append(user_id) GroupMember(
db.query(Group).filter_by(id=group.id).update( id=str(uuid.uuid4()),
{ group_id=group_id,
"user_ids": group.user_ids, user_id=user_id,
"updated_at": int(time.time()), created_at=now,
} updated_at=now,
)
)
if groups_to_add:
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
{"updated_at": now}, synchronize_session=False
) )
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
db.rollback()
return False return False
def add_users_to_group( def add_users_to_group(
@ -288,21 +371,31 @@ class GroupTable:
if not group: if not group:
return None return None
group_user_ids = group.user_ids now = int(time.time())
if not group_user_ids or not isinstance(group_user_ids, list):
group_user_ids = []
group_user_ids = list(set(group_user_ids)) # Deduplicate for user_id in user_ids or []:
try:
db.add(
GroupMember(
id=str(uuid.uuid4()),
group_id=id,
user_id=user_id,
created_at=now,
updated_at=now,
)
)
db.flush() # Detect unique constraint violation early
except Exception:
db.rollback() # Clear failed INSERT
db.begin() # Start a new transaction
continue # Duplicate → ignore
for user_id in user_ids: group.updated_at = now
if user_id not in group_user_ids:
group_user_ids.append(user_id)
group.user_ids = group_user_ids
group.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(group) db.refresh(group)
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
@ -316,23 +409,22 @@ class GroupTable:
if not group: if not group:
return None return None
group_user_ids = group.user_ids if not user_ids:
if not group_user_ids or not isinstance(group_user_ids, list):
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
group_user_ids = list(set(group_user_ids)) # Deduplicate # Remove each user from group_member
for user_id in user_ids: for user_id in user_ids:
if user_id in group_user_ids: db.query(GroupMember).filter(
group_user_ids.remove(user_id) GroupMember.group_id == id, GroupMember.user_id == user_id
).delete()
group.user_ids = group_user_ids # Update group timestamp
group.updated_at = int(time.time()) group.updated_at = int(time.time())
db.commit() db.commit()
db.refresh(group) db.refresh(group)
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None

View file

@ -33,9 +33,18 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse]) @router.get("/", response_model=list[GroupResponse])
async def get_groups(user=Depends(get_verified_user)): async def get_groups(user=Depends(get_verified_user)):
if user.role == "admin": if user.role == "admin":
return Groups.get_groups() groups = Groups.get_groups()
else: else:
return Groups.get_groups_by_member_id(user.id) groups = Groups.get_groups_by_member_id(user.id)
return [
GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
for group in groups
if group
]
############################ ############################
@ -48,7 +57,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
try: try:
group = Groups.insert_new_group(user.id, form_data) group = Groups.insert_new_group(user.id, form_data)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -71,7 +83,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
async def get_group_by_id(id: str, user=Depends(get_admin_user)): async def get_group_by_id(id: str, user=Depends(get_admin_user)):
group = Groups.get_group_by_id(id) group = Groups.get_group_by_id(id)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -89,12 +104,12 @@ async def update_group_by_id(
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
): ):
try: try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data) group = Groups.update_group_by_id(id, form_data)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -123,7 +138,10 @@ async def add_user_to_group(
group = Groups.add_users_to_group(id, form_data.user_ids) group = Groups.add_users_to_group(id, form_data.user_ids)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@ -144,7 +162,10 @@ async def remove_users_from_group(
try: try:
group = Groups.remove_users_from_group(id, form_data.user_ids) group = Groups.remove_users_from_group(id, form_data.user_ids)
if group: if group:
return group return GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View file

@ -349,8 +349,10 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
"""Convert internal Group model to SCIM Group""" """Convert internal Group model to SCIM Group"""
member_ids = Groups.get_group_user_ids_by_id(group.id)
members = [] members = []
for user_id in group.user_ids:
for user_id in member_ids:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
if user: if user:
members.append( members.append(
@ -796,9 +798,11 @@ async def create_group(
update_form = GroupUpdateForm( update_form = GroupUpdateForm(
name=new_group.name, name=new_group.name,
description=new_group.description, description=new_group.description,
user_ids=member_ids,
) )
Groups.update_group_by_id(new_group.id, update_form) Groups.update_group_by_id(new_group.id, update_form)
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
new_group = Groups.get_group_by_id(new_group.id) new_group = Groups.get_group_by_id(new_group.id)
return group_to_scim(new_group, request) return group_to_scim(new_group, request)
@ -830,7 +834,7 @@ async def update_group(
# Handle members if provided # Handle members if provided
if group_data.members is not None: if group_data.members is not None:
member_ids = [member.value for member in group_data.members] member_ids = [member.value for member in group_data.members]
update_form.user_ids = member_ids Groups.set_group_user_ids_by_id(group_id, member_ids)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = Groups.update_group_by_id(group_id, update_form)
@ -863,7 +867,6 @@ async def patch_group(
update_form = GroupUpdateForm( update_form = GroupUpdateForm(
name=group.name, name=group.name,
description=group.description, description=group.description,
user_ids=group.user_ids.copy() if group.user_ids else [],
) )
for operation in patch_data.Operations: for operation in patch_data.Operations:
@ -876,21 +879,22 @@ async def patch_group(
update_form.name = value update_form.name = value
elif path == "members": elif path == "members":
# Replace all members # Replace all members
update_form.user_ids = [member["value"] for member in value] Groups.set_group_user_ids_by_id(
group_id, [member["value"] for member in value]
)
elif op == "add": elif op == "add":
if path == "members": if path == "members":
# Add members # Add members
if isinstance(value, list): if isinstance(value, list):
for member in value: for member in value:
if isinstance(member, dict) and "value" in member: if isinstance(member, dict) and "value" in member:
if member["value"] not in update_form.user_ids: Groups.add_users_to_group(group_id, [member["value"]])
update_form.user_ids.append(member["value"])
elif op == "remove": elif op == "remove":
if path and path.startswith("members[value eq"): if path and path.startswith("members[value eq"):
# Remove specific member # Remove specific member
member_id = path.split('"')[1] member_id = path.split('"')[1]
if member_id in update_form.user_ids: Groups.remove_users_from_group(group_id, [member_id])
update_form.user_ids.remove(member_id)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = Groups.update_group_by_id(group_id, update_form)

View file

@ -1130,22 +1130,21 @@ class OAuthManager:
f"Removing user from group {group_model.name} as it is no longer in their oauth groups" f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
) )
user_ids = group_model.user_ids Groups.remove_users_from_group(group_model.id, [user.id])
user_ids = [i for i in user_ids if i != user.id]
# In case a group is created, but perms are never assigned to the group by hitting "save" # In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions group_permissions = group_model.permissions
if not group_permissions: if not group_permissions:
group_permissions = default_permissions group_permissions = default_permissions
update_form = GroupUpdateForm( Groups.update_group_by_id(
id=group_model.id,
form_data=GroupUpdateForm(
name=group_model.name, name=group_model.name,
description=group_model.description, description=group_model.description,
permissions=group_permissions, permissions=group_permissions,
user_ids=user_ids, ),
) overwrite=False,
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
) )
# Add user to new groups # Add user to new groups
@ -1161,22 +1160,21 @@ class OAuthManager:
f"Adding user to group {group_model.name} as it was found in their oauth groups" f"Adding user to group {group_model.name} as it was found in their oauth groups"
) )
user_ids = group_model.user_ids Groups.add_users_to_group(group_model.id, [user.id])
user_ids.append(user.id)
# In case a group is created, but perms are never assigned to the group by hitting "save" # In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions group_permissions = group_model.permissions
if not group_permissions: if not group_permissions:
group_permissions = default_permissions group_permissions = default_permissions
update_form = GroupUpdateForm( Groups.update_group_by_id(
id=group_model.id,
form_data=GroupUpdateForm(
name=group_model.name, name=group_model.name,
description=group_model.description, description=group_model.description,
permissions=group_permissions, permissions=group_permissions,
user_ids=user_ids, ),
) overwrite=False,
Groups.update_group_by_id(
id=group_model.id, form_data=update_form, overwrite=False
) )
async def _process_picture_url( async def _process_picture_url(