refac/fix: proper notes db operations

This commit is contained in:
Timothy Jaeryang Baek 2025-09-25 13:47:43 -05:00
parent 5b1f9e3e21
commit da661756fa
2 changed files with 35 additions and 29 deletions

View file

@ -128,7 +128,7 @@ class NoteTable:
notes = query.all() notes = query.all()
return [NoteModel.model_validate(note) for note in notes] return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_access( def get_notes_by_permission(
self, self,
user_id: str, user_id: str,
permission: str = "write", permission: str = "write",
@ -137,40 +137,44 @@ class NoteTable:
) -> list[NoteModel]: ) -> list[NoteModel]:
with get_db() as db: with get_db() as db:
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = Groups.get_groups_by_member_id(user_id)
user_group_ids = {group_id for group_id in user_groups} user_group_ids = {group.id for group in user_groups}
query = db.query(Note) # Order newest-first. We stream to keep memory usage low.
query = (
access_conditions = [Note.user_id == user_id] db.query(Note)
.order_by(Note.updated_at.desc())
if user_group_ids: .execution_options(stream_results=True)
access_conditions.append( .yield_per(256)
and_(
Note.access_control.isnot(None),
Note.access_control != '{}',
Note.access_control != 'null'
)
) )
query = query.filter(or_(*access_conditions)) results: list[NoteModel] = []
n_skipped = 0
query = query.order_by(Note.updated_at.desc()) for note in query:
# Fast-pass #1: owner
if note.user_id == user_id:
permitted = True
# Fast-pass #2: public/open
elif note.access_control is None:
permitted = True
else:
permitted = has_access(
user_id, permission, note.access_control, user_group_ids
)
if skip is not None: if not permitted:
query = query.offset(skip) continue
if limit is not None:
query = query.limit(limit)
notes = query.all() # Apply skip AFTER permission filtering so it counts only accessible notes
note_models = [NoteModel.model_validate(note) for note in notes] if skip and n_skipped < skip:
n_skipped += 1
continue
filtered_notes = [] results.append(NoteModel.model_validate(note))
for note in note_models: if limit is not None and len(results) >= limit:
if (note.user_id == user_id or break
has_access(user_id, permission, note.access_control, user_group_ids)):
filtered_notes.append(note)
return filtered_notes return results
def get_note_by_id(self, id: str) -> Optional[NoteModel]: def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db: with get_db() as db:

View file

@ -48,7 +48,7 @@ async def get_notes(request: Request, user=Depends(get_verified_user)):
"user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()),
} }
) )
for note in Notes.get_notes_by_access(user.id, "write") for note in Notes.get_notes_by_permission(user.id, "write")
] ]
return notes return notes
@ -81,7 +81,9 @@ async def get_note_list(
notes = [ notes = [
NoteTitleIdResponse(**note.model_dump()) NoteTitleIdResponse(**note.model_dump())
for note in Notes.get_notes_by_access(user.id, "write", skip=skip, limit=limit) for note in Notes.get_notes_by_permission(
user.id, "write", skip=skip, limit=limit
)
] ]
return notes return notes