diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 7901f3af66..6ed49ba597 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -111,6 +111,10 @@ class MessageReplyToResponse(MessageUserResponse): reply_to_message: Optional[MessageUserResponse] = None +class MessageWithReactionsResponse(MessageUserResponse): + reactions: list[Reactions] + + class MessageResponse(MessageReplyToResponse): latest_reply_at: Optional[int] reply_count: int @@ -306,6 +310,20 @@ class MessageTable: ) return MessageModel.model_validate(message) if message else None + def get_pinned_messages_by_channel_id( + self, channel_id: str, skip: int = 0, limit: int = 50 + ) -> list[MessageModel]: + with get_db() as db: + all_messages = ( + db.query(Message) + .filter_by(channel_id=channel_id, is_pinned=True) + .order_by(Message.pinned_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + return [MessageModel.model_validate(message) for message in all_messages] + def update_message_by_id( self, id: str, form_data: MessageForm ) -> Optional[MessageModel]: @@ -325,7 +343,7 @@ class MessageTable: db.refresh(message) return MessageModel.model_validate(message) if message else None - def update_message_pin_by_id( + def update_is_pinned_by_id( self, id: str, is_pinned: bool, pinned_by: Optional[str] = None ) -> Optional[MessageModel]: with get_db() as db: @@ -333,7 +351,6 @@ class MessageTable: message.is_pinned = is_pinned message.pinned_at = int(time.time_ns()) if is_pinned else None message.pinned_by = pinned_by if is_pinned else None - message.updated_at = int(time.time_ns()) db.commit() db.refresh(message) return MessageModel.model_validate(message) if message else None diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 394c9f0009..f6e3ebe47a 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -31,6 +31,7 @@ from open_webui.models.messages import ( Messages, MessageModel, MessageResponse, + MessageWithReactionsResponse, MessageForm, ) @@ -463,6 +464,62 @@ async def get_channel_messages( return messages +############################ +# GetPinnedChannelMessages +############################ + +PAGE_ITEM_COUNT_PINNED = 20 + + +@router.get("/{id}/messages/pinned", response_model=list[MessageWithReactionsResponse]) +async def get_pinned_channel_messages( + id: str, page: int = 1, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if channel.type == "dm": + if not Channels.is_user_channel_member(channel.id, user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + else: + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + page = max(1, page) + skip = (page - 1) * PAGE_ITEM_COUNT_PINNED + limit = PAGE_ITEM_COUNT_PINNED + + message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit) + users = {} + + messages = [] + for message in message_list: + if message.user_id not in users: + user = Users.get_user_by_id(message.user_id) + users[message.user_id] = user + + messages.append( + MessageWithReactionsResponse( + **{ + **message.model_dump(), + "reactions": Messages.get_reactions_by_message_id(message.id), + "user": UserNameResponse(**users[message.user_id].model_dump()), + } + ) + ) + + return messages + + ############################ # PostNewMessage ############################ @@ -834,6 +891,69 @@ async def get_channel_message( ) +############################ +# PinChannelMessage +############################ + + +class PinMessageForm(BaseModel): + is_pinned: bool + + +@router.post( + "/{id}/messages/{message_id}/pin", response_model=Optional[MessageUserResponse] +) +async def pin_channel_message( + id: str, message_id: str, form_data: PinMessageForm, user=Depends(get_verified_user) +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if channel.type == "dm": + if not Channels.is_user_channel_member(channel.id, user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + else: + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message = Messages.get_message_by_id(message_id) + if not message: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if message.channel_id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + try: + Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id) + message = Messages.get_message_by_id(message_id) + return MessageUserResponse( + **{ + **message.model_dump(), + "user": UserNameResponse( + **Users.get_user_by_id(message.user_id).model_dump() + ), + } + ) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetChannelThreadMessages ############################ diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 5b510491fe..7a954a7507 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -299,6 +299,44 @@ export const getChannelMessages = async ( return res; }; +export const getChannelPinnedMessages = async ( + token: string = '', + channel_id: string, + page: number = 1 +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/pinned?page=${page}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getChannelThreadMessages = async ( token: string = '', channel_id: string, @@ -379,6 +417,46 @@ export const sendMessage = async (token: string = '', channel_id: string, messag return res; }; +export const pinMessage = async ( + token: string = '', + channel_id: string, + message_id: string, + is_pinned: boolean +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/pin`, + { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ is_pinned }) + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.error(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const updateMessage = async ( token: string = '', channel_id: string, diff --git a/src/lib/components/channel/Messages.svelte b/src/lib/components/channel/Messages.svelte index 9127fd8c38..b8e6dbb9ea 100644 --- a/src/lib/components/channel/Messages.svelte +++ b/src/lib/components/channel/Messages.svelte @@ -16,7 +16,13 @@ import Message from './Messages/Message.svelte'; import Loader from '../common/Loader.svelte'; import Spinner from '../common/Spinner.svelte'; - import { addReaction, deleteMessage, removeReaction, updateMessage } from '$lib/apis/channels'; + import { + addReaction, + deleteMessage, + pinMessage, + removeReaction, + updateMessage + } from '$lib/apis/channels'; import { WEBUI_API_BASE_URL } from '$lib/constants'; const i18n = getContext('i18n'); @@ -155,6 +161,26 @@ onReply={(message) => { onReply(message); }} + onPin={async (message) => { + messages = messages.map((m) => { + if (m.id === message.id) { + m.is_pinned = !m.is_pinned; + m.pinned_by = !m.is_pinned ? null : $user?.id; + m.pinned_at = !m.is_pinned ? null : Date.now() * 1000000; + } + return m; + }); + + const updatedMessage = await pinMessage( + localStorage.token, + message.channel_id, + message.id, + message.is_pinned + ).catch((error) => { + toast.error(`${error}`); + return null; + }); + }} onThread={(id) => { onThread(id); }} diff --git a/src/lib/components/channel/Messages/Message.svelte b/src/lib/components/channel/Messages/Message.svelte index 5379a7be00..5ed17336c9 100644 --- a/src/lib/components/channel/Messages/Message.svelte +++ b/src/lib/components/channel/Messages/Message.svelte @@ -36,6 +36,10 @@ import Emoji from '$lib/components/common/Emoji.svelte'; import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte'; import ArrowUpLeftAlt from '$lib/components/icons/ArrowUpLeftAlt.svelte'; + import PinSlash from '$lib/components/icons/PinSlash.svelte'; + import Pin from '$lib/components/icons/Pin.svelte'; + + export let className = ''; export let message; export let showUserProfile = true; @@ -47,6 +51,7 @@ export let onDelete: Function = () => {}; export let onEdit: Function = () => {}; export let onReply: Function = () => {}; + export let onPin: Function = () => {}; export let onThread: Function = () => {}; export let onReaction: Function = () => {}; @@ -69,13 +74,17 @@ {#if message}
{#if !edit && !disabled} @@ -85,37 +94,56 @@
- (showButtons = false)} - onSubmit={(name) => { - showButtons = false; - onReaction(name); - }} - > - - - - - - - + + + {/if} + + {#if onReply} + + + + {/if} + + + - {#if !thread} + {#if !thread && onThread} - + {#if onEdit} + + + + {/if} - - - + {#if onDelete} + + + + {/if} {/if}
{/if} + {#if message?.is_pinned} +
+
+ + {$i18n.t('Pinned')} +
+
+ {/if} + {#if message?.reply_to_message?.user}
{/if} +
-
+
{#if showUserProfile} {#if message?.meta?.model_id} -
+
{#if showUserProfile}
diff --git a/src/lib/components/channel/Navbar.svelte b/src/lib/components/channel/Navbar.svelte index 13e4b5c415..b8d6c81807 100644 --- a/src/lib/components/channel/Navbar.svelte +++ b/src/lib/components/channel/Navbar.svelte @@ -18,16 +18,20 @@ import UserAlt from '../icons/UserAlt.svelte'; import ChannelInfoModal from './ChannelInfoModal.svelte'; import Users from '../icons/Users.svelte'; + import Pin from '../icons/Pin.svelte'; + import PinnedMessagesModal from './PinnedMessagesModal.svelte'; const i18n = getContext('i18n'); + let showChannelPinnedMessagesModal = false; let showChannelInfoModal = false; export let channel; + -