diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 818a57807f..60df810785 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -282,6 +282,7 @@ async def connect(sid, environ, auth): else: USER_POOL[user.id] = [sid] + await sio.enter_room(sid, f"user:{user.id}") @sio.on("user-join") async def user_join(sid, data): @@ -304,6 +305,7 @@ async def user_join(sid, data): else: USER_POOL[user.id] = [sid] + await sio.enter_room(sid, f"user:{user.id}") # Join all the channels channels = Channels.get_channels_by_user_id(user.id) log.debug(f"{channels=}") @@ -650,34 +652,18 @@ def get_event_emitter(request_info, update_db=True): async def __event_emitter__(event_data): user_id = request_info["user_id"] - session_ids = list( - set( - USER_POOL.get(user_id, []) - + ( - [request_info.get("session_id")] - if request_info.get("session_id") - else [] - ) - ) - ) - chat_id = request_info.get("chat_id", None) message_id = request_info.get("message_id", None) - emit_tasks = [ - sio.emit( - "events", - { - "chat_id": chat_id, - "message_id": message_id, - "data": event_data, - }, - to=session_id, - ) - for session_id in session_ids - ] - - await asyncio.gather(*emit_tasks) + await sio.emit( + "events", + { + "chat_id": chat_id, + "message_id": message_id, + "data": event_data, + }, + room=f"user:{user_id}", + ) if ( update_db and message_id