Merge pull request #17607 from sihyeonn/perf/sh-notes

perf: optimize notes query and separate access control logic
This commit is contained in:
Tim Jaeryang Baek 2025-09-25 12:13:42 -05:00 committed by GitHub
commit b76d234f97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 57 additions and 10 deletions

View file

@ -2,6 +2,7 @@ import json
import time import time
import uuid import uuid
from typing import Optional from typing import Optional
from functools import lru_cache
from open_webui.internal.db import Base, get_db from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
@ -110,20 +111,66 @@ class NoteTable:
return [NoteModel.model_validate(note) for note in notes] return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_user_id( def get_notes_by_user_id(
self,
user_id: str,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[NoteModel]:
with get_db() as db:
query = db.query(Note).filter(Note.user_id == user_id)
query = query.order_by(Note.updated_at.desc())
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
notes = query.all()
return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_access(
self, self,
user_id: str, user_id: str,
permission: str = "write", permission: str = "write",
skip: Optional[int] = None, skip: Optional[int] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> list[NoteModel]: ) -> list[NoteModel]:
notes = self.get_notes(skip=skip, limit=limit) with get_db() as db:
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} user_groups = Groups.get_groups_by_member_id(user_id)
return [ user_group_ids = {group_id for group_id in user_groups}
note
for note in notes query = db.query(Note)
if note.user_id == user_id
or has_access(user_id, permission, note.access_control, user_group_ids) access_conditions = [Note.user_id == user_id]
]
if user_group_ids:
access_conditions.append(
and_(
Note.access_control.isnot(None),
Note.access_control != '{}',
Note.access_control != 'null'
)
)
query = query.filter(or_(*access_conditions))
query = query.order_by(Note.updated_at.desc())
if skip is not None:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
notes = query.all()
note_models = [NoteModel.model_validate(note) for note in notes]
filtered_notes = []
for note in note_models:
if (note.user_id == user_id or
has_access(user_id, permission, note.access_control, user_group_ids)):
filtered_notes.append(note)
return filtered_notes
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db: with get_db() as db:

View file

@ -48,7 +48,7 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
} }
) )
for note in Notes.get_notes_by_user_id(user.id, "write") for note in Notes.get_notes_by_access(user.id, "write")
] ]
return notes return notes
@ -81,7 +81,7 @@ async def get_note_list(
notes = [ notes = [
NoteTitleIdResponse(**note.model_dump()) NoteTitleIdResponse(**note.model_dump())
for note in Notes.get_notes_by_user_id(user.id, "write", skip=skip, limit=limit) for note in Notes.get_notes_by_access(user.id, "write", skip=skip, limit=limit)
] ]
return notes return notes