From 3f1d9ccbf8443a2fa5278f36202bad930a216680 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 30 Nov 2025 10:33:50 -0500 Subject: [PATCH] feat/enh: add/remove users from group channel --- backend/open_webui/models/channels.py | 227 ++++++++++++++---- backend/open_webui/models/users.py | 14 ++ backend/open_webui/routers/channels.py | 117 ++++++++- backend/open_webui/socket/main.py | 2 + src/lib/apis/channels/index.ts | 82 +++++++ src/lib/components/channel/Channel.svelte | 5 + .../channel/ChannelInfoModal.svelte | 52 +++- .../ChannelInfoModal/AddMembersModal.svelte | 96 ++++++++ .../channel/ChannelInfoModal/UserList.svelte | 80 ++++-- src/lib/components/channel/Navbar.svelte | 3 +- .../workspace/common/MemberSelector.svelte | 50 ++-- 11 files changed, 625 insertions(+), 103 deletions(-) create mode 100644 src/lib/components/channel/ChannelInfoModal/AddMembersModal.svelte diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 23d8c407d9..0364a123c7 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -5,7 +5,6 @@ from typing import Optional from open_webui.internal.db import Base, get_db from open_webui.models.groups import Groups -from open_webui.utils.access_control import has_access from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case @@ -175,7 +174,9 @@ class ChannelWebhookModel(BaseModel): class ChannelResponse(ChannelModel): + is_manager: bool = False write_access: bool = False + user_count: Optional[int] = None @@ -196,32 +197,42 @@ class CreateChannelForm(ChannelForm): class ChannelTable: - def _create_memberships_by_user_ids_and_group_ids( + def _collect_unique_user_ids( self, - channel_id: str, invited_by: str, user_ids: Optional[list[str]] = None, group_ids: Optional[list[str]] = None, - ) -> list[ChannelMemberModel]: - # For group and direct message channels, automatically add the specified users as members - user_ids = user_ids or [] - if invited_by not in user_ids: - user_ids.append(invited_by) # Ensure the creator is also a member + ) -> set[str]: + """ + Collect unique user ids from: + - invited_by + - user_ids + - each group in group_ids + Returns a set for efficient SQL diffing. + """ + users = set(user_ids or []) + users.add(invited_by) - # Add users from specified groups - group_ids = group_ids or [] - for group_id in group_ids: - group_user_ids = Groups.get_group_user_ids_by_id(group_id) - for uid in group_user_ids: - if uid not in user_ids: - user_ids.append(uid) + for group_id in group_ids or []: + users.update(Groups.get_group_user_ids_by_id(group_id)) - # Ensure uniqueness - user_ids = list(set(user_ids)) + return users + def _create_membership_models( + self, + channel_id: str, + invited_by: str, + user_ids: set[str], + ) -> list[ChannelMember]: + """ + Takes a set of NEW user IDs (already filtered to exclude existing members). + Returns ORM ChannelMember objects to be added. + """ + now = int(time.time_ns()) memberships = [] + for uid in user_ids: - channel_member = ChannelMemberModel( + model = ChannelMemberModel( **{ "id": str(uuid.uuid4()), "channel_id": channel_id, @@ -230,17 +241,16 @@ class ChannelTable: "is_active": True, "is_channel_muted": False, "is_channel_pinned": False, - "invited_at": int(time.time_ns()), + "invited_at": now, "invited_by": invited_by, - "joined_at": int(time.time_ns()), + "joined_at": now, "left_at": None, - "last_read_at": int(time.time_ns()), - "created_at": int(time.time_ns()), - "updated_at": int(time.time_ns()), + "last_read_at": now, + "created_at": now, + "updated_at": now, } ) - - memberships.append(ChannelMember(**channel_member.model_dump())) + memberships.append(ChannelMember(**model.model_dump())) return memberships @@ -262,14 +272,18 @@ class ChannelTable: new_channel = Channel(**channel.model_dump()) if form_data.type in ["group", "dm"]: - memberships = self._create_memberships_by_user_ids_and_group_ids( - channel.id, - user_id, - form_data.user_ids, - form_data.group_ids, + users = self._collect_unique_user_ids( + invited_by=user_id, + user_ids=form_data.user_ids, + group_ids=form_data.group_ids, + ) + memberships = self._create_membership_models( + channel_id=new_channel.id, + invited_by=user_id, + user_ids=users, ) - db.add_all(memberships) + db.add_all(memberships) db.add(new_channel) db.commit() return channel @@ -279,24 +293,71 @@ class ChannelTable: channels = db.query(Channel).all() return [ChannelModel.model_validate(channel) for channel in channels] - def get_channels_by_user_id( - self, user_id: str, permission: str = "read" - ) -> list[ChannelModel]: - channels = self.get_channels() + def _has_permission(self, query, filter: dict, permission: str = "read"): + group_ids = filter.get("group_ids", []) + user_id = filter.get("user_id") - channel_list = [] - for channel in channels: - if channel.type == "dm": - membership = self.get_member_by_channel_and_user_id(channel.id, user_id) - if membership and membership.is_active: - channel_list.append(channel) - else: - if channel.user_id == user_id or has_access( - user_id, permission, channel.access_control - ): - channel_list.append(channel) + json_group_ids = Channel.access_control[permission]["group_ids"] - return channel_list + conditions = [] + if group_ids or user_id: + conditions.append(Channel.access_control.is_(None)) + + if user_id: + conditions.append(Channel.user_id == user_id) + + if group_ids: + group_conditions = [] + + for gid in group_ids: + # CASE: gid IN JSON array + # SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%' + # Postgres → access_control->'write'->'group_ids' @> '[gid]' + group_conditions.append(json_group_ids.contains([gid])) + + conditions.append(or_(*group_conditions)) + + if conditions: + query = query.filter(or_(*conditions)) + + return query + + def get_channels_by_user_id(self, user_id: str) -> list[ChannelModel]: + with get_db() as db: + user_group_ids = [ + group.id for group in Groups.get_groups_by_member_id(user_id) + ] + + membership_channels = ( + db.query(Channel) + .join(ChannelMember, Channel.id == ChannelMember.channel_id) + .filter( + Channel.deleted_at.is_(None), + Channel.archived_at.is_(None), + Channel.type.in_(["group", "dm"]), + ChannelMember.user_id == user_id, + ChannelMember.is_active.is_(True), + ) + .all() + ) + + query = db.query(Channel).filter( + Channel.deleted_at.is_(None), + Channel.archived_at.is_(None), + or_( + Channel.type.is_(None), # True NULL/None + Channel.type == "", # Empty string + and_(Channel.type != "group", Channel.type != "dm"), + ), + ) + query = self._has_permission( + query, {"user_id": user_id, "group_ids": user_group_ids} + ) + + standard_channels = query.all() + + all_channels = membership_channels + standard_channels + return [ChannelModel.model_validate(c) for c in all_channels] def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]: with get_db() as db: @@ -331,6 +392,78 @@ class ChannelTable: return ChannelModel.model_validate(channel) if channel else None + def add_members_to_channel( + self, + channel_id: str, + invited_by: str, + user_ids: Optional[list[str]] = None, + group_ids: Optional[list[str]] = None, + ) -> list[ChannelMemberModel]: + with get_db() as db: + # 1. Collect all user_ids including groups + inviter + requested_users = self._collect_unique_user_ids( + invited_by, user_ids, group_ids + ) + + existing_users = { + row.user_id + for row in db.query(ChannelMember.user_id) + .filter(ChannelMember.channel_id == channel_id) + .all() + } + + new_user_ids = requested_users - existing_users + if not new_user_ids: + return [] # Nothing to add + + new_memberships = self._create_membership_models( + channel_id, invited_by, new_user_ids + ) + + db.add_all(new_memberships) + db.commit() + + return [ + ChannelMemberModel.model_validate(membership) + for membership in new_memberships + ] + + def remove_members_from_channel( + self, + channel_id: str, + user_ids: list[str], + ) -> int: + with get_db() as db: + result = ( + db.query(ChannelMember) + .filter( + ChannelMember.channel_id == channel_id, + ChannelMember.user_id.in_(user_ids), + ) + .delete(synchronize_session=False) + ) + db.commit() + return result # number of rows deleted + + def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool: + with get_db() as db: + # Check if the user is the creator of the channel + # or has a 'manager' role in ChannelMember + channel = db.query(Channel).filter(Channel.id == channel_id).first() + if channel and channel.user_id == user_id: + return True + + membership = ( + db.query(ChannelMember) + .filter( + ChannelMember.channel_id == channel_id, + ChannelMember.user_id == user_id, + ChannelMember.role == "manager", + ) + .first() + ) + return membership is not None + def join_channel( self, channel_id: str, user_id: str ) -> Optional[ChannelMemberModel]: diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index ede5f5e761..42918f59a5 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -7,6 +7,9 @@ from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.models.chats import Chats from open_webui.models.groups import Groups, GroupMember +from open_webui.models.channels import ChannelMember + + from open_webui.utils.misc import throttle @@ -311,6 +314,17 @@ class UsersTable: ) ) + channel_id = filter.get("channel_id") + if channel_id: + query = query.filter( + exists( + select(ChannelMember.id).where( + ChannelMember.user_id == User.id, + ChannelMember.channel_id == channel_id, + ) + ) + ) + user_ids = filter.get("user_ids") group_ids = filter.get("group_ids") diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index fa71698f90..7e4b8a3fdd 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -261,6 +261,7 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)): **channel.model_dump(), "user_ids": user_ids, "users": users, + "is_manager": Channels.is_user_channel_manager(channel.id, user.id), "write_access": True, "user_count": len(user_ids), "last_read_at": channel_member.last_read_at if channel_member else None, @@ -291,6 +292,7 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)): **channel.model_dump(), "user_ids": user_ids, "users": users, + "is_manager": Channels.is_user_channel_manager(channel.id, user.id), "write_access": write_access or user.role == "admin", "user_count": user_count, "last_read_at": channel_member.last_read_at if channel_member else None, @@ -334,6 +336,7 @@ async def get_channel_members_by_id( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) + if channel.type == "dm": user_ids = [ member.user_id for member in Channels.get_members_by_channel_id(channel.id) ] @@ -349,11 +352,8 @@ async def get_channel_members_by_id( ], "total": total, } - else: - filter = { - "roles": ["!pending"], - } + filter = {} if query: filter["query"] = query @@ -362,10 +362,16 @@ async def get_channel_members_by_id( if direction: filter["direction"] = direction - permitted_ids = get_permitted_group_and_user_ids("read", channel.access_control) - if permitted_ids: - filter["user_ids"] = permitted_ids.get("user_ids") - filter["group_ids"] = permitted_ids.get("group_ids") + if channel.type == "group": + filter["channel_id"] = channel.id + else: + filter["roles"] = ["!pending"] + permitted_ids = get_permitted_group_and_user_ids( + "read", channel.access_control + ) + if permitted_ids: + filter["user_ids"] = permitted_ids.get("user_ids") + filter["group_ids"] = permitted_ids.get("group_ids") result = Users.get_users(filter=filter, skip=skip, limit=limit) @@ -413,6 +419,101 @@ async def update_is_active_member_by_id_and_user_id( return True +################################################# +# AddMembersById +################################################# + + +class UpdateMembersForm(BaseModel): + user_ids: list[str] = [] + group_ids: list[str] = [] + + +@router.post("/{id}/update/members/add") +async def add_members_by_id( + request: Request, + id: str, + form_data: UpdateMembersForm, + user=Depends(get_verified_user), +): + if user.role != "admin" and not has_permission( + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if channel.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + memberships = Channels.add_members_to_channel( + channel.id, user.id, form_data.user_ids, form_data.group_ids + ) + + return memberships + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +################################################# +# +################################################# + + +class RemoveMembersForm(BaseModel): + user_ids: list[str] = [] + + +@router.post("/{id}/update/members/remove") +async def remove_members_by_id( + request: Request, + id: str, + form_data: RemoveMembersForm, + user=Depends(get_verified_user), +): + if user.role != "admin" and not has_permission( + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.UNAUTHORIZED, + ) + + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if channel.user_id != user.id and user.role != "admin": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids) + + return deleted + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # UpdateChannelById ############################ diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 84705648d9..6e47a058b6 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -309,11 +309,13 @@ async def user_join(sid, data): ) await sio.enter_room(sid, f"user:{user.id}") + # Join all the channels channels = Channels.get_channels_by_user_id(user.id) log.debug(f"{channels=}") for channel in channels: await sio.enter_room(sid, f"channel:{channel.id}") + return {"id": user.id, "name": user.name} diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 4b444bb6ba..549c9004a1 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -194,6 +194,88 @@ export const updateChannelMemberActiveStatusById = async ( return res; }; +type UpdateMembersForm = { + user_ids?: string[]; + group_ids?: string[]; +}; + +export const addMembersById = async ( + token: string = '', + channel_id: string, + formData: UpdateMembersForm +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/update/members/add`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ ...formData }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +type RemoveMembersForm = { + user_ids?: string[]; + group_ids?: string[]; +}; + +export const removeMembersById = async ( + token: string = '', + channel_id: string, + formData: RemoveMembersForm +) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/channels/${channel_id}/update/members/remove`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ ...formData }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateChannelById = async ( token: string = '', channel_id: string, diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 24ccd65cef..0fdc6e835d 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -309,6 +309,11 @@ return message; }); }} + onUpdate={async () => { + channel = await getChannelById(localStorage.token, id).catch((error) => { + return null; + }); + }} /> {#if channel && messages !== null} diff --git a/src/lib/components/channel/ChannelInfoModal.svelte b/src/lib/components/channel/ChannelInfoModal.svelte index 4ed79828c0..44094f7801 100644 --- a/src/lib/components/channel/ChannelInfoModal.svelte +++ b/src/lib/components/channel/ChannelInfoModal.svelte @@ -3,22 +3,41 @@ import { getContext, onMount } from 'svelte'; const i18n = getContext('i18n'); + import { removeMembersById } from '$lib/apis/channels'; + import Spinner from '$lib/components/common/Spinner.svelte'; import Modal from '$lib/components/common/Modal.svelte'; - import UserPlusSolid from '$lib/components/icons/UserPlusSolid.svelte'; - import WrenchSolid from '$lib/components/icons/WrenchSolid.svelte'; - import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; import XMark from '$lib/components/icons/XMark.svelte'; import Hashtag from '../icons/Hashtag.svelte'; import Lock from '../icons/Lock.svelte'; import UserList from './ChannelInfoModal/UserList.svelte'; + import AddMembersModal from './ChannelInfoModal/AddMembersModal.svelte'; export let show = false; export let channel = null; + export let onUpdate = () => {}; + + let showAddMembersModal = false; const submitHandler = async () => {}; + const removeMemberHandler = async (userId) => { + const res = await removeMembersById(localStorage.token, channel.id, { + user_ids: [userId] + }).catch((error) => { + toast.error(`${error}`); + return null; + }); + + if (res) { + toast.success($i18n.t('Member removed successfully')); + onUpdate(); + } else { + toast.error($i18n.t('Failed to remove member')); + } + }; + const init = () => {}; $: if (show) { @@ -31,15 +50,14 @@ {#if channel} +
{#if channel?.type === 'dm'} -
+
{$i18n.t('Direct Message')}
{:else} @@ -51,9 +69,7 @@ {/if}
-
+
{channel.name}
{/if} @@ -69,7 +85,7 @@
-
+
- + { + showAddMembersModal = true; + } + : null} + onRemove={channel?.type === 'group' && channel?.is_manager + ? (userId) => { + removeMemberHandler(userId); + } + : null} + search={channel?.type !== 'dm'} + sort={channel?.type !== 'dm'} + />
diff --git a/src/lib/components/channel/ChannelInfoModal/AddMembersModal.svelte b/src/lib/components/channel/ChannelInfoModal/AddMembersModal.svelte new file mode 100644 index 0000000000..241f863591 --- /dev/null +++ b/src/lib/components/channel/ChannelInfoModal/AddMembersModal.svelte @@ -0,0 +1,96 @@ + + +{#if channel} + +
+
+
+
+ {$i18n.t('Add Members')} +
+
+ +
+ +
+
+
{ + e.preventDefault(); + submitHandler(); + }} + > +
+ +
+ +
+ +
+
+
+
+
+
+{/if} diff --git a/src/lib/components/channel/ChannelInfoModal/UserList.svelte b/src/lib/components/channel/ChannelInfoModal/UserList.svelte index 612ab139b1..7f283991b1 100644 --- a/src/lib/components/channel/ChannelInfoModal/UserList.svelte +++ b/src/lib/components/channel/ChannelInfoModal/UserList.svelte @@ -1,6 +1,6 @@ @@ -96,10 +94,33 @@
{:else} +
+
+ + {$i18n.t('Members')} + + {total} +
+ + {#if onAdd} +
+ +
+ {/if} +
+ + {#if search} -
+
-
+
0}
-
-
+
-->
{#each users as user, userIdx (user.id)}
-
+
-
-
+
+
+ + {#if onRemove} +
+ +
+ {/if}
{/each} diff --git a/src/lib/components/channel/Navbar.svelte b/src/lib/components/channel/Navbar.svelte index 67f4d92610..00ffb87b36 100644 --- a/src/lib/components/channel/Navbar.svelte +++ b/src/lib/components/channel/Navbar.svelte @@ -29,10 +29,11 @@ export let channel; export let onPin = (messageId, pinned) => {}; + export let onUpdate = () => {}; - +