diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 3aba302ff0..5f665146ae 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -306,9 +306,9 @@ class ChannelTable: return memberships def insert_new_channel( - self, form_data: CreateChannelForm, user_id: str + self, form_data: CreateChannelForm, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: channel = ChannelModel( **{ **form_data.model_dump(), @@ -389,7 +389,7 @@ class ChannelTable: def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]: with get_db_context(db) as db: user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id) + group.id for group in Groups.get_groups_by_member_id(user_id, db=db) ] membership_channels = ( @@ -423,8 +423,8 @@ class ChannelTable: all_channels = membership_channels + standard_channels return [ChannelModel.model_validate(c) for c in all_channels] - def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]: - with get_db() as db: + def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]: + with get_db_context(db) as db: # Ensure uniqueness in case a list with duplicates is passed unique_user_ids = list(set(user_ids)) @@ -462,8 +462,9 @@ class ChannelTable: invited_by: str, user_ids: Optional[list[str]] = None, group_ids: Optional[list[str]] = None, + db: Optional[Session] = None, ) -> list[ChannelMemberModel]: - with get_db() as db: + with get_db_context(db) as db: # 1. Collect all user_ids including groups + inviter requested_users = self._collect_unique_user_ids( invited_by, user_ids, group_ids @@ -496,8 +497,9 @@ class ChannelTable: self, channel_id: str, user_ids: list[str], + db: Optional[Session] = None, ) -> int: - with get_db() as db: + with get_db_context(db) as db: result = ( db.query(ChannelMember) .filter( @@ -509,8 +511,8 @@ class ChannelTable: db.commit() return result # number of rows deleted - def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: # Check if the user is the creator of the channel # or has a 'manager' role in ChannelMember channel = db.query(Channel).filter(Channel.id == channel_id).first() @@ -529,9 +531,9 @@ class ChannelTable: return membership is not None def join_channel( - self, channel_id: str, user_id: str + self, channel_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelMemberModel]: - with get_db() as db: + with get_db_context(db) as db: # Check if the membership already exists existing_membership = ( db.query(ChannelMember) @@ -567,8 +569,8 @@ class ChannelTable: db.commit() return channel_member - def leave_channel(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -589,9 +591,9 @@ class ChannelTable: return True def get_member_by_channel_and_user_id( - self, channel_id: str, user_id: str + self, channel_id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelMemberModel]: - with get_db() as db: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -602,8 +604,8 @@ class ChannelTable: ) return ChannelMemberModel.model_validate(membership) if membership else None - def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]: - with get_db() as db: + def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]: + with get_db_context(db) as db: memberships = ( db.query(ChannelMember) .filter(ChannelMember.channel_id == channel_id) @@ -614,8 +616,8 @@ class ChannelTable: for membership in memberships ] - def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool: - with get_db() as db: + def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -633,8 +635,8 @@ class ChannelTable: db.commit() return True - def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -653,9 +655,9 @@ class ChannelTable: return True def update_member_active_status( - self, channel_id: str, user_id: str, is_active: bool + self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None ) -> bool: - with get_db() as db: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -673,8 +675,8 @@ class ChannelTable: db.commit() return True - def is_user_channel_member(self, channel_id: str, user_id: str) -> bool: - with get_db() as db: + def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: membership = ( db.query(ChannelMember) .filter( @@ -693,8 +695,8 @@ class ChannelTable: except Exception: return None - def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]: - with get_db() as db: + def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]: + with get_db_context(db) as db: channel_files = ( db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() ) @@ -703,9 +705,9 @@ class ChannelTable: return [ChannelModel.model_validate(channel) for channel in channels] def get_channels_by_file_id_and_user_id( - self, file_id: str, user_id: str + self, file_id: str, user_id: str, db: Optional[Session] = None ) -> list[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: # 1. Determine which channels have this file channel_file_rows = ( db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() @@ -729,7 +731,7 @@ class ChannelTable: return [] # Preload user's group membership - user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)] + user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)] allowed_channels = [] @@ -766,9 +768,9 @@ class ChannelTable: return allowed_channels def get_channel_by_id_and_user_id( - self, id: str, user_id: str + self, id: str, user_id: str, db: Optional[Session] = None ) -> Optional[ChannelModel]: - with get_db() as db: + with get_db_context(db) as db: # Fetch the channel channel: Channel = ( db.query(Channel) diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 387dd5b28c..f734ea62d7 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -135,7 +135,9 @@ class FilesTable: log.exception(f"Error inserting a new file: {e}") return None - def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]: + def get_file_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[FileModel]: try: with get_db_context(db) as db: try: @@ -146,8 +148,10 @@ class FilesTable: except Exception: return None - def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: - with get_db() as db: + def get_file_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id, user_id=user_id).first() if file: @@ -157,8 +161,10 @@ class FilesTable: except Exception: return None - def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: - with get_db() as db: + def get_file_metadata_by_id( + self, id: str, db: Optional[Session] = None + ) -> Optional[FileMetadataResponse]: + with get_db_context(db) as db: try: file = db.get(File, id) return FileMetadataResponse( @@ -175,8 +181,10 @@ class FilesTable: with get_db_context(db) as db: return [FileModel.model_validate(file) for file in db.query(File).all()] - def check_access_by_user_id(self, id, user_id, permission="write") -> bool: - file = self.get_file_by_id(id) + def check_access_by_user_id( + self, id, user_id, permission="write", db: Optional[Session] = None + ) -> bool: + file = self.get_file_by_id(id, db=db) if not file: return False if file.user_id == user_id: @@ -184,8 +192,10 @@ class FilesTable: # Implement additional access control logic here as needed return False - def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: - with get_db() as db: + def get_files_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> list[FileModel]: + with get_db_context(db) as db: return [ FileModel.model_validate(file) for file in db.query(File) @@ -194,8 +204,10 @@ class FilesTable: .all() ] - def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: - with get_db() as db: + def get_file_metadatas_by_ids( + self, ids: list[str], db: Optional[Session] = None + ) -> list[FileMetadataResponse]: + with get_db_context(db) as db: return [ FileMetadataResponse( id=file.id, @@ -212,7 +224,9 @@ class FilesTable: .all() ] - def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]: + def get_files_by_user_id( + self, user_id: str, db: Optional[Session] = None + ) -> list[FileModel]: with get_db_context(db) as db: return [ FileModel.model_validate(file) @@ -220,9 +234,9 @@ class FilesTable: ] def update_file_by_id( - self, id: str, form_data: FileUpdateForm + self, id: str, form_data: FileUpdateForm, db: Optional[Session] = None ) -> Optional[FileModel]: - with get_db() as db: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() @@ -242,8 +256,10 @@ class FilesTable: log.exception(f"Error updating file completely by id: {e}") return None - def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: - with get_db() as db: + def update_file_hash_by_id( + self, id: str, hash: str, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() file.hash = hash @@ -254,8 +270,10 @@ class FilesTable: except Exception: return None - def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: - with get_db() as db: + def update_file_data_by_id( + self, id: str, data: dict, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() file.data = {**(file.data if file.data else {}), **data} @@ -266,8 +284,10 @@ class FilesTable: return None - def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]: - with get_db() as db: + def update_file_metadata_by_id( + self, id: str, meta: dict, db: Optional[Session] = None + ) -> Optional[FileModel]: + with get_db_context(db) as db: try: file = db.query(File).filter_by(id=id).first() file.meta = {**(file.meta if file.meta else {}), **meta} @@ -279,5 +299,25 @@ class FilesTable: return False + def delete_file_by_id(self, id: str, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: + try: + db.query(File).filter_by(id=id).delete() + db.commit() + + return True + except Exception: + return False + + def delete_all_files(self, db: Optional[Session] = None) -> bool: + with get_db_context(db) as db: + try: + db.query(File).delete() + db.commit() + + return True + except Exception: + return False + Files = FilesTable()