mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 06:35:20 +00:00
refac
This commit is contained in:
parent
2041ab483e
commit
145c7516f2
2 changed files with 94 additions and 52 deletions
|
|
@ -306,9 +306,9 @@ class ChannelTable:
|
|||
return memberships
|
||||
|
||||
def insert_new_channel(
|
||||
self, form_data: CreateChannelForm, user_id: str
|
||||
self, form_data: CreateChannelForm, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
channel = ChannelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
|
|
@ -389,7 +389,7 @@ class ChannelTable:
|
|||
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
]
|
||||
|
||||
membership_channels = (
|
||||
|
|
@ -423,8 +423,8 @@ class ChannelTable:
|
|||
all_channels = membership_channels + standard_channels
|
||||
return [ChannelModel.model_validate(c) for c in all_channels]
|
||||
|
||||
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
# Ensure uniqueness in case a list with duplicates is passed
|
||||
unique_user_ids = list(set(user_ids))
|
||||
|
||||
|
|
@ -462,8 +462,9 @@ class ChannelTable:
|
|||
invited_by: str,
|
||||
user_ids: Optional[list[str]] = None,
|
||||
group_ids: Optional[list[str]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[ChannelMemberModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# 1. Collect all user_ids including groups + inviter
|
||||
requested_users = self._collect_unique_user_ids(
|
||||
invited_by, user_ids, group_ids
|
||||
|
|
@ -496,8 +497,9 @@ class ChannelTable:
|
|||
self,
|
||||
channel_id: str,
|
||||
user_ids: list[str],
|
||||
db: Optional[Session] = None,
|
||||
) -> int:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
result = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -509,8 +511,8 @@ class ChannelTable:
|
|||
db.commit()
|
||||
return result # number of rows deleted
|
||||
|
||||
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
# Check if the user is the creator of the channel
|
||||
# or has a 'manager' role in ChannelMember
|
||||
channel = db.query(Channel).filter(Channel.id == channel_id).first()
|
||||
|
|
@ -529,9 +531,9 @@ class ChannelTable:
|
|||
return membership is not None
|
||||
|
||||
def join_channel(
|
||||
self, channel_id: str, user_id: str
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelMemberModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Check if the membership already exists
|
||||
existing_membership = (
|
||||
db.query(ChannelMember)
|
||||
|
|
@ -567,8 +569,8 @@ class ChannelTable:
|
|||
db.commit()
|
||||
return channel_member
|
||||
|
||||
def leave_channel(self, channel_id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -589,9 +591,9 @@ class ChannelTable:
|
|||
return True
|
||||
|
||||
def get_member_by_channel_and_user_id(
|
||||
self, channel_id: str, user_id: str
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelMemberModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -602,8 +604,8 @@ class ChannelTable:
|
|||
)
|
||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||
|
||||
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
|
||||
with get_db() as db:
|
||||
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
|
||||
with get_db_context(db) as db:
|
||||
memberships = (
|
||||
db.query(ChannelMember)
|
||||
.filter(ChannelMember.channel_id == channel_id)
|
||||
|
|
@ -614,8 +616,8 @@ class ChannelTable:
|
|||
for membership in memberships
|
||||
]
|
||||
|
||||
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
|
||||
with get_db() as db:
|
||||
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -633,8 +635,8 @@ class ChannelTable:
|
|||
db.commit()
|
||||
return True
|
||||
|
||||
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -653,9 +655,9 @@ class ChannelTable:
|
|||
return True
|
||||
|
||||
def update_member_active_status(
|
||||
self, channel_id: str, user_id: str, is_active: bool
|
||||
self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -673,8 +675,8 @@ class ChannelTable:
|
|||
db.commit()
|
||||
return True
|
||||
|
||||
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
|
||||
with get_db() as db:
|
||||
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
|
|
@ -693,8 +695,8 @@ class ChannelTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
channel_files = (
|
||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
)
|
||||
|
|
@ -703,9 +705,9 @@ class ChannelTable:
|
|||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
def get_channels_by_file_id_and_user_id(
|
||||
self, file_id: str, user_id: str
|
||||
self, file_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# 1. Determine which channels have this file
|
||||
channel_file_rows = (
|
||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
|
|
@ -729,7 +731,7 @@ class ChannelTable:
|
|||
return []
|
||||
|
||||
# Preload user's group membership
|
||||
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)]
|
||||
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||
|
||||
allowed_channels = []
|
||||
|
||||
|
|
@ -766,9 +768,9 @@ class ChannelTable:
|
|||
return allowed_channels
|
||||
|
||||
def get_channel_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
# Fetch the channel
|
||||
channel: Channel = (
|
||||
db.query(Channel)
|
||||
|
|
|
|||
|
|
@ -135,7 +135,9 @@ class FilesTable:
|
|||
log.exception(f"Error inserting a new file: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
def get_file_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
|
|
@ -146,8 +148,10 @@ class FilesTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
def get_file_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id, user_id=user_id).first()
|
||||
if file:
|
||||
|
|
@ -157,8 +161,10 @@ class FilesTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
|
||||
with get_db() as db:
|
||||
def get_file_metadata_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[FileMetadataResponse]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
return FileMetadataResponse(
|
||||
|
|
@ -175,8 +181,10 @@ class FilesTable:
|
|||
with get_db_context(db) as db:
|
||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||
|
||||
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
||||
file = self.get_file_by_id(id)
|
||||
def check_access_by_user_id(
|
||||
self, id, user_id, permission="write", db: Optional[Session] = None
|
||||
) -> bool:
|
||||
file = self.get_file_by_id(id, db=db)
|
||||
if not file:
|
||||
return False
|
||||
if file.user_id == user_id:
|
||||
|
|
@ -184,8 +192,10 @@ class FilesTable:
|
|||
# Implement additional access control logic here as needed
|
||||
return False
|
||||
|
||||
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
|
||||
with get_db() as db:
|
||||
def get_files_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> list[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File)
|
||||
|
|
@ -194,8 +204,10 @@ class FilesTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
|
||||
with get_db() as db:
|
||||
def get_file_metadatas_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> list[FileMetadataResponse]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileMetadataResponse(
|
||||
id=file.id,
|
||||
|
|
@ -212,7 +224,9 @@ class FilesTable:
|
|||
.all()
|
||||
]
|
||||
|
||||
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||
def get_files_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
|
|
@ -220,9 +234,9 @@ class FilesTable:
|
|||
]
|
||||
|
||||
def update_file_by_id(
|
||||
self, id: str, form_data: FileUpdateForm
|
||||
self, id: str, form_data: FileUpdateForm, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
|
||||
|
|
@ -242,8 +256,10 @@ class FilesTable:
|
|||
log.exception(f"Error updating file completely by id: {e}")
|
||||
return None
|
||||
|
||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
def update_file_hash_by_id(
|
||||
self, id: str, hash: str, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.hash = hash
|
||||
|
|
@ -254,8 +270,10 @@ class FilesTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
def update_file_data_by_id(
|
||||
self, id: str, data: dict, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.data = {**(file.data if file.data else {}), **data}
|
||||
|
|
@ -266,8 +284,10 @@ class FilesTable:
|
|||
|
||||
return None
|
||||
|
||||
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
|
||||
with get_db() as db:
|
||||
def update_file_metadata_by_id(
|
||||
self, id: str, meta: dict, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
file.meta = {**(file.meta if file.meta else {}), **meta}
|
||||
|
|
@ -279,5 +299,25 @@ class FilesTable:
|
|||
|
||||
return False
|
||||
|
||||
def delete_file_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(File).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_all_files(self, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(File).delete()
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
Files = FilesTable()
|
||||
|
|
|
|||
Loading…
Reference in a new issue