diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 1591cb48ca..0791021063 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -27,7 +27,7 @@ from open_webui.env import ( WEBSOCKET_SENTINEL_HOSTS, ) from open_webui.utils.auth import decode_token -from open_webui.socket.utils import RedisDict, RedisLock +from open_webui.socket.utils import RedisDict, RedisLock, YdocManager from open_webui.tasks import create_task, stop_item_tasks from open_webui.utils.redis import get_redis_connection from open_webui.utils.access_control import has_access, get_users_with_access @@ -125,7 +125,10 @@ else: # TODO: Implement Yjs document management with Redis -DOCUMENTS = {} # document_id -> Y.YDoc instance +YDOC_MANAGER = YdocManager( + redis=REDIS, + redis_key_prefix="open-webui:ydoc:documents", +) async def periodic_usage_pool_cleanup(): @@ -374,16 +377,7 @@ async def ydoc_document_join(sid, data): user_color = data.get("user_color", "#000000") log.info(f"User {user_id} joining document {document_id}") - - # Initialize document if it doesn't exist - if document_id not in DOCUMENTS: - DOCUMENTS[document_id] = { - "updates": [], # Store updates for the document - "users": set(), - } - - # Add user to document - DOCUMENTS[document_id]["users"].add(sid) + await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid) # Join Socket.IO room await sio.enter_room(sid, f"doc_{document_id}") @@ -392,7 +386,8 @@ async def ydoc_document_join(sid, data): # Get the Yjs document state ydoc = Y.Doc() - for update in DOCUMENTS[document_id]["updates"]: + updates = await YDOC_MANAGER.get_updates(document_id) + for update in updates: ydoc.apply_update(bytes(update)) # Encode the entire document state as an update @@ -461,13 +456,14 @@ async def yjs_document_state(sid, data): log.warning(f"Session {sid} not in room {room}. Cannot send state.") return - if document_id not in DOCUMENTS: + if not await YDOC_MANAGER.document_exists(document_id): log.warning(f"Document {document_id} not found") return # Get the Yjs document state ydoc = Y.Doc() - for update in DOCUMENTS[document_id]["updates"]: + updates = await YDOC_MANAGER.get_updates(document_id) + for update in updates: ydoc.apply_update(bytes(update)) # Encode the entire document state as an update @@ -491,6 +487,7 @@ async def yjs_document_update(sid, data): """Handle Yjs document updates""" try: document_id = data["document_id"] + try: await stop_item_tasks(REDIS, document_id) except: @@ -500,12 +497,10 @@ async def yjs_document_update(sid, data): update = data["update"] # List of bytes from frontend - if document_id not in DOCUMENTS: - log.warning(f"Document {document_id} not found") - return - - updates = DOCUMENTS[document_id]["updates"] - updates.append(update) + await YDOC_MANAGER.append_to_updates( + document_id=document_id, + update=update, # Convert list of bytes to bytes + ) # Broadcast update to all other users in the document await sio.emit( @@ -541,8 +536,8 @@ async def yjs_document_leave(sid, data): log.info(f"User {user_id} leaving document {document_id}") - if document_id in DOCUMENTS: - DOCUMENTS[document_id]["users"].discard(sid) + # Remove user from the document + await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid) # Leave Socket.IO room await sio.leave_room(sid, f"doc_{document_id}") @@ -554,10 +549,12 @@ async def yjs_document_leave(sid, data): room=f"doc_{document_id}", ) - if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]: - # If no users left, clean up the document + if ( + YDOC_MANAGER.document_exists(document_id) + and len(await YDOC_MANAGER.get_users(document_id)) == 0 + ): log.info(f"Cleaning up document {document_id} as no users are left") - del DOCUMENTS[document_id] + await YDOC_MANAGER.clear_document(document_id) except Exception as e: log.error(f"Error in yjs_document_leave: {e}") @@ -594,6 +591,8 @@ async def disconnect(sid): if len(USER_POOL[user_id]) == 0: del USER_POOL[user_id] + + await YDOC_MANAGER.remove_user_from_all_documents(sid) else: pass # print(f"Unknown session ID {sid} disconnected") diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 85a8bb7909..a422d76207 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,6 +1,8 @@ import json import uuid from open_webui.utils.redis import get_redis_connection +from typing import Optional, List, Tuple +import pycrdt as Y class RedisLock: @@ -89,3 +91,109 @@ class RedisDict: if key not in self: self[key] = default return self[key] + + +class YdocManager: + def __init__( + self, + redis=None, + redis_key_prefix: str = "open-webui:ydoc:documents", + ): + self._updates = {} + self._users = {} + self._redis = redis + self._redis_key_prefix = redis_key_prefix + + async def append_to_updates(self, document_id: str, update: bytes): + document_id = document_id.replace(":", "_") + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + await self._redis.rpush(redis_key, json.dumps(list(update))) + else: + if document_id not in self._updates: + self._updates[document_id] = [] + self._updates[document_id].append(update) + + async def get_updates(self, document_id: str) -> List[bytes]: + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + updates = await self._redis.lrange(redis_key, 0, -1) + return [bytes(json.loads(update)) for update in updates] + else: + return self._updates.get(document_id, []) + + async def document_exists(self, document_id: str) -> bool: + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + return await self._redis.exists(redis_key) > 0 + else: + return document_id in self._updates + + async def get_users(self, document_id: str) -> List[str]: + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:users" + users = await self._redis.smembers(redis_key) + return list(users) + else: + return self._users.get(document_id, []) + + async def add_user(self, document_id: str, user_id: str): + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:users" + await self._redis.sadd(redis_key, user_id) + else: + if document_id not in self._users: + self._users[document_id] = set() + self._users[document_id].add(user_id) + + async def remove_user(self, document_id: str, user_id: str): + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:users" + await self._redis.srem(redis_key, user_id) + else: + if document_id in self._users and user_id in self._users[document_id]: + self._users[document_id].remove(user_id) + + async def remove_user_from_all_documents(self, user_id: str): + if self._redis: + keys = await self._redis.keys(f"{self._redis_key_prefix}:*") + for key in keys: + if key.endswith(":users"): + await self._redis.srem(key, user_id) + + document_id = key.split(":")[-2] + if len(await self.get_users(document_id)) == 0: + await self.clear_document(document_id) + + else: + for document_id in list(self._users.keys()): + if user_id in self._users[document_id]: + self._users[document_id].remove(user_id) + if not self._users[document_id]: + del self._users[document_id] + + await self.clear_document(document_id) + + async def clear_document(self, document_id: str): + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + await self._redis.delete(redis_key) + redis_users_key = f"{self._redis_key_prefix}:{document_id}:users" + await self._redis.delete(redis_users_key) + else: + if document_id in self._updates: + del self._updates[document_id] + if document_id in self._users: + del self._users[document_id]