From 5649a668fad15393a52c27a2f188841af8b66989 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 29 Dec 2025 01:42:13 +0400 Subject: [PATCH] refac --- backend/open_webui/routers/channels.py | 61 ++++++++++++++------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 50ad88b546..79aad9fb5d 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -353,14 +353,14 @@ async def get_channel_by_id( ) user_ids = [ - member.user_id for member in Channels.get_members_by_channel_id(channel.id) + member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db) ] users = [ UserIdNameStatusResponse( - **{**user.model_dump(), "is_active": Users.is_user_active(user.id)} + **{**user.model_dump(), "is_active": Users.is_user_active(user.id, db=db)} ) - for user in Users.get_users_by_user_ids(user_ids) + for user in Users.get_users_by_user_ids(user_ids, db=db) ] channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) @@ -453,15 +453,15 @@ async def get_channel_members_by_id( if channel.type == "dm": user_ids = [ - member.user_id for member in Channels.get_members_by_channel_id(channel.id) + member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db) ] - users = Users.get_users_by_user_ids(user_ids) + users = Users.get_users_by_user_ids(user_ids, db=db) total = len(users) return { "users": [ UserModelResponse( - **user.model_dump(), is_active=Users.is_user_active(user.id) + **user.model_dump(), is_active=Users.is_user_active(user.id, db=db) ) for user in users ], @@ -488,7 +488,7 @@ async def get_channel_members_by_id( filter["user_ids"] = permitted_ids.get("user_ids") filter["group_ids"] = permitted_ids.get("group_ids") - result = Users.get_users(filter=filter, skip=skip, limit=limit) + result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db) users = result["users"] total = result["total"] @@ -496,7 +496,7 @@ async def get_channel_members_by_id( return { "users": [ UserModelResponse( - **user.model_dump(), is_active=Users.is_user_active(user.id) + **user.model_dump(), is_active=Users.is_user_active(user.id, db=db) ) for user in users ], @@ -533,7 +533,7 @@ async def update_is_active_member_by_id_and_user_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - Channels.update_member_active_status(channel.id, user.id, form_data.is_active) + Channels.update_member_active_status(channel.id, user.id, form_data.is_active, db=db) return True @@ -577,7 +577,7 @@ async def add_members_by_id( try: memberships = Channels.add_members_to_channel( - channel.id, user.id, form_data.user_ids, form_data.group_ids + channel.id, user.id, form_data.user_ids, form_data.group_ids, db=db ) return memberships @@ -769,10 +769,10 @@ async def get_channel_messages( ) channel_member = Channels.join_channel( - id, user.id + id, user.id, db=db ) # Ensure user is a member of the channel - message_list = Messages.get_messages_by_channel_id(id, skip, limit) + message_list = Messages.get_messages_by_channel_id(id, skip, limit, db=db) users = {} messages = [] @@ -781,7 +781,7 @@ async def get_channel_messages( user = Users.get_user_by_id(message.user_id, db=db) users[message.user_id] = user - thread_replies = Messages.get_thread_replies_by_message_id(message.id) + thread_replies = Messages.get_thread_replies_by_message_id(message.id, db=db) latest_thread_reply_at = ( thread_replies[0].created_at if thread_replies else None ) @@ -792,7 +792,7 @@ async def get_channel_messages( **message.model_dump(), "reply_count": len(thread_replies), "latest_reply_at": latest_thread_reply_at, - "reactions": Messages.get_reactions_by_message_id(message.id), + "reactions": Messages.get_reactions_by_message_id(message.id, db=db), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -840,7 +840,7 @@ async def get_pinned_channel_messages( skip = (page - 1) * PAGE_ITEM_COUNT_PINNED limit = PAGE_ITEM_COUNT_PINNED - message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit) + message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit, db=db) users = {} messages = [] @@ -853,7 +853,7 @@ async def get_pinned_channel_messages( MessageWithReactionsResponse( **{ **message.model_dump(), - "reactions": Messages.get_reactions_by_message_id(message.id), + "reactions": Messages.get_reactions_by_message_id(message.id, db=db), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -867,12 +867,12 @@ async def get_pinned_channel_messages( ############################ -async def send_notification(name, webui_url, channel, message, active_user_ids): +async def send_notification(name, webui_url, channel, message, active_user_ids, db=None): users = get_users_with_access("read", channel.access_control) for user in users: if (user.id not in active_user_ids) and Channels.is_user_channel_member( - channel.id, user.id + channel.id, user.id, db=db ): if user.settings: webhook_url = user.settings.ui.get("notifications", {}).get( @@ -894,7 +894,7 @@ async def send_notification(name, webui_url, channel, message, active_user_ids): return True -async def model_response_handler(request, channel, message, user): +async def model_response_handler(request, channel, message, user, db=None): MODELS = { model["id"]: model for model in get_filtered_models(await get_all_models(request, user=user), user) @@ -932,6 +932,7 @@ async def model_response_handler(request, channel, message, user): thread_messages = Messages.get_messages_by_parent_id( channel.id, message.parent_id if message.parent_id else message.id, + db=db, )[::-1] response_message, channel = await new_message_handler( @@ -951,6 +952,7 @@ async def model_response_handler(request, channel, message, user): } ), user, + db, ) thread_history = [] @@ -1051,6 +1053,7 @@ async def model_response_handler(request, channel, message, user): } ), user, + db, ) elif res.get("error", None): await update_message_by_id( @@ -1066,6 +1069,7 @@ async def model_response_handler(request, channel, message, user): } ), user, + db, ) except Exception as e: log.info(e) @@ -1179,13 +1183,14 @@ async def post_new_message( active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") async def background_handler(): - await model_response_handler(request, channel, message, user) + await model_response_handler(request, channel, message, user, db) await send_notification( request.app.state.WEBUI_NAME, request.app.state.config.WEBUI_URL, channel, message, active_user_ids, + db=db, ) background_tasks.add_task(background_handler) @@ -1354,7 +1359,7 @@ async def pin_channel_message( ) try: - Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id) + Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id, db=db) message = Messages.get_message_by_id(message_id, db=db) return MessageUserResponse( **{ @@ -1408,7 +1413,7 @@ async def get_channel_thread_messages( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit) + message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit, db=db) users = {} messages = [] @@ -1423,7 +1428,7 @@ async def get_channel_thread_messages( **message.model_dump(), "reply_count": 0, "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id(message.id), + "reactions": Messages.get_reactions_by_message_id(message.id, db=db), "user": UserNameResponse(**users[message.user_id].model_dump()), } ) @@ -1484,7 +1489,7 @@ async def update_message_by_id( ) try: - message = Messages.update_message_by_id(message_id, form_data) + message = Messages.update_message_by_id(message_id, form_data, db=db) message = Messages.get_message_by_id(message_id, db=db) if message: @@ -1561,7 +1566,7 @@ async def add_reaction_to_message( ) try: - Messages.add_reaction_to_message(message_id, user.id, form_data.name) + Messages.add_reaction_to_message(message_id, user.id, form_data.name, db=db) message = Messages.get_message_by_id(message_id, db=db) await sio.emit( @@ -1637,7 +1642,7 @@ async def remove_reaction_by_id_and_user_id_and_name( try: Messages.remove_reaction_by_id_and_user_id_and_name( - message_id, user.id, form_data.name + message_id, user.id, form_data.name, db=db ) message = Messages.get_message_by_id(message_id, db=db) @@ -1721,7 +1726,7 @@ async def delete_message_by_id( ) try: - Messages.delete_message_by_id(message_id) + Messages.delete_message_by_id(message_id, db=db) await sio.emit( "events:channel", { @@ -1742,7 +1747,7 @@ async def delete_message_by_id( if message.parent_id: # If this message is a reply, emit to the parent message as well - parent_message = Messages.get_message_by_id(message.parent_id) + parent_message = Messages.get_message_by_id(message.parent_id, db=db) if parent_message: await sio.emit(