From 5d1459df166cce8445eb1556cc26abbd65a3f9f4 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 29 Dec 2025 01:20:04 +0400 Subject: [PATCH] refac --- backend/open_webui/routers/channels.py | 205 ++++++++++++++----------- backend/open_webui/routers/images.py | 5 +- backend/open_webui/routers/scim.py | 50 +++--- 3 files changed, 148 insertions(+), 112 deletions(-) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 954be49721..f3ca874538 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -61,6 +61,8 @@ from open_webui.utils.access_control import ( ) from open_webui.utils.webhook import post_webhook from open_webui.utils.channels import extract_mentions, replace_mentions +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session log = logging.getLogger(__name__) @@ -98,26 +100,27 @@ class ChannelListItemResponse(ChannelModel): async def get_channels( request: Request, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channels = Channels.get_channels_by_user_id(user.id) + channels = Channels.get_channels_by_user_id(user.id, db=db) channel_list = [] for channel in channels: - last_message = Messages.get_last_message_by_channel_id(channel.id) + last_message = Messages.get_last_message_by_channel_id(channel.id, db=db) last_message_at = last_message.created_at if last_message else None - channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id) + channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) unread_count = ( Messages.get_unread_message_count( - channel.id, user.id, channel_member.last_read_at + channel.id, user.id, channel_member.last_read_at, db=db ) if channel_member else 0 @@ -128,13 +131,13 @@ async def get_channels( if channel.type == "dm": user_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(channel.id) + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] users = [ UserIdNameStatusResponse( - **{**user.model_dump(), "is_active": Users.is_user_active(user.id)} + **{**user.model_dump(), "is_active": Users.is_user_active(user.id, db=db)} ) - for user in Users.get_users_by_user_ids(user_ids) + for user in Users.get_users_by_user_ids(user_ids, db=db) ] channel_list.append( @@ -154,11 +157,12 @@ async def get_channels( async def get_all_channels( request: Request, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role == "admin": - return Channels.get_channels() - return Channels.get_channels_by_user_id(user.id) + return Channels.get_channels(db=db) + return Channels.get_channels_by_user_id(user.id, db=db) ############################ @@ -171,10 +175,11 @@ async def get_dm_channel_by_user_id( request: Request, user_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -182,11 +187,11 @@ async def get_dm_channel_by_user_id( ) try: - existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id]) + existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id], db=db) if existing_channel: participant_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(existing_channel.id) + for member in Channels.get_members_by_channel_id(existing_channel.id, db=db) ] await emit_to_users( @@ -198,7 +203,7 @@ async def get_dm_channel_by_user_id( f"channel:{existing_channel.id}", participant_ids ) - Channels.update_member_active_status(existing_channel.id, user.id, True) + Channels.update_member_active_status(existing_channel.id, user.id, True, db=db) return ChannelModel(**existing_channel.model_dump()) channel = Channels.insert_new_channel( @@ -208,12 +213,13 @@ async def get_dm_channel_by_user_id( user_ids=[user_id], ), user.id, + db=db, ) if channel: participant_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(channel.id) + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] await emit_to_users( @@ -243,10 +249,11 @@ async def create_new_channel( request: Request, form_data: CreateChannelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -263,13 +270,13 @@ async def create_new_channel( try: if form_data.type == "dm": existing_channel = Channels.get_dm_channel_by_user_ids( - [user.id, *form_data.user_ids] + [user.id, *form_data.user_ids], db=db ) if existing_channel: participant_ids = [ member.user_id for member in Channels.get_members_by_channel_id( - existing_channel.id + existing_channel.id, db=db ) ] await emit_to_users( @@ -281,15 +288,15 @@ async def create_new_channel( f"channel:{existing_channel.id}", participant_ids ) - Channels.update_member_active_status(existing_channel.id, user.id, True) + Channels.update_member_active_status(existing_channel.id, user.id, True, db=db) return ChannelModel(**existing_channel.model_dump()) - channel = Channels.insert_new_channel(form_data, user.id) + channel = Channels.insert_new_channel(form_data, user.id, db=db) if channel: participant_ids = [ member.user_id - for member in Channels.get_members_by_channel_id(channel.id) + for member in Channels.get_members_by_channel_id(channel.id, db=db) ] await emit_to_users( @@ -327,9 +334,10 @@ async def get_channel_by_id( request: Request, id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -339,7 +347,7 @@ async def get_channel_by_id( users = None if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -355,7 +363,7 @@ async def get_channel_by_id( for user in Users.get_users_by_user_ids(user_ids) ] - channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id) + channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) unread_count = Messages.get_unread_message_count( channel.id, user.id, channel_member.last_read_at if channel_member else None ) @@ -365,7 +373,7 @@ async def get_channel_by_id( **channel.model_dump(), "user_ids": user_ids, "users": users, - "is_manager": Channels.is_user_channel_manager(channel.id, user.id), + "is_manager": Channels.is_user_channel_manager(channel.id, user.id, db=db), "write_access": True, "user_count": len(user_ids), "last_read_at": channel_member.last_read_at if channel_member else None, @@ -386,7 +394,7 @@ async def get_channel_by_id( user_count = len(get_users_with_access("read", channel.access_control)) - channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id) + channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db) unread_count = Messages.get_unread_message_count( channel.id, user.id, channel_member.last_read_at if channel_member else None ) @@ -396,7 +404,7 @@ async def get_channel_by_id( **channel.model_dump(), "user_ids": user_ids, "users": users, - "is_manager": Channels.is_user_channel_manager(channel.id, user.id), + "is_manager": Channels.is_user_channel_manager(channel.id, user.id, db=db), "write_access": write_access or user.role == "admin", "user_count": user_count, "last_read_at": channel_member.last_read_at if channel_member else None, @@ -422,10 +430,11 @@ async def get_channel_members_by_id( direction: Optional[str] = None, page: Optional[int] = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -437,7 +446,7 @@ async def get_channel_members_by_id( skip = (page - 1) * limit if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -510,15 +519,16 @@ async def update_is_active_member_by_id_and_user_id( id: str, form_data: UpdateActiveMemberForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) @@ -543,17 +553,18 @@ async def add_members_by_id( id: str, form_data: UpdateMembersForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -592,17 +603,18 @@ async def remove_members_by_id( id: str, form_data: RemoveMembersForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -614,7 +626,7 @@ async def remove_members_by_id( ) try: - deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids) + deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids, db=db) return deleted except Exception as e: @@ -635,17 +647,18 @@ async def update_channel_by_id( id: str, form_data: ChannelForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -657,7 +670,7 @@ async def update_channel_by_id( ) try: - channel = Channels.update_channel_by_id(id, form_data) + channel = Channels.update_channel_by_id(id, form_data, db=db) return ChannelModel(**channel.model_dump()) except Exception as e: log.exception(e) @@ -676,17 +689,18 @@ async def delete_channel_by_id( request: Request, id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) if user.role != "admin" and not has_permission( - user.id, "features.channels", request.app.state.config.USER_PERMISSIONS + user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED, ) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -698,7 +712,7 @@ async def delete_channel_by_id( ) try: - Channels.delete_channel_by_id(id) + Channels.delete_channel_by_id(id, db=db) return True except Exception as e: log.exception(e) @@ -732,16 +746,17 @@ async def get_channel_messages( skip: int = 0, limit: int = 50, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -763,7 +778,7 @@ async def get_channel_messages( messages = [] for message in message_list: if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) + user = Users.get_user_by_id(message.user_id, db=db) users[message.user_id] = user thread_replies = Messages.get_thread_replies_by_message_id(message.id) @@ -799,16 +814,17 @@ async def get_pinned_channel_messages( id: str, page: int = 1, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -830,7 +846,7 @@ async def get_pinned_channel_messages( messages = [] for message in message_list: if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) + user = Users.get_user_by_id(message.user_id, db=db) users[message.user_id] = user messages.append( @@ -944,7 +960,7 @@ async def model_response_handler(request, channel, message, user): for thread_message in thread_messages: message_user = None if thread_message.user_id not in message_users: - message_user = Users.get_user_by_id(thread_message.user_id) + message_user = Users.get_user_by_id(thread_message.user_id, db=db) message_users[thread_message.user_id] = message_user else: message_user = message_users[thread_message.user_id] @@ -1059,39 +1075,39 @@ async def model_response_handler(request, channel, message, user): async def new_message_handler( - request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user) + request: Request, id: str, form_data: MessageForm, user, db ): - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) else: if user.role != "admin" and not has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, type="write", access_control=channel.access_control, strict=False, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) try: - message = Messages.insert_new_message(form_data, channel.id, user.id) + message = Messages.insert_new_message(form_data, channel.id, user.id, db=db) if message: if channel.type in ["group", "dm"]: - members = Channels.get_members_by_channel_id(channel.id) + members = Channels.get_members_by_channel_id(channel.id, db=db) for member in members: if not member.is_active: Channels.update_member_active_status( - channel.id, member.user_id, True + channel.id, member.user_id, True, db=db ) - message = Messages.get_message_by_id(message.id) + message = Messages.get_message_by_id(message.id, db=db) event_data = { "channel_id": channel.id, "message_id": message.id, @@ -1111,7 +1127,7 @@ async def new_message_handler( if message.parent_id: # If this message is a reply, emit to the parent message as well - parent_message = Messages.get_message_by_id(message.parent_id) + parent_message = Messages.get_message_by_id(message.parent_id, db=db) if parent_message: await sio.emit( @@ -1145,16 +1161,17 @@ async def post_new_message( form_data: MessageForm, background_tasks: BackgroundTasks, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) try: - message, channel = await new_message_handler(request, id, form_data, user) + message, channel = await new_message_handler(request, id, form_data, user, db) try: if files := message.data.get("files", []): for file in files: Channels.set_file_message_id_in_channel_by_id( - channel.id, file.get("id", ""), message.id + channel.id, file.get("id", ""), message.id, db=db ) except Exception as e: log.debug(e) @@ -1195,16 +1212,17 @@ async def get_channel_message( id: str, message_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1216,7 +1234,7 @@ async def get_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1231,7 +1249,7 @@ async def get_channel_message( **{ **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **Users.get_user_by_id(message.user_id, db=db).model_dump() ), } ) @@ -1248,16 +1266,17 @@ async def get_channel_message_data( id: str, message_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1269,7 +1288,7 @@ async def get_channel_message_data( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1301,16 +1320,17 @@ async def pin_channel_message( message_id: str, form_data: PinMessageForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1322,7 +1342,7 @@ async def pin_channel_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1335,12 +1355,12 @@ async def pin_channel_message( try: Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) return MessageUserResponse( **{ **message.model_dump(), "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() + **Users.get_user_by_id(message.user_id, db=db).model_dump() ), } ) @@ -1366,16 +1386,17 @@ async def get_channel_thread_messages( skip: int = 0, limit: int = 50, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1393,7 +1414,7 @@ async def get_channel_thread_messages( messages = [] for message in message_list: if message.user_id not in users: - user = Users.get_user_by_id(message.user_id) + user = Users.get_user_by_id(message.user_id, db=db) users[message.user_id] = user messages.append( @@ -1425,15 +1446,16 @@ async def update_message_by_id( message_id: str, form_data: MessageForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1445,7 +1467,7 @@ async def update_message_by_id( ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1463,7 +1485,7 @@ async def update_message_by_id( try: message = Messages.update_message_by_id(message_id, form_data) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if message: await sio.emit( @@ -1505,16 +1527,17 @@ async def add_reaction_to_message( message_id: str, form_data: ReactionForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1526,7 +1549,7 @@ async def add_reaction_to_message( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1539,7 +1562,7 @@ async def add_reaction_to_message( try: Messages.add_reaction_to_message(message_id, user.id, form_data.name) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) await sio.emit( "events:channel", @@ -1579,16 +1602,17 @@ async def remove_reaction_by_id_and_user_id_and_name( message_id: str, form_data: ReactionForm, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) @@ -1600,7 +1624,7 @@ async def remove_reaction_by_id_and_user_id_and_name( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1616,7 +1640,7 @@ async def remove_reaction_by_id_and_user_id_and_name( message_id, user.id, form_data.name ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) await sio.emit( "events:channel", @@ -1655,15 +1679,16 @@ async def delete_message_by_id( id: str, message_id: str, user=Depends(get_verified_user), + db: Session = Depends(get_session), ): check_channels_access(request) - channel = Channels.get_channel_by_id(id) + channel = Channels.get_channel_by_id(id, db=db) if not channel: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - message = Messages.get_message_by_id(message_id) + message = Messages.get_message_by_id(message_id, db=db) if not message: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND @@ -1675,7 +1700,7 @@ async def delete_message_by_id( ) if channel.type in ["group", "dm"]: - if not Channels.is_user_channel_member(channel.id, user.id): + if not Channels.is_user_channel_member(channel.id, user.id, db=db): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 8037d2077d..11e192e6ba 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -22,6 +22,8 @@ from open_webui.models.chats import Chats from open_webui.routers.files import upload_file_handler, get_file_content_by_id from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.headers import include_user_info_headers +from open_webui.internal.db import get_session +from sqlalchemy.orm import Session from open_webui.utils.images.comfyui import ( ComfyUICreateImageForm, ComfyUIEditImageForm, @@ -496,7 +498,7 @@ def get_image_data(data: str, headers=None): return None, None -def upload_image(request, image_data, content_type, metadata, user): +def upload_image(request, image_data, content_type, metadata, user, db=None): image_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(image_data), @@ -524,6 +526,7 @@ def upload_image(request, image_data, content_type, metadata, user): message_id=message_id, file_ids=[file_item.id], user_id=user.id, + db=db, ) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index d26b499700..9070256770 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -300,7 +300,7 @@ def get_scim_auth( ) -def user_to_scim(user: UserModel, request: Request) -> SCIMUser: +def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: """Convert internal User model to SCIM User""" # Parse display name into name components name_parts = user.name.split(" ", 1) if user.name else ["", ""] @@ -308,7 +308,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser: family_name = name_parts[1] if len(name_parts) > 1 else "" # Get user's groups - user_groups = Groups.get_groups_by_member_id(user.id) + user_groups = Groups.get_groups_by_member_id(user.id, db=db) groups = [ { "value": group.id, @@ -487,6 +487,7 @@ async def get_users( count: int = Query(20, ge=1, le=100), filter: Optional[str] = None, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """List SCIM Users""" skip = startIndex - 1 @@ -498,20 +499,20 @@ async def get_users( # In production, you'd want a more robust filter parser if "userName eq" in filter: email = filter.split('"')[1] - user = Users.get_user_by_email(email) + user = Users.get_user_by_email(email, db=db) users_list = [user] if user else [] total = 1 if user else 0 else: - response = Users.get_users(skip=skip, limit=limit) + response = Users.get_users(skip=skip, limit=limit, db=db) users_list = response["users"] total = response["total"] else: - response = Users.get_users(skip=skip, limit=limit) + response = Users.get_users(skip=skip, limit=limit, db=db) users_list = response["users"] total = response["total"] # Convert to SCIM format - scim_users = [user_to_scim(user, request) for user in users_list] + scim_users = [user_to_scim(user, request, db=db) for user in users_list] return SCIMListResponse( totalResults=total, @@ -526,15 +527,16 @@ async def get_user( user_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Get SCIM User by ID""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: return scim_error( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" ) - return user_to_scim(user, request) + return user_to_scim(user, request, db=db) @router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) @@ -542,10 +544,11 @@ async def create_user( request: Request, user_data: SCIMUserCreateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Create SCIM User""" # Check if user already exists - existing_user = Users.get_user_by_email(user_data.userName) + existing_user = Users.get_user_by_email(user_data.userName, db=db) if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -576,6 +579,7 @@ async def create_user( email=email, profile_image_url=profile_image, role="user" if user_data.active else "pending", + db=db, ) if not new_user: @@ -584,7 +588,7 @@ async def create_user( detail="Failed to create user", ) - return user_to_scim(new_user, request) + return user_to_scim(new_user, request, db=db) @router.put("/Users/{user_id}", response_model=SCIMUser) @@ -593,9 +597,10 @@ async def update_user( request: Request, user_data: SCIMUserUpdateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM User (full update)""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -628,14 +633,14 @@ async def update_user( update_data["profile_image_url"] = user_data.photos[0].value # Update user - updated_user = Users.update_user_by_id(user_id, update_data) + updated_user = Users.update_user_by_id(user_id, update_data, db=db) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update user", ) - return user_to_scim(updated_user, request) + return user_to_scim(updated_user, request, db=db) @router.patch("/Users/{user_id}", response_model=SCIMUser) @@ -644,9 +649,10 @@ async def patch_user( request: Request, patch_data: SCIMPatchRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM User (partial update)""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -674,7 +680,7 @@ async def patch_user( # Update user if update_data: - updated_user = Users.update_user_by_id(user_id, update_data) + updated_user = Users.update_user_by_id(user_id, update_data, db=db) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -683,7 +689,7 @@ async def patch_user( else: updated_user = user - return user_to_scim(updated_user, request) + return user_to_scim(updated_user, request, db=db) @router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -691,16 +697,17 @@ async def delete_user( user_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Delete SCIM User""" - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found", ) - success = Users.delete_user_by_id(user_id) + success = Users.delete_user_by_id(user_id, db=db) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -749,7 +756,7 @@ async def get_group( db: Session = Depends(get_session), ): """Get SCIM Group by ID""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -921,16 +928,17 @@ async def delete_group( group_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Delete SCIM Group""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - success = Groups.delete_group_by_id(group_id) + success = Groups.delete_group_by_id(group_id, db=db) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,