From 2fbff741dacf19df19df5174104a786cc0a0b1fa Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 11 Jul 2025 23:59:48 +0400 Subject: [PATCH] feat: collaborative note --- backend/open_webui/socket/main.py | 221 ++++++++++ .../components/common/RichTextInput.svelte | 399 +++++++++++++----- src/lib/components/notes/NoteEditor.svelte | 14 +- 3 files changed, 520 insertions(+), 114 deletions(-) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 96bcbcf1b5..f8a92bc6f4 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -5,11 +5,14 @@ import socketio import logging import sys import time +from typing import Dict, Set from redis import asyncio as aioredis +import pycrdt as Y from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels from open_webui.models.chats import Chats +from open_webui.models.notes import Notes, NoteUpdateForm from open_webui.utils.redis import ( get_sentinels_from_env, get_sentinel_url_from_env, @@ -25,6 +28,10 @@ from open_webui.env import ( ) from open_webui.utils.auth import decode_token from open_webui.socket.utils import RedisDict, RedisLock +from open_webui.tasks import create_task, stop_item_tasks +from open_webui.utils.redis import get_redis_connection +from open_webui.utils.access_control import has_access, get_users_with_access + from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -37,6 +44,14 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["SOCKET"]) +REDIS = get_redis_connection( + redis_url=WEBSOCKET_REDIS_URL, + redis_sentinels=get_sentinels_from_env( + WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT + ), + async_mode=True, +) + if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_SENTINEL_HOSTS: mgr = socketio.AsyncRedisManager( @@ -90,6 +105,9 @@ if WEBSOCKET_MANAGER == "redis": redis_sentinels=redis_sentinels, ) + DOCUMENTS = {} + DOCUMENT_USERS = {} + clean_up_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, lock_name="usage_cleanup_lock", @@ -103,6 +121,9 @@ else: SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} + + DOCUMENTS = {} # document_id -> Y.YDoc instance + DOCUMENT_USERS = {} # document_id -> set of user sids aquire_func = release_func = renew_func = lambda: True @@ -316,6 +337,206 @@ async def channel_events(sid, data): ) +@sio.on("yjs:document:join") +async def yjs_document_join(sid, data): + """Handle user joining a document""" + user = SESSION_POOL.get(sid) + + try: + document_id = data["document_id"] + + if document_id.startswith("note:"): + note_id = document_id.split(":")[1] + note = Notes.get_note_by_id(note_id) + if not note: + log.error(f"Note {note_id} not found") + return + + if user.get("role") != "admin" and has_access( + user.get("id"), type="read", access_control=note.access_control + ): + log.error( + f"User {user.get('id')} does not have access to note {note_id}" + ) + return + + user_id = data.get("user_id", sid) + user_name = data.get("user_name", "Anonymous") + user_color = data.get("user_color", "#000000") + + log.info(f"User {user_id} joining document {document_id}") + + # Initialize document if it doesn't exist + if document_id not in DOCUMENTS: + DOCUMENTS[document_id] = { + "ydoc": Y.Doc(), # Create actual Yjs document + "users": set(), + } + DOCUMENT_USERS[document_id] = set() + + # Add user to document + DOCUMENTS[document_id]["users"].add(sid) + DOCUMENT_USERS[document_id].add(sid) + + # Join Socket.IO room + await sio.enter_room(sid, f"doc_{document_id}") + + # Send current document state as a proper Yjs update + ydoc = DOCUMENTS[document_id]["ydoc"] + + # Encode the entire document state as an update + state_update = ydoc.get_update() + await sio.emit( + "yjs:document:state", + { + "document_id": document_id, + "state": list(state_update), # Convert bytes to list for JSON + }, + room=sid, + ) + + # Notify other users about the new user + await sio.emit( + "yjs:user:joined", + { + "document_id": document_id, + "user_id": user_id, + "user_name": user_name, + "user_color": user_color, + }, + room=f"doc_{document_id}", + skip_sid=sid, + ) + + log.info(f"User {user_id} successfully joined document {document_id}") + + except Exception as e: + log.error(f"Error in yjs_document_join: {e}") + await sio.emit("error", {"message": "Failed to join document"}, room=sid) + + +async def document_save_handler(document_id, data, user): + if document_id.startswith("note:"): + note_id = document_id.split(":")[1] + note = Notes.get_note_by_id(note_id) + if not note: + log.error(f"Note {note_id} not found") + return + + if user.get("role") != "admin" and has_access( + user.get("id"), type="read", access_control=note.access_control + ): + log.error(f"User {user.get('id')} does not have access to note {note_id}") + return + + Notes.update_note_by_id(note_id, NoteUpdateForm(data=data)) + + +@sio.on("yjs:document:update") +async def yjs_document_update(sid, data): + """Handle Yjs document updates""" + try: + document_id = data["document_id"] + await stop_item_tasks(REDIS, document_id) + + user_id = data.get("user_id", sid) + update = data["update"] # List of bytes from frontend + + if document_id not in DOCUMENTS: + log.warning(f"Document {document_id} not found") + return + + # Apply the update to the server's Yjs document + ydoc = DOCUMENTS[document_id]["ydoc"] + update_bytes = bytes(update) + + try: + ydoc.apply_update(update_bytes) + except Exception as e: + log.error(f"Failed to apply Yjs update: {e}") + return + + # Broadcast update to all other users in the document + await sio.emit( + "yjs:document:update", + { + "document_id": document_id, + "user_id": user_id, + "update": update, + "socket_id": sid, # Add socket_id to match frontend filtering + }, + room=f"doc_{document_id}", + skip_sid=sid, + ) + + async def debounced_save(): + await asyncio.sleep(0.5) + await document_save_handler( + document_id, data.get("data", {}), SESSION_POOL.get(sid) + ) + + await stop_item_tasks(REDIS, document_id) # Cancel previous in-flight save + await create_task(REDIS, debounced_save(), document_id) + + except Exception as e: + log.error(f"Error in yjs_document_update: {e}") + + +@sio.on("yjs:document:leave") +async def yjs_document_leave(sid, data): + """Handle user leaving a document""" + try: + document_id = data["document_id"] + user_id = data.get("user_id", sid) + + log.info(f"User {user_id} leaving document {document_id}") + + if document_id in DOCUMENTS: + DOCUMENTS[document_id]["users"].discard(sid) + + if document_id in DOCUMENT_USERS: + DOCUMENT_USERS[document_id].discard(sid) + + # Leave Socket.IO room + await sio.leave_room(sid, f"doc_{document_id}") + + # Notify other users + await sio.emit( + "yjs:user:left", + {"document_id": document_id, "user_id": user_id}, + room=f"doc_{document_id}", + ) + + if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]: + # If no users left, clean up the document + log.info(f"Cleaning up document {document_id} as no users are left") + del DOCUMENTS[document_id] + del DOCUMENT_USERS[document_id] + + except Exception as e: + log.error(f"Error in yjs_document_leave: {e}") + + +@sio.on("yjs:awareness:update") +async def yjs_awareness_update(sid, data): + """Handle awareness updates (cursors, selections, etc.)""" + try: + document_id = data["document_id"] + user_id = data.get("user_id", sid) + update = data["update"] + + # Broadcast awareness update to all other users in the document + await sio.emit( + "yjs:awareness:update", + {"document_id": document_id, "user_id": user_id, "update": update}, + room=f"doc_{document_id}", + skip_sid=sid, + ) + + except Exception as e: + log.error(f"Error in yjs_awareness_update: {e}") + + @sio.event async def disconnect(sid): if sid in SESSION_POOL: diff --git a/src/lib/components/common/RichTextInput.svelte b/src/lib/components/common/RichTextInput.svelte index 4befba5c4f..be9469cdff 100644 --- a/src/lib/components/common/RichTextInput.svelte +++ b/src/lib/components/common/RichTextInput.svelte @@ -56,13 +56,22 @@ import { Fragment, DOMParser } from 'prosemirror-model'; import { EditorState, Plugin, PluginKey, TextSelection, Selection } from 'prosemirror-state'; - import { receiveTransaction, sendableSteps, getVersion } from 'prosemirror-collab'; - import { Step } from 'prosemirror-transform'; - import { Decoration, DecorationSet } from 'prosemirror-view'; import { Editor, Extension } from '@tiptap/core'; + // Yjs imports + import * as Y from 'yjs'; + import { + ySyncPlugin, + yCursorPlugin, + yUndoPlugin, + undo, + redo, + prosemirrorJSONToYDoc, + yDocToProsemirrorJSON + } from 'y-prosemirror'; + import { keymap } from 'prosemirror-keymap'; + import { AIAutocompletion } from './RichTextInput/AutoCompletion.js'; - import History from '@tiptap/extension-history'; import Table from '@tiptap/extension-table'; import TableRow from '@tiptap/extension-table-row'; import TableHeader from '@tiptap/extension-table-header'; @@ -126,98 +135,292 @@ export let largeTextAsFile = false; export let insertPromptAsRichText = false; - let isConnected = false; - let collaborators = new Map(); - let version = 0; + let content = null; + let htmlValue = ''; + let jsonValue = ''; + let mdValue = ''; - // Custom collaboration plugin - const collaborationPlugin = () => { - return new Plugin({ - key: new PluginKey('collaboration'), - state: { - init: () => ({ version: 0 }), - apply: (tr, pluginState) => { - const newState = { ...pluginState }; + // Yjs setup + let ydoc = null; + let yXmlFragment = null; + let awareness = null; - if (tr.getMeta('collaboration')) { - newState.version = tr.getMeta('collaboration').version; - } + // Custom Yjs Socket.IO provider + class SocketIOProvider { + constructor(doc, documentId, socket, user) { + this.doc = doc; + this.documentId = documentId; + this.socket = socket; + this.user = user; + this.isConnected = false; + this.synced = false; - return newState; - } - }, - view: () => ({ - update: (view, prevState) => { - const sendable = sendableSteps(view.state); - if (sendable) { - socket.emit('document_steps', { - document_id: documentId, - user_id: user?.id, - version: sendable.version, - steps: sendable.steps.map((step) => step.toJSON()), - clientID: sendable.clientID - }); - } - } - }) - }); - }; - - function initializeCollaboration() { - if (!socket || !user || !documentId) { - console.warn('Collaboration not initialized: missing socket, user, or documentId'); - return; + this.setupEventListeners(); } - socket.emit('join_document', { - document_id: documentId, - user_id: user?.id, - user_name: user?.name, - user_color: user?.color - }); + onConnect() { + this.isConnected = true; + this.joinDocument(); + } - socket.on('document_steps', handleDocumentSteps); - socket.on('document_state', handleDocumentState); - socket.on('user_joined', handleUserJoined); - socket.on('user_left', handleUserLeft); - socket.on('connect', () => { - isConnected = true; - }); - socket.on('disconnect', () => { - isConnected = false; - }); - } + onDisconnect() { + this.isConnected = false; + this.synced = false; + } - function handleDocumentSteps(data) { - if (data.user_id !== user?.id && editor) { - const steps = data.steps.map((stepJSON) => Step.fromJSON(editor.schema, stepJSON)); - const tr = receiveTransaction(editor.state, steps, data.clientID); + setupEventListeners() { + // Listen for document updates from server + this.socket.on('yjs:document:update', (data) => { + if (data.document_id === this.documentId && data.socket_id !== this.socket.id) { + try { + const update = new Uint8Array(data.update); + Y.applyUpdate(this.doc, update); + } catch (error) { + console.error('Error applying Yjs update:', error); + } + } + }); - if (tr) { - editor.view.dispatch(tr); + // Listen for document state from server + this.socket.on('yjs:document:state', async (data) => { + if (data.document_id === this.documentId) { + try { + if (data.state) { + const state = new Uint8Array(data.state); + + if (state.length === 2 && state[0] === 0 && state[1] === 0) { + // Empty state, check if we have content to initialize + if (content) { + const pydoc = prosemirrorJSONToYDoc(editor.schema, content); + if (pydoc) { + Y.applyUpdate(this.doc, Y.encodeStateAsUpdate(pydoc)); + } + } + } else { + Y.applyUpdate(this.doc, state); + } + } + this.synced = true; + } catch (error) { + console.error('Error applying Yjs state:', error); + } + } + }); + + // Listen for awareness updates + this.socket.on('yjs:awareness:update', (data) => { + if (data.document_id === this.documentId && awareness) { + try { + const awarenessUpdate = new Uint8Array(data.update); + awareness.applyUpdate(awarenessUpdate, 'server'); + } catch (error) { + console.error('Error applying awareness update:', error); + } + } + }); + + // Handle connection events + this.socket.on('connect', this.onConnect); + this.socket.on('disconnect', this.onDisconnect); + + // Listen for document updates from Yjs + this.doc.on('update', async (update, origin) => { + if (origin !== 'server' && this.isConnected) { + await tick(); // Ensure the DOM is updated before sending + this.socket.emit('yjs:document:update', { + document_id: this.documentId, + user_id: this.user?.id, + socket_id: this.socket.id, + update: Array.from(update), + data: { + content: { + md: mdValue, + html: htmlValue, + json: jsonValue + } + } + }); + } + }); + + // Listen for awareness updates from Yjs + if (awareness) { + awareness.on('change', ({ added, updated, removed }, origin) => { + if (origin !== 'server' && this.isConnected) { + const changedClients = added.concat(updated).concat(removed); + const awarenessUpdate = awareness.encodeUpdate(changedClients); + this.socket.emit('yjs:awareness:update', { + document_id: this.documentId, + user_id: this.socket.id, + update: Array.from(awarenessUpdate) + }); + } + }); + } + + if (this.socket.connected) { + this.isConnected = true; + this.joinDocument(); + } + } + + generateUserColor() { + const colors = [ + '#FF6B6B', + '#4ECDC4', + '#45B7D1', + '#96CEB4', + '#FFEAA7', + '#DDA0DD', + '#98D8C8', + '#F7DC6F', + '#BB8FCE', + '#85C1E9' + ]; + return colors[Math.floor(Math.random() * colors.length)]; + } + + joinDocument() { + const userColor = this.generateUserColor(); + this.socket.emit('yjs:document:join', { + document_id: this.documentId, + user_id: this.user?.id, + user_name: this.user?.name, + user_color: userColor + }); + + // Set user awareness info + if (awareness && this.user) { + awareness.setLocalStateField('user', { + name: `${this.user.name}`, + color: userColor, + id: this.socket.id + }); + } + } + + destroy() { + this.socket.off('yjs:document:update'); + this.socket.off('yjs:document:state'); + this.socket.off('yjs:awareness:update'); + this.socket.off('connect', this.onConnect); + this.socket.off('disconnect', this.onDisconnect); + + if (this.isConnected) { + this.socket.emit('yjs:document:leave', { + document_id: this.documentId, + user_id: this.user?.id + }); } } } - function handleDocumentState(data) { - version = data.version; - if (data.content && editor) { - editor.commands.setContent(data.content); + let provider = null; + + // Simple awareness implementation + class SimpleAwareness { + constructor(yDoc) { + // Yjs awareness expects clientID (not clientId) property + this.clientID = yDoc ? yDoc.clientID : Math.floor(Math.random() * 0xffffffff); + // Map from clientID (number) to state (object) + this._states = new Map(); // _states, not states; will make getStates() for compat + this._updateHandlers = []; + this._localState = {}; + // As in Yjs Awareness, add our local state to the states map from the start: + this._states.set(this.clientID, this._localState); + } + on(event, handler) { + if (event === 'change') this._updateHandlers.push(handler); + } + off(event, handler) { + if (event === 'change') { + const i = this._updateHandlers.indexOf(handler); + if (i !== -1) this._updateHandlers.splice(i, 1); + } + } + getLocalState() { + return this._states.get(this.clientID) || null; + } + getStates() { + // Yjs returns a Map (clientID->state) + return this._states; + } + setLocalStateField(field, value) { + let localState = this._states.get(this.clientID); + if (!localState) { + localState = {}; + this._states.set(this.clientID, localState); + } + localState[field] = value; + // After updating, fire 'update' event to all handlers + for (const cb of this._updateHandlers) { + // Follows Yjs Awareness ({ added, updated, removed }, origin) + cb({ added: [], updated: [this.clientID], removed: [] }, 'local'); + } + } + applyUpdate(update, origin) { + // Very simple: Accepts a serialized JSON state for now as Uint8Array + try { + const str = new TextDecoder().decode(update); + const obj = JSON.parse(str); + // Should be a plain object: { clientID: state, ... } + for (const [k, v] of Object.entries(obj)) { + this._states.set(+k, v); + } + for (const cb of this._updateHandlers) { + cb({ added: [], updated: Array.from(Object.keys(obj)).map(Number), removed: [] }, origin); + } + } catch (e) { + console.warn('SimpleAwareness: Could not decode update:', e); + } + } + encodeUpdate(clients) { + // Encodes the states for the given clientIDs as Uint8Array (JSON) + const obj = {}; + for (const id of clients || Array.from(this._states.keys())) { + const st = this._states.get(id); + if (st) obj[id] = st; + } + const json = JSON.stringify(obj); + return new TextEncoder().encode(json); } - isConnected = true; } - function handleUserJoined(data) { - collaborators.set(data.user_id, { - name: data.user_name, - color: data.user_color - }); - collaborators = collaborators; - } + // Yjs collaboration extension + const YjsCollaboration = Extension.create({ + name: 'yjsCollaboration', - function handleUserLeft(data) { - collaborators.delete(data.user_id); - collaborators = collaborators; + addProseMirrorPlugins() { + if (!collaboration || !yXmlFragment) return []; + + const plugins = [ + ySyncPlugin(yXmlFragment), + yUndoPlugin(), + keymap({ + 'Mod-z': undo, + 'Mod-y': redo, + 'Mod-Shift-z': redo + }) + ]; + + if (awareness) { + plugins.push(yCursorPlugin(awareness)); + } + + return plugins; + } + }); + + function initializeCollaboration() { + if (!collaboration) return; + + // Create Yjs document + ydoc = new Y.Doc(); + yXmlFragment = ydoc.getXmlFragment('prosemirror'); + awareness = new SimpleAwareness(ydoc); + + // Create custom Socket.IO provider + provider = new SocketIOProvider(ydoc, documentId, socket, user); } let floatingMenuElement = null; @@ -538,7 +741,7 @@ }; onMount(async () => { - let content = value; + content = value; if (json) { if (!content) { @@ -655,28 +858,18 @@ }) ] : []), - - ...(collaboration - ? [ - Extension.create({ - name: 'socketCollaboration', - addProseMirrorPlugins() { - return [collaborationPlugin()]; - } - }) - ] - : []) + ...(collaboration ? [YjsCollaboration] : []) ], - content: content, + content: collaboration ? undefined : content, autofocus: messageInput ? true : false, onTransaction: () => { // force re-render so `editor.isActive` works as expected editor = editor; - const htmlValue = editor.getHTML(); - const jsonValue = editor.getJSON(); + htmlValue = editor.getHTML(); + jsonValue = editor.getJSON(); - let mdValue = turndownService + mdValue = turndownService .turndown( htmlValue .replace(/

<\/p>/g, '
') @@ -872,16 +1065,8 @@ }); onDestroy(() => { - if (socket) { - socket.off('document_steps', handleDocumentSteps); - socket.off('document_state', handleDocumentState); - socket.off('user_joined', handleUserJoined); - socket.off('user_left', handleUserLeft); - - socket.emit('leave_document', { - document_id: documentId, - user_id: userId - }); + if (provider) { + provider.destroy(); } if (editor) { @@ -889,7 +1074,7 @@ } }); - $: if (value !== null && editor) { + $: if (value !== null && editor && !collaboration) { onValueChange(); } diff --git a/src/lib/components/notes/NoteEditor.svelte b/src/lib/components/notes/NoteEditor.svelte index 168a21225e..ebcde20060 100644 --- a/src/lib/components/notes/NoteEditor.svelte +++ b/src/lib/components/notes/NoteEditor.svelte @@ -31,7 +31,7 @@ import { uploadFile } from '$lib/apis/files'; import { chatCompletion } from '$lib/apis/openai'; - import { config, models, settings, showSidebar } from '$lib/stores'; + import { config, models, settings, showSidebar, socket, user } from '$lib/stores'; import NotePanel from '$lib/components/notes/NotePanel.svelte'; import MenuLines from '../icons/MenuLines.svelte'; @@ -171,10 +171,6 @@ }, 200); }; - $: if (note) { - changeDebounceHandler(); - } - $: if (id) { init(); } @@ -862,7 +858,7 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings, -

+
{ @@ -906,7 +902,7 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
{#if enhancing} @@ -959,6 +955,10 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings, html={note.data?.content?.html} json={true} link={true} + documentId={`note:${note.id}`} + collaboration={true} + socket={$socket} + user={$user} placeholder={$i18n.t('Write something...')} editable={versionIdx === null && !enhancing} onChange={(content) => {