This commit is contained in:
Timothy Jaeryang Baek 2025-12-28 22:26:35 +04:00
parent 2041ab483e
commit 145c7516f2
2 changed files with 94 additions and 52 deletions

View file

@ -306,9 +306,9 @@ class ChannelTable:
return memberships
def insert_new_channel(
self, form_data: CreateChannelForm, user_id: str
self, form_data: CreateChannelForm, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]:
with get_db() as db:
with get_db_context(db) as db:
channel = ChannelModel(
**{
**form_data.model_dump(),
@ -389,7 +389,7 @@ class ChannelTable:
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id)
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
membership_channels = (
@ -423,8 +423,8 @@ class ChannelTable:
all_channels = membership_channels + standard_channels
return [ChannelModel.model_validate(c) for c in all_channels]
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
with get_db() as db:
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
with get_db_context(db) as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
@ -462,8 +462,9 @@ class ChannelTable:
invited_by: str,
user_ids: Optional[list[str]] = None,
group_ids: Optional[list[str]] = None,
db: Optional[Session] = None,
) -> list[ChannelMemberModel]:
with get_db() as db:
with get_db_context(db) as db:
# 1. Collect all user_ids including groups + inviter
requested_users = self._collect_unique_user_ids(
invited_by, user_ids, group_ids
@ -496,8 +497,9 @@ class ChannelTable:
self,
channel_id: str,
user_ids: list[str],
db: Optional[Session] = None,
) -> int:
with get_db() as db:
with get_db_context(db) as db:
result = (
db.query(ChannelMember)
.filter(
@ -509,8 +511,8 @@ class ChannelTable:
db.commit()
return result # number of rows deleted
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
# Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember
channel = db.query(Channel).filter(Channel.id == channel_id).first()
@ -529,9 +531,9 @@ class ChannelTable:
return membership is not None
def join_channel(
self, channel_id: str, user_id: str
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelMemberModel]:
with get_db() as db:
with get_db_context(db) as db:
# Check if the membership already exists
existing_membership = (
db.query(ChannelMember)
@ -567,8 +569,8 @@ class ChannelTable:
db.commit()
return channel_member
def leave_channel(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
.filter(
@ -589,9 +591,9 @@ class ChannelTable:
return True
def get_member_by_channel_and_user_id(
self, channel_id: str, user_id: str
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelMemberModel]:
with get_db() as db:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
.filter(
@ -602,8 +604,8 @@ class ChannelTable:
)
return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
with get_db() as db:
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
with get_db_context(db) as db:
memberships = (
db.query(ChannelMember)
.filter(ChannelMember.channel_id == channel_id)
@ -614,8 +616,8 @@ class ChannelTable:
for membership in memberships
]
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
with get_db() as db:
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
.filter(
@ -633,8 +635,8 @@ class ChannelTable:
db.commit()
return True
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
.filter(
@ -653,9 +655,9 @@ class ChannelTable:
return True
def update_member_active_status(
self, channel_id: str, user_id: str, is_active: bool
self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None
) -> bool:
with get_db() as db:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
.filter(
@ -673,8 +675,8 @@ class ChannelTable:
db.commit()
return True
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
with get_db() as db:
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
.filter(
@ -693,8 +695,8 @@ class ChannelTable:
except Exception:
return None
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
with get_db() as db:
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db:
channel_files = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
@ -703,9 +705,9 @@ class ChannelTable:
return [ChannelModel.model_validate(channel) for channel in channels]
def get_channels_by_file_id_and_user_id(
self, file_id: str, user_id: str
self, file_id: str, user_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
with get_db() as db:
with get_db_context(db) as db:
# 1. Determine which channels have this file
channel_file_rows = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
@ -729,7 +731,7 @@ class ChannelTable:
return []
# Preload user's group membership
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)]
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
allowed_channels = []
@ -766,9 +768,9 @@ class ChannelTable:
return allowed_channels
def get_channel_by_id_and_user_id(
self, id: str, user_id: str
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]:
with get_db() as db:
with get_db_context(db) as db:
# Fetch the channel
channel: Channel = (
db.query(Channel)

View file

@ -135,7 +135,9 @@ class FilesTable:
log.exception(f"Error inserting a new file: {e}")
return None
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
def get_file_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
try:
with get_db_context(db) as db:
try:
@ -146,8 +148,10 @@ class FilesTable:
except Exception:
return None
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
with get_db() as db:
def get_file_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id, user_id=user_id).first()
if file:
@ -157,8 +161,10 @@ class FilesTable:
except Exception:
return None
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
with get_db() as db:
def get_file_metadata_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FileMetadataResponse]:
with get_db_context(db) as db:
try:
file = db.get(File, id)
return FileMetadataResponse(
@ -175,8 +181,10 @@ class FilesTable:
with get_db_context(db) as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
file = self.get_file_by_id(id)
def check_access_by_user_id(
self, id, user_id, permission="write", db: Optional[Session] = None
) -> bool:
file = self.get_file_by_id(id, db=db)
if not file:
return False
if file.user_id == user_id:
@ -184,8 +192,10 @@ class FilesTable:
# Implement additional access control logic here as needed
return False
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
with get_db() as db:
def get_files_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> list[FileModel]:
with get_db_context(db) as db:
return [
FileModel.model_validate(file)
for file in db.query(File)
@ -194,8 +204,10 @@ class FilesTable:
.all()
]
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
with get_db() as db:
def get_file_metadatas_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> list[FileMetadataResponse]:
with get_db_context(db) as db:
return [
FileMetadataResponse(
id=file.id,
@ -212,7 +224,9 @@ class FilesTable:
.all()
]
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
def get_files_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[FileModel]:
with get_db_context(db) as db:
return [
FileModel.model_validate(file)
@ -220,9 +234,9 @@ class FilesTable:
]
def update_file_by_id(
self, id: str, form_data: FileUpdateForm
self, id: str, form_data: FileUpdateForm, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db() as db:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()
@ -242,8 +256,10 @@ class FilesTable:
log.exception(f"Error updating file completely by id: {e}")
return None
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
with get_db() as db:
def update_file_hash_by_id(
self, id: str, hash: str, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()
file.hash = hash
@ -254,8 +270,10 @@ class FilesTable:
except Exception:
return None
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
with get_db() as db:
def update_file_data_by_id(
self, id: str, data: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()
file.data = {**(file.data if file.data else {}), **data}
@ -266,8 +284,10 @@ class FilesTable:
return None
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
with get_db() as db:
def update_file_metadata_by_id(
self, id: str, meta: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()
file.meta = {**(file.meta if file.meta else {}), **meta}
@ -279,5 +299,25 @@ class FilesTable:
return False
def delete_file_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
db.query(File).filter_by(id=id).delete()
db.commit()
return True
except Exception:
return False
def delete_all_files(self, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
db.query(File).delete()
db.commit()
return True
except Exception:
return False
Files = FilesTable()