diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index d9c0ff3a04..c3f7eef101 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -95,7 +95,7 @@ class MessageResponse(MessageModel): class MessageTable: - def insert_new_message( + async def insert_new_message( self, form_data: MessageForm, channel_id: str, user_id: str ) -> Optional[MessageModel]: async with get_db() as db: @@ -117,19 +117,19 @@ class MessageTable: ) result = Message(**message.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) return MessageModel.model_validate(result) if result else None - def get_message_by_id(self, id: str) -> Optional[MessageResponse]: + async def get_message_by_id(self, id: str) -> Optional[MessageResponse]: async with get_db() as db: - message = db.get(Message, id) + message = await db.get(Message, id) if not message: return None - reactions = self.get_reactions_by_message_id(id) - replies = self.get_replies_by_message_id(id) + reactions = await self.get_reactions_by_message_id(id) + replies = await self.get_replies_by_message_id(id) return MessageResponse( **{ @@ -140,29 +140,29 @@ class MessageTable: } ) - def get_replies_by_message_id(self, id: str) -> list[MessageModel]: + async def get_replies_by_message_id(self, id: str) -> list[MessageModel]: async with get_db() as db: all_messages = ( - db.query(Message) + await db.query(Message) .filter_by(parent_id=id) .order_by(Message.created_at.desc()) .all() ) return [MessageModel.model_validate(message) for message in all_messages] - def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: + async def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: async with get_db() as db: return [ message.user_id - for message in db.query(Message).filter_by(parent_id=id).all() + for message in await db.query(Message).filter_by(parent_id=id).all() ] - def get_messages_by_channel_id( + async def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 ) -> list[MessageModel]: async with get_db() as db: all_messages = ( - db.query(Message) + await db.query(Message) .filter_by(channel_id=channel_id, parent_id=None) .order_by(Message.created_at.desc()) .offset(skip) @@ -171,17 +171,17 @@ class MessageTable: ) return [MessageModel.model_validate(message) for message in all_messages] - def get_messages_by_parent_id( + async def get_messages_by_parent_id( self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 ) -> list[MessageModel]: async with get_db() as db: - message = db.get(Message, parent_id) + message = await db.get(Message, parent_id) if not message: return [] all_messages = ( - db.query(Message) + await db.query(Message) .filter_by(channel_id=channel_id, parent_id=parent_id) .order_by(Message.created_at.desc()) .offset(skip) @@ -195,20 +195,20 @@ class MessageTable: return [MessageModel.model_validate(message) for message in all_messages] - def update_message_by_id( + async def update_message_by_id( self, id: str, form_data: MessageForm ) -> Optional[MessageModel]: async with get_db() as db: - message = db.get(Message, id) + message = await db.get(Message, id) message.content = form_data.content message.data = form_data.data message.meta = form_data.meta message.updated_at = int(time.time_ns()) - db.commit() - db.refresh(message) + await db.commit() + await db.refresh(message) return MessageModel.model_validate(message) if message else None - def add_reaction_to_message( + async def add_reaction_to_message( self, id: str, user_id: str, name: str ) -> Optional[MessageReactionModel]: async with get_db() as db: @@ -221,14 +221,16 @@ class MessageTable: created_at=int(time.time_ns()), ) result = MessageReaction(**reaction.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) return MessageReactionModel.model_validate(result) if result else None - def get_reactions_by_message_id(self, id: str) -> list[Reactions]: + async def get_reactions_by_message_id(self, id: str) -> list[Reactions]: async with get_db() as db: - all_reactions = db.query(MessageReaction).filter_by(message_id=id).all() + all_reactions = ( + await db.query(MessageReaction).filter_by(message_id=id).all() + ) reactions = {} for reaction in all_reactions: @@ -243,36 +245,36 @@ class MessageTable: return [Reactions(**reaction) for reaction in reactions.values()] - def remove_reaction_by_id_and_user_id_and_name( + async def remove_reaction_by_id_and_user_id_and_name( self, id: str, user_id: str, name: str ) -> bool: async with get_db() as db: - db.query(MessageReaction).filter_by( + await db.query(MessageReaction).filter_by( message_id=id, user_id=user_id, name=name ).delete() - db.commit() + await db.commit() return True - def delete_reactions_by_id(self, id: str) -> bool: + async def delete_reactions_by_id(self, id: str) -> bool: async with get_db() as db: - db.query(MessageReaction).filter_by(message_id=id).delete() - db.commit() + await db.query(MessageReaction).filter_by(message_id=id).delete() + await db.commit() return True - def delete_replies_by_id(self, id: str) -> bool: + async def delete_replies_by_id(self, id: str) -> bool: async with get_db() as db: - db.query(Message).filter_by(parent_id=id).delete() - db.commit() + await db.query(Message).filter_by(parent_id=id).delete() + await db.commit() return True - def delete_message_by_id(self, id: str) -> bool: + async def delete_message_by_id(self, id: str) -> bool: async with get_db() as db: - db.query(Message).filter_by(id=id).delete() + await db.query(Message).filter_by(id=id).delete() # Delete all reactions to this message - db.query(MessageReaction).filter_by(message_id=id).delete() + await db.query(MessageReaction).filter_by(message_id=id).delete() - db.commit() + await db.commit() return True diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 39958a0ccb..020c6ba0e0 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -164,7 +164,7 @@ async def get_channel_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message_list = Messages.get_messages_by_channel_id(id, skip, limit) + message_list = await Messages.get_messages_by_channel_id(id, skip, limit) users = {} messages = [] @@ -173,7 +173,7 @@ async def get_channel_messages( user = await Users.get_user_by_id(message.user_id) users[message.user_id] = user - replies = Messages.get_replies_by_message_id(message.id) + replies = await Messages.get_replies_by_message_id(message.id) latest_reply_at = replies[0].created_at if replies else None messages.append( @@ -182,7 +182,7 @@ async def get_channel_messages( **message.model_dump(), "reply_count": len(replies), "latest_reply_at": latest_reply_at, - "reactions": Messages.get_reactions_by_message_id(message.id), + "reactions": await Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -244,7 +244,7 @@ async def post_new_message( ) try: - message = Messages.insert_new_message(form_data, channel.id, user.id) + message = await Messages.insert_new_message(form_data, channel.id, user.id) if message: event_data = { @@ -257,7 +257,7 @@ async def post_new_message( **message.model_dump(), "reply_count": 0, "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id( + "reactions": await Messages.get_reactions_by_message_id( message.id ), "user": UserNameResponse(**user.model_dump()), @@ -276,7 +276,7 @@ async def post_new_message( if message.parent_id: # If this message is a reply, emit to the parent message as well - parent_message = Messages.get_message_by_id(message.parent_id) + parent_message = await Messages.get_message_by_id(message.parent_id) if parent_message: await sio.emit( @@ -348,7 +348,7 @@ async def get_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = await Messages.get_message_by_id(message_id) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -397,7 +397,7 @@ async def get_channel_thread_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit) + message_list = await Messages.get_messages_by_parent_id(id, message_id, skip, limit) users = {} messages = [] @@ -412,7 +412,7 @@ async def get_channel_thread_messages( **message.model_dump(), "reply_count": 0, "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id(message.id), + "reactions": await Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -438,7 +438,7 @@ async def update_message_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - message = Messages.get_message_by_id(message_id) + message = await Messages.get_message_by_id(message_id) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -461,8 +461,8 @@ async def update_message_by_id( ) try: - message = Messages.update_message_by_id(message_id, form_data) - message = Messages.get_message_by_id(message_id) + message = await Messages.update_message_by_id(message_id, form_data) + message = await Messages.get_message_by_id(message_id) if message: await sio.emit( @@ -521,7 +521,7 @@ async def add_reaction_to_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = await Messages.get_message_by_id(message_id) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -533,8 +533,8 @@ async def add_reaction_to_message( ) try: - Messages.add_reaction_to_message(message_id, user.id, form_data.name) - message = Messages.get_message_by_id(message_id) + await Messages.add_reaction_to_message(message_id, user.id, form_data.name) + message = await Messages.get_message_by_id(message_id) await sio.emit( "channel-events", @@ -587,7 +587,7 @@ async def remove_reaction_by_id_and_user_id_and_name( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = await Messages.get_message_by_id(message_id) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -599,11 +599,11 @@ async def remove_reaction_by_id_and_user_id_and_name( ) try: - Messages.remove_reaction_by_id_and_user_id_and_name( + await Messages.remove_reaction_by_id_and_user_id_and_name( message_id, user.id, form_data.name ) - message = Messages.get_message_by_id(message_id) + message = await Messages.get_message_by_id(message_id) await sio.emit( "channel-events", @@ -649,7 +649,7 @@ async def delete_message_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - message = Messages.get_message_by_id(message_id) + message = await Messages.get_message_by_id(message_id) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -672,7 +672,7 @@ async def delete_message_by_id( ) try: - Messages.delete_message_by_id(message_id) + await Messages.delete_message_by_id(message_id) await sio.emit( "channel-events", { @@ -693,7 +693,7 @@ async def delete_message_by_id( if message.parent_id: # If this message is a reply, emit to the parent message as well - parent_message = Messages.get_message_by_id(message.parent_id) + parent_message = await Messages.get_message_by_id(message.parent_id) if parent_message: await sio.emit(