diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index c80e8f645a..c1a4e9c3f5 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -6,7 +6,7 @@ from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.models.chats import Chats -from open_webui.models.groups import Groups +from open_webui.models.groups import Groups, GroupMember from open_webui.utils.misc import throttle @@ -95,8 +95,12 @@ class UpdateProfileForm(BaseModel): date_of_birth: Optional[datetime.date] = None +class UserGroupIdsModel(UserModel): + group_ids: list[str] = [] + + class UserListResponse(BaseModel): - users: list[UserModel] + users: list[UserGroupIdsModel] total: int @@ -222,7 +226,10 @@ class UsersTable: limit: Optional[int] = None, ) -> dict: with get_db() as db: - query = db.query(User) + # Join GroupMember so we can order by group_id when requested + query = db.query(User).outerjoin( + GroupMember, GroupMember.user_id == User.id + ) if filter: query_key = filter.get("query") @@ -237,7 +244,16 @@ class UsersTable: order_by = filter.get("order_by") direction = filter.get("direction") - if order_by == "name": + if order_by and order_by.startswith("group_id:"): + group_id = order_by.split(":", 1)[1] + + if direction == "asc": + query = query.order_by((GroupMember.group_id == group_id).asc()) + else: + query = query.order_by( + (GroupMember.group_id == group_id).desc() + ) + elif order_by == "name": if direction == "asc": query = query.order_by(User.name.asc()) else: diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 9ee3f9f88c..9d95c3d71a 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -16,6 +16,7 @@ from open_webui.models.groups import Groups from open_webui.models.chats import Chats from open_webui.models.users import ( UserModel, + UserGroupIdsModel, UserListResponse, UserInfoListResponse, UserIdNameListResponse, @@ -91,7 +92,25 @@ async def get_users( if direction: filter["direction"] = direction - return Users.get_users(filter=filter, skip=skip, limit=limit) + result = Users.get_users(filter=filter, skip=skip, limit=limit) + + users = result["users"] + total = result["total"] + + return { + "users": [ + UserGroupIdsModel( + **{ + **user.model_dump(), + "group_ids": [ + group.id for group in Groups.get_groups_by_member_id(user.id) + ], + } + ) + for user in users + ], + "total": total, + } @router.get("/all", response_model=UserInfoListResponse) diff --git a/src/lib/apis/groups/index.ts b/src/lib/apis/groups/index.ts index c55f477af5..51b49bf4d9 100644 --- a/src/lib/apis/groups/index.ts +++ b/src/lib/apis/groups/index.ts @@ -160,3 +160,73 @@ export const deleteGroupById = async (token: string, id: string) => { return res; }; + +export const addUserToGroup = async (token: string, id: string, userIds: string[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/users/add`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + user_ids: userIds + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const removeUserFromGroup = async (token: string, id: string, userIds: string[]) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/groups/id/${id}/users/remove`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + user_ids: userIds + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/admin/Users/Groups/EditGroupModal.svelte b/src/lib/components/admin/Users/Groups/EditGroupModal.svelte index 7f753c537c..af5e732499 100644 --- a/src/lib/components/admin/Users/Groups/EditGroupModal.svelte +++ b/src/lib/components/admin/Users/Groups/EditGroupModal.svelte @@ -219,7 +219,7 @@