This commit is contained in:
Timothy Jaeryang Baek 2025-12-28 22:26:35 +04:00
parent 2041ab483e
commit 145c7516f2
2 changed files with 94 additions and 52 deletions

View file

@ -306,9 +306,9 @@ class ChannelTable:
return memberships return memberships
def insert_new_channel( def insert_new_channel(
self, form_data: CreateChannelForm, user_id: str self, form_data: CreateChannelForm, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]: ) -> Optional[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
channel = ChannelModel( channel = ChannelModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
@ -389,7 +389,7 @@ class ChannelTable:
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]: def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
user_group_ids = [ user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id) group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
] ]
membership_channels = ( membership_channels = (
@ -423,8 +423,8 @@ class ChannelTable:
all_channels = membership_channels + standard_channels all_channels = membership_channels + standard_channels
return [ChannelModel.model_validate(c) for c in all_channels] return [ChannelModel.model_validate(c) for c in all_channels]
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]: def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
# Ensure uniqueness in case a list with duplicates is passed # Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids)) unique_user_ids = list(set(user_ids))
@ -462,8 +462,9 @@ class ChannelTable:
invited_by: str, invited_by: str,
user_ids: Optional[list[str]] = None, user_ids: Optional[list[str]] = None,
group_ids: Optional[list[str]] = None, group_ids: Optional[list[str]] = None,
db: Optional[Session] = None,
) -> list[ChannelMemberModel]: ) -> list[ChannelMemberModel]:
with get_db() as db: with get_db_context(db) as db:
# 1. Collect all user_ids including groups + inviter # 1. Collect all user_ids including groups + inviter
requested_users = self._collect_unique_user_ids( requested_users = self._collect_unique_user_ids(
invited_by, user_ids, group_ids invited_by, user_ids, group_ids
@ -496,8 +497,9 @@ class ChannelTable:
self, self,
channel_id: str, channel_id: str,
user_ids: list[str], user_ids: list[str],
db: Optional[Session] = None,
) -> int: ) -> int:
with get_db() as db: with get_db_context(db) as db:
result = ( result = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -509,8 +511,8 @@ class ChannelTable:
db.commit() db.commit()
return result # number of rows deleted return result # number of rows deleted
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool: def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
# Check if the user is the creator of the channel # Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember # or has a 'manager' role in ChannelMember
channel = db.query(Channel).filter(Channel.id == channel_id).first() channel = db.query(Channel).filter(Channel.id == channel_id).first()
@ -529,9 +531,9 @@ class ChannelTable:
return membership is not None return membership is not None
def join_channel( def join_channel(
self, channel_id: str, user_id: str self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelMemberModel]: ) -> Optional[ChannelMemberModel]:
with get_db() as db: with get_db_context(db) as db:
# Check if the membership already exists # Check if the membership already exists
existing_membership = ( existing_membership = (
db.query(ChannelMember) db.query(ChannelMember)
@ -567,8 +569,8 @@ class ChannelTable:
db.commit() db.commit()
return channel_member return channel_member
def leave_channel(self, channel_id: str, user_id: str) -> bool: def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -589,9 +591,9 @@ class ChannelTable:
return True return True
def get_member_by_channel_and_user_id( def get_member_by_channel_and_user_id(
self, channel_id: str, user_id: str self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelMemberModel]: ) -> Optional[ChannelMemberModel]:
with get_db() as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -602,8 +604,8 @@ class ChannelTable:
) )
return ChannelMemberModel.model_validate(membership) if membership else None return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]: def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
with get_db() as db: with get_db_context(db) as db:
memberships = ( memberships = (
db.query(ChannelMember) db.query(ChannelMember)
.filter(ChannelMember.channel_id == channel_id) .filter(ChannelMember.channel_id == channel_id)
@ -614,8 +616,8 @@ class ChannelTable:
for membership in memberships for membership in memberships
] ]
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool: def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -633,8 +635,8 @@ class ChannelTable:
db.commit() db.commit()
return True return True
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool: def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -653,9 +655,9 @@ class ChannelTable:
return True return True
def update_member_active_status( def update_member_active_status(
self, channel_id: str, user_id: str, is_active: bool self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None
) -> bool: ) -> bool:
with get_db() as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -673,8 +675,8 @@ class ChannelTable:
db.commit() db.commit()
return True return True
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool: def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db() as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@ -693,8 +695,8 @@ class ChannelTable:
except Exception: except Exception:
return None return None
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]: def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
channel_files = ( channel_files = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
) )
@ -703,9 +705,9 @@ class ChannelTable:
return [ChannelModel.model_validate(channel) for channel in channels] return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_file_id_and_user_id( def get_channels_by_file_id_and_user_id(
self, file_id: str, user_id: str self, file_id: str, user_id: str, db: Optional[Session] = None
) -> list[ChannelModel]: ) -> list[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
# 1. Determine which channels have this file # 1. Determine which channels have this file
channel_file_rows = ( channel_file_rows = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all() db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
@ -729,7 +731,7 @@ class ChannelTable:
return [] return []
# Preload user's group membership # Preload user's group membership
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)] user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
allowed_channels = [] allowed_channels = []
@ -766,9 +768,9 @@ class ChannelTable:
return allowed_channels return allowed_channels
def get_channel_by_id_and_user_id( def get_channel_by_id_and_user_id(
self, id: str, user_id: str self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]: ) -> Optional[ChannelModel]:
with get_db() as db: with get_db_context(db) as db:
# Fetch the channel # Fetch the channel
channel: Channel = ( channel: Channel = (
db.query(Channel) db.query(Channel)

View file

@ -135,7 +135,9 @@ class FilesTable:
log.exception(f"Error inserting a new file: {e}") log.exception(f"Error inserting a new file: {e}")
return None return None
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]: def get_file_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
@ -146,8 +148,10 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: def get_file_by_id_and_user_id(
with get_db() as db: self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id, user_id=user_id).first() file = db.query(File).filter_by(id=id, user_id=user_id).first()
if file: if file:
@ -157,8 +161,10 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: def get_file_metadata_by_id(
with get_db() as db: self, id: str, db: Optional[Session] = None
) -> Optional[FileMetadataResponse]:
with get_db_context(db) as db:
try: try:
file = db.get(File, id) file = db.get(File, id)
return FileMetadataResponse( return FileMetadataResponse(
@ -175,8 +181,10 @@ class FilesTable:
with get_db_context(db) as db: with get_db_context(db) as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id(self, id, user_id, permission="write") -> bool: def check_access_by_user_id(
file = self.get_file_by_id(id) self, id, user_id, permission="write", db: Optional[Session] = None
) -> bool:
file = self.get_file_by_id(id, db=db)
if not file: if not file:
return False return False
if file.user_id == user_id: if file.user_id == user_id:
@ -184,8 +192,10 @@ class FilesTable:
# Implement additional access control logic here as needed # Implement additional access control logic here as needed
return False return False
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: def get_files_by_ids(
with get_db() as db: self, ids: list[str], db: Optional[Session] = None
) -> list[FileModel]:
with get_db_context(db) as db:
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in db.query(File) for file in db.query(File)
@ -194,8 +204,10 @@ class FilesTable:
.all() .all()
] ]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]: def get_file_metadatas_by_ids(
with get_db() as db: self, ids: list[str], db: Optional[Session] = None
) -> list[FileMetadataResponse]:
with get_db_context(db) as db:
return [ return [
FileMetadataResponse( FileMetadataResponse(
id=file.id, id=file.id,
@ -212,7 +224,9 @@ class FilesTable:
.all() .all()
] ]
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]: def get_files_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
@ -220,9 +234,9 @@ class FilesTable:
] ]
def update_file_by_id( def update_file_by_id(
self, id: str, form_data: FileUpdateForm self, id: str, form_data: FileUpdateForm, db: Optional[Session] = None
) -> Optional[FileModel]: ) -> Optional[FileModel]:
with get_db() as db: with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
@ -242,8 +256,10 @@ class FilesTable:
log.exception(f"Error updating file completely by id: {e}") log.exception(f"Error updating file completely by id: {e}")
return None return None
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]: def update_file_hash_by_id(
with get_db() as db: self, id: str, hash: str, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
file.hash = hash file.hash = hash
@ -254,8 +270,10 @@ class FilesTable:
except Exception: except Exception:
return None return None
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]: def update_file_data_by_id(
with get_db() as db: self, id: str, data: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
file.data = {**(file.data if file.data else {}), **data} file.data = {**(file.data if file.data else {}), **data}
@ -266,8 +284,10 @@ class FilesTable:
return None return None
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]: def update_file_metadata_by_id(
with get_db() as db: self, id: str, meta: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
file.meta = {**(file.meta if file.meta else {}), **meta} file.meta = {**(file.meta if file.meta else {}), **meta}
@ -279,5 +299,25 @@ class FilesTable:
return False return False
def delete_file_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
db.query(File).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_files(self, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
db.query(File).delete()
db.commit()
return True
except Exception:
return False
Files = FilesTable() Files = FilesTable()