feat: collaborative note

This commit is contained in:
Timothy Jaeryang Baek 2025-07-11 23:59:48 +04:00
parent d3b14ff827
commit 2fbff741da
3 changed files with 520 additions and 114 deletions

View file

@ -5,11 +5,14 @@ import socketio
import logging import logging
import sys import sys
import time import time
from typing import Dict, Set
from redis import asyncio as aioredis from redis import asyncio as aioredis
import pycrdt as Y
from open_webui.models.users import Users, UserNameResponse from open_webui.models.users import Users, UserNameResponse
from open_webui.models.channels import Channels from open_webui.models.channels import Channels
from open_webui.models.chats import Chats from open_webui.models.chats import Chats
from open_webui.models.notes import Notes, NoteUpdateForm
from open_webui.utils.redis import ( from open_webui.utils.redis import (
get_sentinels_from_env, get_sentinels_from_env,
get_sentinel_url_from_env, get_sentinel_url_from_env,
@ -25,6 +28,10 @@ from open_webui.env import (
) )
from open_webui.utils.auth import decode_token from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock from open_webui.socket.utils import RedisDict, RedisLock
from open_webui.tasks import create_task, stop_item_tasks
from open_webui.utils.redis import get_redis_connection
from open_webui.utils.access_control import has_access, get_users_with_access
from open_webui.env import ( from open_webui.env import (
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
@ -37,6 +44,14 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"]) log.setLevel(SRC_LOG_LEVELS["SOCKET"])
REDIS = get_redis_connection(
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
),
async_mode=True,
)
if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_MANAGER == "redis":
if WEBSOCKET_SENTINEL_HOSTS: if WEBSOCKET_SENTINEL_HOSTS:
mgr = socketio.AsyncRedisManager( mgr = socketio.AsyncRedisManager(
@ -90,6 +105,9 @@ if WEBSOCKET_MANAGER == "redis":
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
) )
DOCUMENTS = {}
DOCUMENT_USERS = {}
clean_up_lock = RedisLock( clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock", lock_name="usage_cleanup_lock",
@ -103,6 +121,9 @@ else:
SESSION_POOL = {} SESSION_POOL = {}
USER_POOL = {} USER_POOL = {}
USAGE_POOL = {} USAGE_POOL = {}
DOCUMENTS = {} # document_id -> Y.YDoc instance
DOCUMENT_USERS = {} # document_id -> set of user sids
aquire_func = release_func = renew_func = lambda: True aquire_func = release_func = renew_func = lambda: True
@ -316,6 +337,206 @@ async def channel_events(sid, data):
) )
@sio.on("yjs:document:join")
async def yjs_document_join(sid, data):
"""Handle user joining a document"""
user = SESSION_POOL.get(sid)
try:
document_id = data["document_id"]
if document_id.startswith("note:"):
note_id = document_id.split(":")[1]
note = Notes.get_note_by_id(note_id)
if not note:
log.error(f"Note {note_id} not found")
return
if user.get("role") != "admin" and has_access(
user.get("id"), type="read", access_control=note.access_control
):
log.error(
f"User {user.get('id')} does not have access to note {note_id}"
)
return
user_id = data.get("user_id", sid)
user_name = data.get("user_name", "Anonymous")
user_color = data.get("user_color", "#000000")
log.info(f"User {user_id} joining document {document_id}")
# Initialize document if it doesn't exist
if document_id not in DOCUMENTS:
DOCUMENTS[document_id] = {
"ydoc": Y.Doc(), # Create actual Yjs document
"users": set(),
}
DOCUMENT_USERS[document_id] = set()
# Add user to document
DOCUMENTS[document_id]["users"].add(sid)
DOCUMENT_USERS[document_id].add(sid)
# Join Socket.IO room
await sio.enter_room(sid, f"doc_{document_id}")
# Send current document state as a proper Yjs update
ydoc = DOCUMENTS[document_id]["ydoc"]
# Encode the entire document state as an update
state_update = ydoc.get_update()
await sio.emit(
"yjs:document:state",
{
"document_id": document_id,
"state": list(state_update), # Convert bytes to list for JSON
},
room=sid,
)
# Notify other users about the new user
await sio.emit(
"yjs:user:joined",
{
"document_id": document_id,
"user_id": user_id,
"user_name": user_name,
"user_color": user_color,
},
room=f"doc_{document_id}",
skip_sid=sid,
)
log.info(f"User {user_id} successfully joined document {document_id}")
except Exception as e:
log.error(f"Error in yjs_document_join: {e}")
await sio.emit("error", {"message": "Failed to join document"}, room=sid)
async def document_save_handler(document_id, data, user):
if document_id.startswith("note:"):
note_id = document_id.split(":")[1]
note = Notes.get_note_by_id(note_id)
if not note:
log.error(f"Note {note_id} not found")
return
if user.get("role") != "admin" and has_access(
user.get("id"), type="read", access_control=note.access_control
):
log.error(f"User {user.get('id')} does not have access to note {note_id}")
return
Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
@sio.on("yjs:document:update")
async def yjs_document_update(sid, data):
"""Handle Yjs document updates"""
try:
document_id = data["document_id"]
await stop_item_tasks(REDIS, document_id)
user_id = data.get("user_id", sid)
update = data["update"] # List of bytes from frontend
if document_id not in DOCUMENTS:
log.warning(f"Document {document_id} not found")
return
# Apply the update to the server's Yjs document
ydoc = DOCUMENTS[document_id]["ydoc"]
update_bytes = bytes(update)
try:
ydoc.apply_update(update_bytes)
except Exception as e:
log.error(f"Failed to apply Yjs update: {e}")
return
# Broadcast update to all other users in the document
await sio.emit(
"yjs:document:update",
{
"document_id": document_id,
"user_id": user_id,
"update": update,
"socket_id": sid, # Add socket_id to match frontend filtering
},
room=f"doc_{document_id}",
skip_sid=sid,
)
async def debounced_save():
await asyncio.sleep(0.5)
await document_save_handler(
document_id, data.get("data", {}), SESSION_POOL.get(sid)
)
await stop_item_tasks(REDIS, document_id) # Cancel previous in-flight save
await create_task(REDIS, debounced_save(), document_id)
except Exception as e:
log.error(f"Error in yjs_document_update: {e}")
@sio.on("yjs:document:leave")
async def yjs_document_leave(sid, data):
"""Handle user leaving a document"""
try:
document_id = data["document_id"]
user_id = data.get("user_id", sid)
log.info(f"User {user_id} leaving document {document_id}")
if document_id in DOCUMENTS:
DOCUMENTS[document_id]["users"].discard(sid)
if document_id in DOCUMENT_USERS:
DOCUMENT_USERS[document_id].discard(sid)
# Leave Socket.IO room
await sio.leave_room(sid, f"doc_{document_id}")
# Notify other users
await sio.emit(
"yjs:user:left",
{"document_id": document_id, "user_id": user_id},
room=f"doc_{document_id}",
)
if document_id in DOCUMENTS and not DOCUMENTS[document_id]["users"]:
# If no users left, clean up the document
log.info(f"Cleaning up document {document_id} as no users are left")
del DOCUMENTS[document_id]
del DOCUMENT_USERS[document_id]
except Exception as e:
log.error(f"Error in yjs_document_leave: {e}")
@sio.on("yjs:awareness:update")
async def yjs_awareness_update(sid, data):
"""Handle awareness updates (cursors, selections, etc.)"""
try:
document_id = data["document_id"]
user_id = data.get("user_id", sid)
update = data["update"]
# Broadcast awareness update to all other users in the document
await sio.emit(
"yjs:awareness:update",
{"document_id": document_id, "user_id": user_id, "update": update},
room=f"doc_{document_id}",
skip_sid=sid,
)
except Exception as e:
log.error(f"Error in yjs_awareness_update: {e}")
@sio.event @sio.event
async def disconnect(sid): async def disconnect(sid):
if sid in SESSION_POOL: if sid in SESSION_POOL:

View file

@ -56,13 +56,22 @@
import { Fragment, DOMParser } from 'prosemirror-model'; import { Fragment, DOMParser } from 'prosemirror-model';
import { EditorState, Plugin, PluginKey, TextSelection, Selection } from 'prosemirror-state'; import { EditorState, Plugin, PluginKey, TextSelection, Selection } from 'prosemirror-state';
import { receiveTransaction, sendableSteps, getVersion } from 'prosemirror-collab';
import { Step } from 'prosemirror-transform';
import { Decoration, DecorationSet } from 'prosemirror-view';
import { Editor, Extension } from '@tiptap/core'; import { Editor, Extension } from '@tiptap/core';
// Yjs imports
import * as Y from 'yjs';
import {
ySyncPlugin,
yCursorPlugin,
yUndoPlugin,
undo,
redo,
prosemirrorJSONToYDoc,
yDocToProsemirrorJSON
} from 'y-prosemirror';
import { keymap } from 'prosemirror-keymap';
import { AIAutocompletion } from './RichTextInput/AutoCompletion.js'; import { AIAutocompletion } from './RichTextInput/AutoCompletion.js';
import History from '@tiptap/extension-history';
import Table from '@tiptap/extension-table'; import Table from '@tiptap/extension-table';
import TableRow from '@tiptap/extension-table-row'; import TableRow from '@tiptap/extension-table-row';
import TableHeader from '@tiptap/extension-table-header'; import TableHeader from '@tiptap/extension-table-header';
@ -126,98 +135,292 @@
export let largeTextAsFile = false; export let largeTextAsFile = false;
export let insertPromptAsRichText = false; export let insertPromptAsRichText = false;
let isConnected = false; let content = null;
let collaborators = new Map(); let htmlValue = '';
let version = 0; let jsonValue = '';
let mdValue = '';
// Custom collaboration plugin // Yjs setup
const collaborationPlugin = () => { let ydoc = null;
return new Plugin({ let yXmlFragment = null;
key: new PluginKey('collaboration'), let awareness = null;
state: {
init: () => ({ version: 0 }),
apply: (tr, pluginState) => {
const newState = { ...pluginState };
if (tr.getMeta('collaboration')) { // Custom Yjs Socket.IO provider
newState.version = tr.getMeta('collaboration').version; class SocketIOProvider {
} constructor(doc, documentId, socket, user) {
this.doc = doc;
this.documentId = documentId;
this.socket = socket;
this.user = user;
this.isConnected = false;
this.synced = false;
return newState; this.setupEventListeners();
}
},
view: () => ({
update: (view, prevState) => {
const sendable = sendableSteps(view.state);
if (sendable) {
socket.emit('document_steps', {
document_id: documentId,
user_id: user?.id,
version: sendable.version,
steps: sendable.steps.map((step) => step.toJSON()),
clientID: sendable.clientID
});
}
}
})
});
};
function initializeCollaboration() {
if (!socket || !user || !documentId) {
console.warn('Collaboration not initialized: missing socket, user, or documentId');
return;
} }
socket.emit('join_document', { onConnect() {
document_id: documentId, this.isConnected = true;
user_id: user?.id, this.joinDocument();
user_name: user?.name, }
user_color: user?.color
});
socket.on('document_steps', handleDocumentSteps); onDisconnect() {
socket.on('document_state', handleDocumentState); this.isConnected = false;
socket.on('user_joined', handleUserJoined); this.synced = false;
socket.on('user_left', handleUserLeft); }
socket.on('connect', () => {
isConnected = true;
});
socket.on('disconnect', () => {
isConnected = false;
});
}
function handleDocumentSteps(data) { setupEventListeners() {
if (data.user_id !== user?.id && editor) { // Listen for document updates from server
const steps = data.steps.map((stepJSON) => Step.fromJSON(editor.schema, stepJSON)); this.socket.on('yjs:document:update', (data) => {
const tr = receiveTransaction(editor.state, steps, data.clientID); if (data.document_id === this.documentId && data.socket_id !== this.socket.id) {
try {
const update = new Uint8Array(data.update);
Y.applyUpdate(this.doc, update);
} catch (error) {
console.error('Error applying Yjs update:', error);
}
}
});
if (tr) { // Listen for document state from server
editor.view.dispatch(tr); this.socket.on('yjs:document:state', async (data) => {
if (data.document_id === this.documentId) {
try {
if (data.state) {
const state = new Uint8Array(data.state);
if (state.length === 2 && state[0] === 0 && state[1] === 0) {
// Empty state, check if we have content to initialize
if (content) {
const pydoc = prosemirrorJSONToYDoc(editor.schema, content);
if (pydoc) {
Y.applyUpdate(this.doc, Y.encodeStateAsUpdate(pydoc));
}
}
} else {
Y.applyUpdate(this.doc, state);
}
}
this.synced = true;
} catch (error) {
console.error('Error applying Yjs state:', error);
}
}
});
// Listen for awareness updates
this.socket.on('yjs:awareness:update', (data) => {
if (data.document_id === this.documentId && awareness) {
try {
const awarenessUpdate = new Uint8Array(data.update);
awareness.applyUpdate(awarenessUpdate, 'server');
} catch (error) {
console.error('Error applying awareness update:', error);
}
}
});
// Handle connection events
this.socket.on('connect', this.onConnect);
this.socket.on('disconnect', this.onDisconnect);
// Listen for document updates from Yjs
this.doc.on('update', async (update, origin) => {
if (origin !== 'server' && this.isConnected) {
await tick(); // Ensure the DOM is updated before sending
this.socket.emit('yjs:document:update', {
document_id: this.documentId,
user_id: this.user?.id,
socket_id: this.socket.id,
update: Array.from(update),
data: {
content: {
md: mdValue,
html: htmlValue,
json: jsonValue
}
}
});
}
});
// Listen for awareness updates from Yjs
if (awareness) {
awareness.on('change', ({ added, updated, removed }, origin) => {
if (origin !== 'server' && this.isConnected) {
const changedClients = added.concat(updated).concat(removed);
const awarenessUpdate = awareness.encodeUpdate(changedClients);
this.socket.emit('yjs:awareness:update', {
document_id: this.documentId,
user_id: this.socket.id,
update: Array.from(awarenessUpdate)
});
}
});
}
if (this.socket.connected) {
this.isConnected = true;
this.joinDocument();
}
}
generateUserColor() {
const colors = [
'#FF6B6B',
'#4ECDC4',
'#45B7D1',
'#96CEB4',
'#FFEAA7',
'#DDA0DD',
'#98D8C8',
'#F7DC6F',
'#BB8FCE',
'#85C1E9'
];
return colors[Math.floor(Math.random() * colors.length)];
}
joinDocument() {
const userColor = this.generateUserColor();
this.socket.emit('yjs:document:join', {
document_id: this.documentId,
user_id: this.user?.id,
user_name: this.user?.name,
user_color: userColor
});
// Set user awareness info
if (awareness && this.user) {
awareness.setLocalStateField('user', {
name: `${this.user.name}`,
color: userColor,
id: this.socket.id
});
}
}
destroy() {
this.socket.off('yjs:document:update');
this.socket.off('yjs:document:state');
this.socket.off('yjs:awareness:update');
this.socket.off('connect', this.onConnect);
this.socket.off('disconnect', this.onDisconnect);
if (this.isConnected) {
this.socket.emit('yjs:document:leave', {
document_id: this.documentId,
user_id: this.user?.id
});
} }
} }
} }
function handleDocumentState(data) { let provider = null;
version = data.version;
if (data.content && editor) { // Simple awareness implementation
editor.commands.setContent(data.content); class SimpleAwareness {
constructor(yDoc) {
// Yjs awareness expects clientID (not clientId) property
this.clientID = yDoc ? yDoc.clientID : Math.floor(Math.random() * 0xffffffff);
// Map from clientID (number) to state (object)
this._states = new Map(); // _states, not states; will make getStates() for compat
this._updateHandlers = [];
this._localState = {};
// As in Yjs Awareness, add our local state to the states map from the start:
this._states.set(this.clientID, this._localState);
}
on(event, handler) {
if (event === 'change') this._updateHandlers.push(handler);
}
off(event, handler) {
if (event === 'change') {
const i = this._updateHandlers.indexOf(handler);
if (i !== -1) this._updateHandlers.splice(i, 1);
}
}
getLocalState() {
return this._states.get(this.clientID) || null;
}
getStates() {
// Yjs returns a Map (clientID->state)
return this._states;
}
setLocalStateField(field, value) {
let localState = this._states.get(this.clientID);
if (!localState) {
localState = {};
this._states.set(this.clientID, localState);
}
localState[field] = value;
// After updating, fire 'update' event to all handlers
for (const cb of this._updateHandlers) {
// Follows Yjs Awareness ({ added, updated, removed }, origin)
cb({ added: [], updated: [this.clientID], removed: [] }, 'local');
}
}
applyUpdate(update, origin) {
// Very simple: Accepts a serialized JSON state for now as Uint8Array
try {
const str = new TextDecoder().decode(update);
const obj = JSON.parse(str);
// Should be a plain object: { clientID: state, ... }
for (const [k, v] of Object.entries(obj)) {
this._states.set(+k, v);
}
for (const cb of this._updateHandlers) {
cb({ added: [], updated: Array.from(Object.keys(obj)).map(Number), removed: [] }, origin);
}
} catch (e) {
console.warn('SimpleAwareness: Could not decode update:', e);
}
}
encodeUpdate(clients) {
// Encodes the states for the given clientIDs as Uint8Array (JSON)
const obj = {};
for (const id of clients || Array.from(this._states.keys())) {
const st = this._states.get(id);
if (st) obj[id] = st;
}
const json = JSON.stringify(obj);
return new TextEncoder().encode(json);
} }
isConnected = true;
} }
function handleUserJoined(data) { // Yjs collaboration extension
collaborators.set(data.user_id, { const YjsCollaboration = Extension.create({
name: data.user_name, name: 'yjsCollaboration',
color: data.user_color
});
collaborators = collaborators;
}
function handleUserLeft(data) { addProseMirrorPlugins() {
collaborators.delete(data.user_id); if (!collaboration || !yXmlFragment) return [];
collaborators = collaborators;
const plugins = [
ySyncPlugin(yXmlFragment),
yUndoPlugin(),
keymap({
'Mod-z': undo,
'Mod-y': redo,
'Mod-Shift-z': redo
})
];
if (awareness) {
plugins.push(yCursorPlugin(awareness));
}
return plugins;
}
});
function initializeCollaboration() {
if (!collaboration) return;
// Create Yjs document
ydoc = new Y.Doc();
yXmlFragment = ydoc.getXmlFragment('prosemirror');
awareness = new SimpleAwareness(ydoc);
// Create custom Socket.IO provider
provider = new SocketIOProvider(ydoc, documentId, socket, user);
} }
let floatingMenuElement = null; let floatingMenuElement = null;
@ -538,7 +741,7 @@
}; };
onMount(async () => { onMount(async () => {
let content = value; content = value;
if (json) { if (json) {
if (!content) { if (!content) {
@ -655,28 +858,18 @@
}) })
] ]
: []), : []),
...(collaboration ? [YjsCollaboration] : [])
...(collaboration
? [
Extension.create({
name: 'socketCollaboration',
addProseMirrorPlugins() {
return [collaborationPlugin()];
}
})
]
: [])
], ],
content: content, content: collaboration ? undefined : content,
autofocus: messageInput ? true : false, autofocus: messageInput ? true : false,
onTransaction: () => { onTransaction: () => {
// force re-render so `editor.isActive` works as expected // force re-render so `editor.isActive` works as expected
editor = editor; editor = editor;
const htmlValue = editor.getHTML(); htmlValue = editor.getHTML();
const jsonValue = editor.getJSON(); jsonValue = editor.getJSON();
let mdValue = turndownService mdValue = turndownService
.turndown( .turndown(
htmlValue htmlValue
.replace(/<p><\/p>/g, '<br/>') .replace(/<p><\/p>/g, '<br/>')
@ -872,16 +1065,8 @@
}); });
onDestroy(() => { onDestroy(() => {
if (socket) { if (provider) {
socket.off('document_steps', handleDocumentSteps); provider.destroy();
socket.off('document_state', handleDocumentState);
socket.off('user_joined', handleUserJoined);
socket.off('user_left', handleUserLeft);
socket.emit('leave_document', {
document_id: documentId,
user_id: userId
});
} }
if (editor) { if (editor) {
@ -889,7 +1074,7 @@
} }
}); });
$: if (value !== null && editor) { $: if (value !== null && editor && !collaboration) {
onValueChange(); onValueChange();
} }

View file

@ -31,7 +31,7 @@
import { uploadFile } from '$lib/apis/files'; import { uploadFile } from '$lib/apis/files';
import { chatCompletion } from '$lib/apis/openai'; import { chatCompletion } from '$lib/apis/openai';
import { config, models, settings, showSidebar } from '$lib/stores'; import { config, models, settings, showSidebar, socket, user } from '$lib/stores';
import NotePanel from '$lib/components/notes/NotePanel.svelte'; import NotePanel from '$lib/components/notes/NotePanel.svelte';
import MenuLines from '../icons/MenuLines.svelte'; import MenuLines from '../icons/MenuLines.svelte';
@ -171,10 +171,6 @@
}, 200); }, 200);
}; };
$: if (note) {
changeDebounceHandler();
}
$: if (id) { $: if (id) {
init(); init();
} }
@ -862,7 +858,7 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
</div> </div>
</div> </div>
<div class=" mb-2.5 px-2.5"> <div class=" px-2.5">
<div <div
class=" flex w-full bg-transparent overflow-x-auto scrollbar-none" class=" flex w-full bg-transparent overflow-x-auto scrollbar-none"
on:wheel={(e) => { on:wheel={(e) => {
@ -906,7 +902,7 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
</div> </div>
<div <div
class=" flex-1 w-full h-full overflow-auto px-3.5 pb-20 relative" class=" flex-1 w-full h-full overflow-auto px-3.5 pb-20 relative z-40 pt-2.5"
id="note-content-container" id="note-content-container"
> >
{#if enhancing} {#if enhancing}
@ -959,6 +955,10 @@ Provide the enhanced notes in markdown format. Use markdown syntax for headings,
html={note.data?.content?.html} html={note.data?.content?.html}
json={true} json={true}
link={true} link={true}
documentId={`note:${note.id}`}
collaboration={true}
socket={$socket}
user={$user}
placeholder={$i18n.t('Write something...')} placeholder={$i18n.t('Write something...')}
editable={versionIdx === null && !enhancing} editable={versionIdx === null && !enhancing}
onChange={(content) => { onChange={(content) => {