diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index b61e820eae..b44d459dd0 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -2,6 +2,7 @@ import json import time import uuid from typing import Optional +from functools import lru_cache from open_webui.internal.db import Base, get_db from open_webui.models.groups import Groups @@ -110,20 +111,66 @@ class NoteTable: return [NoteModel.model_validate(note) for note in notes] def get_notes_by_user_id( + self, + user_id: str, + skip: Optional[int] = None, + limit: Optional[int] = None, + ) -> list[NoteModel]: + with get_db() as db: + query = db.query(Note).filter(Note.user_id == user_id) + query = query.order_by(Note.updated_at.desc()) + + if skip is not None: + query = query.offset(skip) + if limit is not None: + query = query.limit(limit) + + notes = query.all() + return [NoteModel.model_validate(note) for note in notes] + + def get_notes_by_access( self, user_id: str, permission: str = "write", skip: Optional[int] = None, limit: Optional[int] = None, ) -> list[NoteModel]: - notes = self.get_notes(skip=skip, limit=limit) - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} - return [ - note - for note in notes - if note.user_id == user_id - or has_access(user_id, permission, note.access_control, user_group_ids) - ] + with get_db() as db: + user_groups = Groups.get_groups_by_member_id(user_id) + user_group_ids = {group_id for group_id in user_groups} + + query = db.query(Note) + + access_conditions = [Note.user_id == user_id] + + if user_group_ids: + access_conditions.append( + and_( + Note.access_control.isnot(None), + Note.access_control != '{}', + Note.access_control != 'null' + ) + ) + + query = query.filter(or_(*access_conditions)) + + query = query.order_by(Note.updated_at.desc()) + + if skip is not None: + query = query.offset(skip) + if limit is not None: + query = query.limit(limit) + + notes = query.all() + note_models = [NoteModel.model_validate(note) for note in notes] + + filtered_notes = [] + for note in note_models: + if (note.user_id == user_id or + has_access(user_id, permission, note.access_control, user_group_ids)): + filtered_notes.append(note) + + return filtered_notes def get_note_by_id(self, id: str) -> Optional[NoteModel]: with get_db() as db: diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index dff7bc2e7f..b79174b793 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -48,7 +48,7 @@ async def get_notes(request: Request, user=Depends(get_verified_user)): "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), } ) - for note in Notes.get_notes_by_user_id(user.id, "write") + for note in Notes.get_notes_by_access(user.id, "write") ] return notes @@ -81,7 +81,7 @@ async def get_note_list( notes = [ NoteTitleIdResponse(**note.model_dump()) - for note in Notes.get_notes_by_user_id(user.id, "write", skip=skip, limit=limit) + for note in Notes.get_notes_by_access(user.id, "write", skip=skip, limit=limit) ] return notes