mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: group members backend
This commit is contained in:
parent
f05e945a45
commit
bc576782d7
4 changed files with 218 additions and 103 deletions
|
|
@ -11,7 +11,7 @@ from open_webui.models.files import FileMetadataResponse
|
|||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON, func
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -35,7 +35,6 @@ class Group(Base):
|
|||
meta = Column(JSON, nullable=True)
|
||||
|
||||
permissions = Column(JSON, nullable=True)
|
||||
user_ids = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
|
@ -53,12 +52,33 @@ class GroupModel(BaseModel):
|
|||
meta: Optional[dict] = None
|
||||
|
||||
permissions: Optional[dict] = None
|
||||
user_ids: list[str] = []
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class GroupMember(Base):
|
||||
__tablename__ = "group_member"
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
group_id = Column(
|
||||
Text,
|
||||
ForeignKey("group.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
user_id = Column(Text, nullable=False)
|
||||
created_at = Column(BigInteger, nullable=True)
|
||||
updated_at = Column(BigInteger, nullable=True)
|
||||
|
||||
|
||||
class GroupMemberModel(BaseModel):
|
||||
id: str
|
||||
group_id: str
|
||||
user_id: str
|
||||
created_at: Optional[int] = None # timestamp in epoch
|
||||
updated_at: Optional[int] = None # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
|
@ -72,7 +92,7 @@ class GroupResponse(BaseModel):
|
|||
permissions: Optional[dict] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
user_ids: list[str] = []
|
||||
member_count: Optional[int] = None
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
|
@ -87,7 +107,7 @@ class UserIdsForm(BaseModel):
|
|||
user_ids: Optional[list[str]] = None
|
||||
|
||||
|
||||
class GroupUpdateForm(GroupForm, UserIdsForm):
|
||||
class GroupUpdateForm(GroupForm):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -131,12 +151,8 @@ class GroupTable:
|
|||
return [
|
||||
GroupModel.model_validate(group)
|
||||
for group in db.query(Group)
|
||||
.filter(
|
||||
func.json_array_length(Group.user_ids) > 0
|
||||
) # Ensure array exists
|
||||
.filter(
|
||||
Group.user_ids.cast(String).like(f'%"{user_id}"%')
|
||||
) # String-based check
|
||||
.join(GroupMember, GroupMember.group_id == Group.id)
|
||||
.filter(GroupMember.user_id == user_id)
|
||||
.order_by(Group.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
|
@ -149,13 +165,47 @@ class GroupTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||
group = self.get_group_by_id(id)
|
||||
if group:
|
||||
return group.user_ids
|
||||
else:
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
|
||||
with get_db() as db:
|
||||
members = (
|
||||
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
||||
)
|
||||
|
||||
if not members:
|
||||
return None
|
||||
|
||||
return [m[0] for m in members]
|
||||
|
||||
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
|
||||
with get_db() as db:
|
||||
# Delete existing members
|
||||
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
||||
|
||||
# Insert new members
|
||||
now = int(time.time())
|
||||
new_members = [
|
||||
GroupMember(
|
||||
id=str(uuid.uuid4()),
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
for user_id in user_ids
|
||||
]
|
||||
|
||||
db.add_all(new_members)
|
||||
db.commit()
|
||||
|
||||
def get_group_member_count_by_id(self, id: str) -> int:
|
||||
with get_db() as db:
|
||||
count = (
|
||||
db.query(func.count(GroupMember.user_id))
|
||||
.filter(GroupMember.group_id == id)
|
||||
.scalar()
|
||||
)
|
||||
return count if count else 0
|
||||
|
||||
def update_group_by_id(
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||
) -> Optional[GroupModel]:
|
||||
|
|
@ -195,20 +245,29 @@ class GroupTable:
|
|||
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
groups = self.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in groups:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
# Find all groups the user belongs to
|
||||
groups = (
|
||||
db.query(Group)
|
||||
.join(GroupMember, GroupMember.group_id == Group.id)
|
||||
.filter(GroupMember.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
db.commit()
|
||||
|
||||
# Remove the user from each group
|
||||
for group in groups:
|
||||
db.query(GroupMember).filter(
|
||||
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
||||
).delete()
|
||||
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{"updated_at": int(time.time())}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
db.rollback()
|
||||
return False
|
||||
|
||||
def create_groups_by_group_names(
|
||||
|
|
@ -246,37 +305,61 @@ class GroupTable:
|
|||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
||||
with get_db() as db:
|
||||
try:
|
||||
groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
group_ids = [group.id for group in groups]
|
||||
now = int(time.time())
|
||||
|
||||
# Remove user from groups not in the new list
|
||||
existing_groups = self.get_groups_by_member_id(user_id)
|
||||
# 1. Groups that SHOULD contain the user
|
||||
target_groups = (
|
||||
db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
)
|
||||
target_group_ids = {g.id for g in target_groups}
|
||||
|
||||
for group in existing_groups:
|
||||
if group.id not in group_ids:
|
||||
group.user_ids.remove(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
# 2. Groups the user is CURRENTLY in
|
||||
existing_group_ids = {
|
||||
g.id
|
||||
for g in db.query(Group)
|
||||
.join(GroupMember, GroupMember.group_id == Group.id)
|
||||
.filter(GroupMember.user_id == user_id)
|
||||
.all()
|
||||
}
|
||||
|
||||
# 3. Determine adds + removals
|
||||
groups_to_add = target_group_ids - existing_group_ids
|
||||
groups_to_remove = existing_group_ids - target_group_ids
|
||||
|
||||
# 4. Remove in one bulk delete
|
||||
if groups_to_remove:
|
||||
db.query(GroupMember).filter(
|
||||
GroupMember.user_id == user_id,
|
||||
GroupMember.group_id.in_(groups_to_remove),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
||||
{"updated_at": now}, synchronize_session=False
|
||||
)
|
||||
|
||||
# Add user to new groups
|
||||
for group in groups:
|
||||
if user_id not in group.user_ids:
|
||||
group.user_ids.append(user_id)
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{
|
||||
"user_ids": group.user_ids,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
# 5. Bulk insert missing memberships
|
||||
for group_id in groups_to_add:
|
||||
db.add(
|
||||
GroupMember(
|
||||
id=str(uuid.uuid4()),
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
if groups_to_add:
|
||||
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
||||
{"updated_at": now}, synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
db.rollback()
|
||||
return False
|
||||
|
||||
def add_users_to_group(
|
||||
|
|
@ -288,21 +371,31 @@ class GroupTable:
|
|||
if not group:
|
||||
return None
|
||||
|
||||
group_user_ids = group.user_ids
|
||||
if not group_user_ids or not isinstance(group_user_ids, list):
|
||||
group_user_ids = []
|
||||
now = int(time.time())
|
||||
|
||||
group_user_ids = list(set(group_user_ids)) # Deduplicate
|
||||
for user_id in user_ids or []:
|
||||
try:
|
||||
db.add(
|
||||
GroupMember(
|
||||
id=str(uuid.uuid4()),
|
||||
group_id=id,
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
db.flush() # Detect unique constraint violation early
|
||||
except Exception:
|
||||
db.rollback() # Clear failed INSERT
|
||||
db.begin() # Start a new transaction
|
||||
continue # Duplicate → ignore
|
||||
|
||||
for user_id in user_ids:
|
||||
if user_id not in group_user_ids:
|
||||
group_user_ids.append(user_id)
|
||||
|
||||
group.user_ids = group_user_ids
|
||||
group.updated_at = int(time.time())
|
||||
group.updated_at = now
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
|
||||
return GroupModel.model_validate(group)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
|
@ -316,23 +409,22 @@ class GroupTable:
|
|||
if not group:
|
||||
return None
|
||||
|
||||
group_user_ids = group.user_ids
|
||||
|
||||
if not group_user_ids or not isinstance(group_user_ids, list):
|
||||
if not user_ids:
|
||||
return GroupModel.model_validate(group)
|
||||
|
||||
group_user_ids = list(set(group_user_ids)) # Deduplicate
|
||||
|
||||
# Remove each user from group_member
|
||||
for user_id in user_ids:
|
||||
if user_id in group_user_ids:
|
||||
group_user_ids.remove(user_id)
|
||||
db.query(GroupMember).filter(
|
||||
GroupMember.group_id == id, GroupMember.user_id == user_id
|
||||
).delete()
|
||||
|
||||
group.user_ids = group_user_ids
|
||||
# Update group timestamp
|
||||
group.updated_at = int(time.time())
|
||||
|
||||
db.commit()
|
||||
db.refresh(group)
|
||||
return GroupModel.model_validate(group)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -33,9 +33,18 @@ router = APIRouter()
|
|||
@router.get("/", response_model=list[GroupResponse])
|
||||
async def get_groups(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
return Groups.get_groups()
|
||||
groups = Groups.get_groups()
|
||||
else:
|
||||
return Groups.get_groups_by_member_id(user.id)
|
||||
groups = Groups.get_groups_by_member_id(user.id)
|
||||
|
||||
return [
|
||||
GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
for group in groups
|
||||
if group
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -48,7 +57,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
|||
try:
|
||||
group = Groups.insert_new_group(user.id, form_data)
|
||||
if group:
|
||||
return group
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -71,7 +83,10 @@ async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)):
|
|||
async def get_group_by_id(id: str, user=Depends(get_admin_user)):
|
||||
group = Groups.get_group_by_id(id)
|
||||
if group:
|
||||
return group
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -89,12 +104,12 @@ async def update_group_by_id(
|
|||
id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
try:
|
||||
if form_data.user_ids:
|
||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
||||
|
||||
group = Groups.update_group_by_id(id, form_data)
|
||||
if group:
|
||||
return group
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -123,7 +138,10 @@ async def add_user_to_group(
|
|||
|
||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||
if group:
|
||||
return group
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
@ -144,7 +162,10 @@ async def remove_users_from_group(
|
|||
try:
|
||||
group = Groups.remove_users_from_group(id, form_data.user_ids)
|
||||
if group:
|
||||
return group
|
||||
return GroupResponse(
|
||||
**group.model_dump(),
|
||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
|
|
|||
|
|
@ -349,8 +349,10 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
|||
|
||||
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||
"""Convert internal Group model to SCIM Group"""
|
||||
member_ids = Groups.get_group_user_ids_by_id(group.id)
|
||||
members = []
|
||||
for user_id in group.user_ids:
|
||||
|
||||
for user_id in member_ids:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
if user:
|
||||
members.append(
|
||||
|
|
@ -796,9 +798,11 @@ async def create_group(
|
|||
update_form = GroupUpdateForm(
|
||||
name=new_group.name,
|
||||
description=new_group.description,
|
||||
user_ids=member_ids,
|
||||
)
|
||||
|
||||
Groups.update_group_by_id(new_group.id, update_form)
|
||||
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
|
||||
|
||||
new_group = Groups.get_group_by_id(new_group.id)
|
||||
|
||||
return group_to_scim(new_group, request)
|
||||
|
|
@ -830,7 +834,7 @@ async def update_group(
|
|||
# Handle members if provided
|
||||
if group_data.members is not None:
|
||||
member_ids = [member.value for member in group_data.members]
|
||||
update_form.user_ids = member_ids
|
||||
Groups.set_group_user_ids_by_id(group_id, member_ids)
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
|
|
@ -863,7 +867,6 @@ async def patch_group(
|
|||
update_form = GroupUpdateForm(
|
||||
name=group.name,
|
||||
description=group.description,
|
||||
user_ids=group.user_ids.copy() if group.user_ids else [],
|
||||
)
|
||||
|
||||
for operation in patch_data.Operations:
|
||||
|
|
@ -876,21 +879,22 @@ async def patch_group(
|
|||
update_form.name = value
|
||||
elif path == "members":
|
||||
# Replace all members
|
||||
update_form.user_ids = [member["value"] for member in value]
|
||||
Groups.set_group_user_ids_by_id(
|
||||
group_id, [member["value"] for member in value]
|
||||
)
|
||||
|
||||
elif op == "add":
|
||||
if path == "members":
|
||||
# Add members
|
||||
if isinstance(value, list):
|
||||
for member in value:
|
||||
if isinstance(member, dict) and "value" in member:
|
||||
if member["value"] not in update_form.user_ids:
|
||||
update_form.user_ids.append(member["value"])
|
||||
Groups.add_users_to_group(group_id, [member["value"]])
|
||||
elif op == "remove":
|
||||
if path and path.startswith("members[value eq"):
|
||||
# Remove specific member
|
||||
member_id = path.split('"')[1]
|
||||
if member_id in update_form.user_ids:
|
||||
update_form.user_ids.remove(member_id)
|
||||
Groups.remove_users_from_group(group_id, [member_id])
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
|
|
|
|||
|
|
@ -1130,22 +1130,21 @@ class OAuthManager:
|
|||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids = [i for i in user_ids if i != user.id]
|
||||
Groups.remove_users_from_group(group_model.id, [user.id])
|
||||
|
||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||
group_permissions = group_model.permissions
|
||||
if not group_permissions:
|
||||
group_permissions = default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id,
|
||||
form_data=GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
),
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
# Add user to new groups
|
||||
|
|
@ -1161,22 +1160,21 @@ class OAuthManager:
|
|||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||
)
|
||||
|
||||
user_ids = group_model.user_ids
|
||||
user_ids.append(user.id)
|
||||
Groups.add_users_to_group(group_model.id, [user.id])
|
||||
|
||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||
group_permissions = group_model.permissions
|
||||
if not group_permissions:
|
||||
group_permissions = default_permissions
|
||||
|
||||
update_form = GroupUpdateForm(
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id,
|
||||
form_data=GroupUpdateForm(
|
||||
name=group_model.name,
|
||||
description=group_model.description,
|
||||
permissions=group_permissions,
|
||||
user_ids=user_ids,
|
||||
)
|
||||
Groups.update_group_by_id(
|
||||
id=group_model.id, form_data=update_form, overwrite=False
|
||||
),
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
async def _process_picture_url(
|
||||
|
|
|
|||
Loading…
Reference in a new issue