mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 14:45:18 +00:00
refac
This commit is contained in:
parent
2041ab483e
commit
145c7516f2
2 changed files with 94 additions and 52 deletions
|
|
@ -306,9 +306,9 @@ class ChannelTable:
|
||||||
return memberships
|
return memberships
|
||||||
|
|
||||||
def insert_new_channel(
|
def insert_new_channel(
|
||||||
self, form_data: CreateChannelForm, user_id: str
|
self, form_data: CreateChannelForm, user_id: str, db: Optional[Session] = None
|
||||||
) -> Optional[ChannelModel]:
|
) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
channel = ChannelModel(
|
channel = ChannelModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
|
|
@ -389,7 +389,7 @@ class ChannelTable:
|
||||||
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user_group_ids = [
|
user_group_ids = [
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||||
]
|
]
|
||||||
|
|
||||||
membership_channels = (
|
membership_channels = (
|
||||||
|
|
@ -423,8 +423,8 @@ class ChannelTable:
|
||||||
all_channels = membership_channels + standard_channels
|
all_channels = membership_channels + standard_channels
|
||||||
return [ChannelModel.model_validate(c) for c in all_channels]
|
return [ChannelModel.model_validate(c) for c in all_channels]
|
||||||
|
|
||||||
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
|
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
# Ensure uniqueness in case a list with duplicates is passed
|
# Ensure uniqueness in case a list with duplicates is passed
|
||||||
unique_user_ids = list(set(user_ids))
|
unique_user_ids = list(set(user_ids))
|
||||||
|
|
||||||
|
|
@ -462,8 +462,9 @@ class ChannelTable:
|
||||||
invited_by: str,
|
invited_by: str,
|
||||||
user_ids: Optional[list[str]] = None,
|
user_ids: Optional[list[str]] = None,
|
||||||
group_ids: Optional[list[str]] = None,
|
group_ids: Optional[list[str]] = None,
|
||||||
|
db: Optional[Session] = None,
|
||||||
) -> list[ChannelMemberModel]:
|
) -> list[ChannelMemberModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
# 1. Collect all user_ids including groups + inviter
|
# 1. Collect all user_ids including groups + inviter
|
||||||
requested_users = self._collect_unique_user_ids(
|
requested_users = self._collect_unique_user_ids(
|
||||||
invited_by, user_ids, group_ids
|
invited_by, user_ids, group_ids
|
||||||
|
|
@ -496,8 +497,9 @@ class ChannelTable:
|
||||||
self,
|
self,
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
user_ids: list[str],
|
user_ids: list[str],
|
||||||
|
db: Optional[Session] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
result = (
|
result = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -509,8 +511,8 @@ class ChannelTable:
|
||||||
db.commit()
|
db.commit()
|
||||||
return result # number of rows deleted
|
return result # number of rows deleted
|
||||||
|
|
||||||
def is_user_channel_manager(self, channel_id: str, user_id: str) -> bool:
|
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
# Check if the user is the creator of the channel
|
# Check if the user is the creator of the channel
|
||||||
# or has a 'manager' role in ChannelMember
|
# or has a 'manager' role in ChannelMember
|
||||||
channel = db.query(Channel).filter(Channel.id == channel_id).first()
|
channel = db.query(Channel).filter(Channel.id == channel_id).first()
|
||||||
|
|
@ -529,9 +531,9 @@ class ChannelTable:
|
||||||
return membership is not None
|
return membership is not None
|
||||||
|
|
||||||
def join_channel(
|
def join_channel(
|
||||||
self, channel_id: str, user_id: str
|
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||||
) -> Optional[ChannelMemberModel]:
|
) -> Optional[ChannelMemberModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
# Check if the membership already exists
|
# Check if the membership already exists
|
||||||
existing_membership = (
|
existing_membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
|
|
@ -567,8 +569,8 @@ class ChannelTable:
|
||||||
db.commit()
|
db.commit()
|
||||||
return channel_member
|
return channel_member
|
||||||
|
|
||||||
def leave_channel(self, channel_id: str, user_id: str) -> bool:
|
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -589,9 +591,9 @@ class ChannelTable:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_member_by_channel_and_user_id(
|
def get_member_by_channel_and_user_id(
|
||||||
self, channel_id: str, user_id: str
|
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||||
) -> Optional[ChannelMemberModel]:
|
) -> Optional[ChannelMemberModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -602,8 +604,8 @@ class ChannelTable:
|
||||||
)
|
)
|
||||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||||
|
|
||||||
def get_members_by_channel_id(self, channel_id: str) -> list[ChannelMemberModel]:
|
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
memberships = (
|
memberships = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(ChannelMember.channel_id == channel_id)
|
.filter(ChannelMember.channel_id == channel_id)
|
||||||
|
|
@ -614,8 +616,8 @@ class ChannelTable:
|
||||||
for membership in memberships
|
for membership in memberships
|
||||||
]
|
]
|
||||||
|
|
||||||
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool) -> bool:
|
def pin_channel(self, channel_id: str, user_id: str, is_pinned: bool, db: Optional[Session] = None) -> bool:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -633,8 +635,8 @@ class ChannelTable:
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def update_member_last_read_at(self, channel_id: str, user_id: str) -> bool:
|
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -653,9 +655,9 @@ class ChannelTable:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def update_member_active_status(
|
def update_member_active_status(
|
||||||
self, channel_id: str, user_id: str, is_active: bool
|
self, channel_id: str, user_id: str, is_active: bool, db: Optional[Session] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -673,8 +675,8 @@ class ChannelTable:
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_user_channel_member(self, channel_id: str, user_id: str) -> bool:
|
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
|
|
@ -693,8 +695,8 @@ class ChannelTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_channels_by_file_id(self, file_id: str) -> list[ChannelModel]:
|
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
channel_files = (
|
channel_files = (
|
||||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||||
)
|
)
|
||||||
|
|
@ -703,9 +705,9 @@ class ChannelTable:
|
||||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||||
|
|
||||||
def get_channels_by_file_id_and_user_id(
|
def get_channels_by_file_id_and_user_id(
|
||||||
self, file_id: str, user_id: str
|
self, file_id: str, user_id: str, db: Optional[Session] = None
|
||||||
) -> list[ChannelModel]:
|
) -> list[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
# 1. Determine which channels have this file
|
# 1. Determine which channels have this file
|
||||||
channel_file_rows = (
|
channel_file_rows = (
|
||||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||||
|
|
@ -729,7 +731,7 @@ class ChannelTable:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Preload user's group membership
|
# Preload user's group membership
|
||||||
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id)]
|
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||||
|
|
||||||
allowed_channels = []
|
allowed_channels = []
|
||||||
|
|
||||||
|
|
@ -766,9 +768,9 @@ class ChannelTable:
|
||||||
return allowed_channels
|
return allowed_channels
|
||||||
|
|
||||||
def get_channel_by_id_and_user_id(
|
def get_channel_by_id_and_user_id(
|
||||||
self, id: str, user_id: str
|
self, id: str, user_id: str, db: Optional[Session] = None
|
||||||
) -> Optional[ChannelModel]:
|
) -> Optional[ChannelModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
# Fetch the channel
|
# Fetch the channel
|
||||||
channel: Channel = (
|
channel: Channel = (
|
||||||
db.query(Channel)
|
db.query(Channel)
|
||||||
|
|
|
||||||
|
|
@ -135,7 +135,9 @@ class FilesTable:
|
||||||
log.exception(f"Error inserting a new file: {e}")
|
log.exception(f"Error inserting a new file: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
def get_file_by_id(
|
||||||
|
self, id: str, db: Optional[Session] = None
|
||||||
|
) -> Optional[FileModel]:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
|
|
@ -146,8 +148,10 @@ class FilesTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]:
|
def get_file_by_id_and_user_id(
|
||||||
with get_db() as db:
|
self, id: str, user_id: str, db: Optional[Session] = None
|
||||||
|
) -> Optional[FileModel]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id, user_id=user_id).first()
|
file = db.query(File).filter_by(id=id, user_id=user_id).first()
|
||||||
if file:
|
if file:
|
||||||
|
|
@ -157,8 +161,10 @@ class FilesTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]:
|
def get_file_metadata_by_id(
|
||||||
with get_db() as db:
|
self, id: str, db: Optional[Session] = None
|
||||||
|
) -> Optional[FileMetadataResponse]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.get(File, id)
|
file = db.get(File, id)
|
||||||
return FileMetadataResponse(
|
return FileMetadataResponse(
|
||||||
|
|
@ -175,8 +181,10 @@ class FilesTable:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||||
|
|
||||||
def check_access_by_user_id(self, id, user_id, permission="write") -> bool:
|
def check_access_by_user_id(
|
||||||
file = self.get_file_by_id(id)
|
self, id, user_id, permission="write", db: Optional[Session] = None
|
||||||
|
) -> bool:
|
||||||
|
file = self.get_file_by_id(id, db=db)
|
||||||
if not file:
|
if not file:
|
||||||
return False
|
return False
|
||||||
if file.user_id == user_id:
|
if file.user_id == user_id:
|
||||||
|
|
@ -184,8 +192,10 @@ class FilesTable:
|
||||||
# Implement additional access control logic here as needed
|
# Implement additional access control logic here as needed
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_files_by_ids(self, ids: list[str]) -> list[FileModel]:
|
def get_files_by_ids(
|
||||||
with get_db() as db:
|
self, ids: list[str], db: Optional[Session] = None
|
||||||
|
) -> list[FileModel]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FileModel.model_validate(file)
|
FileModel.model_validate(file)
|
||||||
for file in db.query(File)
|
for file in db.query(File)
|
||||||
|
|
@ -194,8 +204,10 @@ class FilesTable:
|
||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse]:
|
def get_file_metadatas_by_ids(
|
||||||
with get_db() as db:
|
self, ids: list[str], db: Optional[Session] = None
|
||||||
|
) -> list[FileMetadataResponse]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FileMetadataResponse(
|
FileMetadataResponse(
|
||||||
id=file.id,
|
id=file.id,
|
||||||
|
|
@ -212,7 +224,9 @@ class FilesTable:
|
||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
def get_files_by_user_id(
|
||||||
|
self, user_id: str, db: Optional[Session] = None
|
||||||
|
) -> list[FileModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FileModel.model_validate(file)
|
FileModel.model_validate(file)
|
||||||
|
|
@ -220,9 +234,9 @@ class FilesTable:
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_file_by_id(
|
def update_file_by_id(
|
||||||
self, id: str, form_data: FileUpdateForm
|
self, id: str, form_data: FileUpdateForm, db: Optional[Session] = None
|
||||||
) -> Optional[FileModel]:
|
) -> Optional[FileModel]:
|
||||||
with get_db() as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
|
|
||||||
|
|
@ -242,8 +256,10 @@ class FilesTable:
|
||||||
log.exception(f"Error updating file completely by id: {e}")
|
log.exception(f"Error updating file completely by id: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_file_hash_by_id(self, id: str, hash: str) -> Optional[FileModel]:
|
def update_file_hash_by_id(
|
||||||
with get_db() as db:
|
self, id: str, hash: str, db: Optional[Session] = None
|
||||||
|
) -> Optional[FileModel]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
file.hash = hash
|
file.hash = hash
|
||||||
|
|
@ -254,8 +270,10 @@ class FilesTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_file_data_by_id(self, id: str, data: dict) -> Optional[FileModel]:
|
def update_file_data_by_id(
|
||||||
with get_db() as db:
|
self, id: str, data: dict, db: Optional[Session] = None
|
||||||
|
) -> Optional[FileModel]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
file.data = {**(file.data if file.data else {}), **data}
|
file.data = {**(file.data if file.data else {}), **data}
|
||||||
|
|
@ -266,8 +284,10 @@ class FilesTable:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_file_metadata_by_id(self, id: str, meta: dict) -> Optional[FileModel]:
|
def update_file_metadata_by_id(
|
||||||
with get_db() as db:
|
self, id: str, meta: dict, db: Optional[Session] = None
|
||||||
|
) -> Optional[FileModel]:
|
||||||
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
file.meta = {**(file.meta if file.meta else {}), **meta}
|
file.meta = {**(file.meta if file.meta else {}), **meta}
|
||||||
|
|
@ -279,5 +299,25 @@ class FilesTable:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def delete_file_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
|
with get_db_context(db) as db:
|
||||||
|
try:
|
||||||
|
db.query(File).filter_by(id=id).delete()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_all_files(self, db: Optional[Session] = None) -> bool:
|
||||||
|
with get_db_context(db) as db:
|
||||||
|
try:
|
||||||
|
db.query(File).delete()
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
Files = FilesTable()
|
Files = FilesTable()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue