diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index e5c0612639..a7900e2c78 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -11,7 +11,18 @@ from open_webui.models.files import FileMetadataResponse from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey +from sqlalchemy import ( + BigInteger, + Column, + String, + Text, + JSON, + and_, + func, + ForeignKey, + cast, + or_, +) log = logging.getLogger(__name__) @@ -41,7 +52,6 @@ class Group(Base): class GroupModel(BaseModel): - model_config = ConfigDict(from_attributes=True) id: str user_id: str @@ -56,6 +66,8 @@ class GroupModel(BaseModel): created_at: int # timestamp in epoch updated_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) + class GroupMember(Base): __tablename__ = "group_member" @@ -84,17 +96,8 @@ class GroupMemberModel(BaseModel): #################### -class GroupResponse(BaseModel): - id: str - user_id: str - name: str - description: str - permissions: Optional[dict] = None - data: Optional[dict] = None - meta: Optional[dict] = None +class GroupResponse(GroupModel): member_count: Optional[int] = None - created_at: int # timestamp in epoch - updated_at: int # timestamp in epoch class GroupForm(BaseModel): @@ -112,6 +115,11 @@ class GroupUpdateForm(GroupForm): pass +class GroupListResponse(BaseModel): + items: list[GroupResponse] = [] + total: int = 0 + + class GroupTable: def insert_new_group( self, user_id: str, form_data: GroupForm @@ -140,13 +148,87 @@ class GroupTable: except Exception: return None - def get_groups(self) -> list[GroupModel]: + def get_all_groups(self) -> list[GroupModel]: with get_db() as db: + groups = db.query(Group).order_by(Group.updated_at.desc()).all() + return [GroupModel.model_validate(group) for group in groups] + + def get_groups(self, filter) -> list[GroupResponse]: + with get_db() as db: + query = db.query(Group) + + if filter: + if "query" in filter: + query = query.filter(Group.name.ilike(f"%{filter['query']}%")) + if "member_id" in filter: + query = query.join( + GroupMember, GroupMember.group_id == Group.id + ).filter(GroupMember.user_id == filter["member_id"]) + + if "share" in filter: + share_value = filter["share"] + json_share = Group.data["config"]["share"].as_boolean() + + if share_value: + query = query.filter( + or_( + Group.data.is_(None), + json_share.is_(None), + json_share == True, + ) + ) + else: + query = query.filter( + and_(Group.data.isnot(None), json_share == False) + ) + groups = query.order_by(Group.updated_at.desc()).all() return [ - GroupModel.model_validate(group) - for group in db.query(Group).order_by(Group.updated_at.desc()).all() + GroupResponse.model_validate( + { + **GroupModel.model_validate(group).model_dump(), + "member_count": self.get_group_member_count_by_id(group.id), + } + ) + for group in groups ] + def search_groups( + self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30 + ) -> GroupListResponse: + with get_db() as db: + query = db.query(Group) + + if filter: + if "query" in filter: + query = query.filter(Group.name.ilike(f"%{filter['query']}%")) + if "member_id" in filter: + query = query.join( + GroupMember, GroupMember.group_id == Group.id + ).filter(GroupMember.user_id == filter["member_id"]) + + if "share" in filter: + # 'share' is stored in data JSON, support both sqlite and postgres + share_value = filter["share"] + print("Filtering by share:", share_value) + query = query.filter( + Group.data.op("->>")("share") == str(share_value) + ) + + total = query.count() + query = query.order_by(Group.updated_at.desc()) + groups = query.offset(skip).limit(limit).all() + + return { + "items": [ + GroupResponse.model_validate( + **GroupModel.model_validate(group).model_dump(), + member_count=self.get_group_member_count_by_id(group.id), + ) + for group in groups + ], + "total": total, + } + def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: with get_db() as db: return [ @@ -293,7 +375,7 @@ class GroupTable: ) -> list[GroupModel]: # check for existing groups - existing_groups = self.get_groups() + existing_groups = self.get_all_groups() existing_group_names = {group.name for group in existing_groups} new_groups = [] diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index b68db3a15e..05d52c5c7b 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -32,31 +32,17 @@ router = APIRouter() @router.get("/", response_model=list[GroupResponse]) async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): - if user.role == "admin": - groups = Groups.get_groups() - else: - groups = Groups.get_groups_by_member_id(user.id) - group_list = [] + filter = {} + if user.role != "admin": + filter["member_id"] = user.id - for group in groups: - if share is not None: - # Check if the group has data and a config with share key - if ( - group.data - and "share" in group.data.get("config", {}) - and group.data["config"]["share"] != share - ): - continue + if share is not None: + filter["share"] = share - group_list.append( - GroupResponse( - **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id), - ) - ) + groups = Groups.get_groups(filter=filter) - return group_list + return groups ############################ diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index b5d0e029ec..c2ee4d1c35 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -719,7 +719,7 @@ async def get_groups( ): """List SCIM Groups""" # Get all groups - groups_list = Groups.get_groups() + groups_list = Groups.get_all_groups() # Apply pagination total = len(groups_list) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 6bd955e90c..9cd329a861 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1102,7 +1102,7 @@ class OAuthManager: user_oauth_groups = [] user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) - all_available_groups: list[GroupModel] = Groups.get_groups() + all_available_groups: list[GroupModel] = Groups.get_all_groups() # Create groups if they don't exist and creation is enabled if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: @@ -1146,7 +1146,7 @@ class OAuthManager: # Refresh the list of all available groups if any were created if groups_created: - all_available_groups = Groups.get_groups() + all_available_groups = Groups.get_all_groups() log.debug("Refreshed list of all available groups after creation.") log.debug(f"Oauth Groups claim: {oauth_claim}") diff --git a/src/lib/components/admin/Users/Groups.svelte b/src/lib/components/admin/Users/Groups.svelte index 65e4d4d120..3239a3f462 100644 --- a/src/lib/components/admin/Users/Groups.svelte +++ b/src/lib/components/admin/Users/Groups.svelte @@ -100,6 +100,7 @@