mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 06:35:20 +00:00
refac
This commit is contained in:
parent
88dbc14abc
commit
5d1459df16
3 changed files with 148 additions and 112 deletions
|
|
@ -61,6 +61,8 @@ from open_webui.utils.access_control import (
|
|||
)
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
from open_webui.utils.channels import extract_mentions, replace_mentions
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -98,26 +100,27 @@ class ChannelListItemResponse(ChannelModel):
|
|||
async def get_channels(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
channels = Channels.get_channels_by_user_id(user.id)
|
||||
channels = Channels.get_channels_by_user_id(user.id, db=db)
|
||||
channel_list = []
|
||||
for channel in channels:
|
||||
last_message = Messages.get_last_message_by_channel_id(channel.id)
|
||||
last_message = Messages.get_last_message_by_channel_id(channel.id, db=db)
|
||||
last_message_at = last_message.created_at if last_message else None
|
||||
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
|
||||
unread_count = (
|
||||
Messages.get_unread_message_count(
|
||||
channel.id, user.id, channel_member.last_read_at
|
||||
channel.id, user.id, channel_member.last_read_at, db=db
|
||||
)
|
||||
if channel_member
|
||||
else 0
|
||||
|
|
@ -128,13 +131,13 @@ async def get_channels(
|
|||
if channel.type == "dm":
|
||||
user_ids = [
|
||||
member.user_id
|
||||
for member in Channels.get_members_by_channel_id(channel.id)
|
||||
for member in Channels.get_members_by_channel_id(channel.id, db=db)
|
||||
]
|
||||
users = [
|
||||
UserIdNameStatusResponse(
|
||||
**{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
|
||||
**{**user.model_dump(), "is_active": Users.is_user_active(user.id, db=db)}
|
||||
)
|
||||
for user in Users.get_users_by_user_ids(user_ids)
|
||||
for user in Users.get_users_by_user_ids(user_ids, db=db)
|
||||
]
|
||||
|
||||
channel_list.append(
|
||||
|
|
@ -154,11 +157,12 @@ async def get_channels(
|
|||
async def get_all_channels(
|
||||
request: Request,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role == "admin":
|
||||
return Channels.get_channels()
|
||||
return Channels.get_channels_by_user_id(user.id)
|
||||
return Channels.get_channels(db=db)
|
||||
return Channels.get_channels_by_user_id(user.id, db=db)
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -171,10 +175,11 @@ async def get_dm_channel_by_user_id(
|
|||
request: Request,
|
||||
user_id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -182,11 +187,11 @@ async def get_dm_channel_by_user_id(
|
|||
)
|
||||
|
||||
try:
|
||||
existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id])
|
||||
existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id], db=db)
|
||||
if existing_channel:
|
||||
participant_ids = [
|
||||
member.user_id
|
||||
for member in Channels.get_members_by_channel_id(existing_channel.id)
|
||||
for member in Channels.get_members_by_channel_id(existing_channel.id, db=db)
|
||||
]
|
||||
|
||||
await emit_to_users(
|
||||
|
|
@ -198,7 +203,7 @@ async def get_dm_channel_by_user_id(
|
|||
f"channel:{existing_channel.id}", participant_ids
|
||||
)
|
||||
|
||||
Channels.update_member_active_status(existing_channel.id, user.id, True)
|
||||
Channels.update_member_active_status(existing_channel.id, user.id, True, db=db)
|
||||
return ChannelModel(**existing_channel.model_dump())
|
||||
|
||||
channel = Channels.insert_new_channel(
|
||||
|
|
@ -208,12 +213,13 @@ async def get_dm_channel_by_user_id(
|
|||
user_ids=[user_id],
|
||||
),
|
||||
user.id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
if channel:
|
||||
participant_ids = [
|
||||
member.user_id
|
||||
for member in Channels.get_members_by_channel_id(channel.id)
|
||||
for member in Channels.get_members_by_channel_id(channel.id, db=db)
|
||||
]
|
||||
|
||||
await emit_to_users(
|
||||
|
|
@ -243,10 +249,11 @@ async def create_new_channel(
|
|||
request: Request,
|
||||
form_data: CreateChannelForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -263,13 +270,13 @@ async def create_new_channel(
|
|||
try:
|
||||
if form_data.type == "dm":
|
||||
existing_channel = Channels.get_dm_channel_by_user_ids(
|
||||
[user.id, *form_data.user_ids]
|
||||
[user.id, *form_data.user_ids], db=db
|
||||
)
|
||||
if existing_channel:
|
||||
participant_ids = [
|
||||
member.user_id
|
||||
for member in Channels.get_members_by_channel_id(
|
||||
existing_channel.id
|
||||
existing_channel.id, db=db
|
||||
)
|
||||
]
|
||||
await emit_to_users(
|
||||
|
|
@ -281,15 +288,15 @@ async def create_new_channel(
|
|||
f"channel:{existing_channel.id}", participant_ids
|
||||
)
|
||||
|
||||
Channels.update_member_active_status(existing_channel.id, user.id, True)
|
||||
Channels.update_member_active_status(existing_channel.id, user.id, True, db=db)
|
||||
return ChannelModel(**existing_channel.model_dump())
|
||||
|
||||
channel = Channels.insert_new_channel(form_data, user.id)
|
||||
channel = Channels.insert_new_channel(form_data, user.id, db=db)
|
||||
|
||||
if channel:
|
||||
participant_ids = [
|
||||
member.user_id
|
||||
for member in Channels.get_members_by_channel_id(channel.id)
|
||||
for member in Channels.get_members_by_channel_id(channel.id, db=db)
|
||||
]
|
||||
|
||||
await emit_to_users(
|
||||
|
|
@ -327,9 +334,10 @@ async def get_channel_by_id(
|
|||
request: Request,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -339,7 +347,7 @@ async def get_channel_by_id(
|
|||
users = None
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -355,7 +363,7 @@ async def get_channel_by_id(
|
|||
for user in Users.get_users_by_user_ids(user_ids)
|
||||
]
|
||||
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
|
||||
unread_count = Messages.get_unread_message_count(
|
||||
channel.id, user.id, channel_member.last_read_at if channel_member else None
|
||||
)
|
||||
|
|
@ -365,7 +373,7 @@ async def get_channel_by_id(
|
|||
**channel.model_dump(),
|
||||
"user_ids": user_ids,
|
||||
"users": users,
|
||||
"is_manager": Channels.is_user_channel_manager(channel.id, user.id),
|
||||
"is_manager": Channels.is_user_channel_manager(channel.id, user.id, db=db),
|
||||
"write_access": True,
|
||||
"user_count": len(user_ids),
|
||||
"last_read_at": channel_member.last_read_at if channel_member else None,
|
||||
|
|
@ -386,7 +394,7 @@ async def get_channel_by_id(
|
|||
|
||||
user_count = len(get_users_with_access("read", channel.access_control))
|
||||
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
|
||||
unread_count = Messages.get_unread_message_count(
|
||||
channel.id, user.id, channel_member.last_read_at if channel_member else None
|
||||
)
|
||||
|
|
@ -396,7 +404,7 @@ async def get_channel_by_id(
|
|||
**channel.model_dump(),
|
||||
"user_ids": user_ids,
|
||||
"users": users,
|
||||
"is_manager": Channels.is_user_channel_manager(channel.id, user.id),
|
||||
"is_manager": Channels.is_user_channel_manager(channel.id, user.id, db=db),
|
||||
"write_access": write_access or user.role == "admin",
|
||||
"user_count": user_count,
|
||||
"last_read_at": channel_member.last_read_at if channel_member else None,
|
||||
|
|
@ -422,10 +430,11 @@ async def get_channel_members_by_id(
|
|||
direction: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -437,7 +446,7 @@ async def get_channel_members_by_id(
|
|||
skip = (page - 1) * limit
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -510,15 +519,16 @@ async def update_is_active_member_by_id_and_user_id(
|
|||
id: str,
|
||||
form_data: UpdateActiveMemberForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
|
@ -543,17 +553,18 @@ async def add_members_by_id(
|
|||
id: str,
|
||||
form_data: UpdateMembersForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -592,17 +603,18 @@ async def remove_members_by_id(
|
|||
id: str,
|
||||
form_data: RemoveMembersForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -614,7 +626,7 @@ async def remove_members_by_id(
|
|||
)
|
||||
|
||||
try:
|
||||
deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids)
|
||||
deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids, db=db)
|
||||
|
||||
return deleted
|
||||
except Exception as e:
|
||||
|
|
@ -635,17 +647,18 @@ async def update_channel_by_id(
|
|||
id: str,
|
||||
form_data: ChannelForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -657,7 +670,7 @@ async def update_channel_by_id(
|
|||
)
|
||||
|
||||
try:
|
||||
channel = Channels.update_channel_by_id(id, form_data)
|
||||
channel = Channels.update_channel_by_id(id, form_data, db=db)
|
||||
return ChannelModel(**channel.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -676,17 +689,18 @@ async def delete_channel_by_id(
|
|||
request: Request,
|
||||
id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
if user.role != "admin" and not has_permission(
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
|
||||
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -698,7 +712,7 @@ async def delete_channel_by_id(
|
|||
)
|
||||
|
||||
try:
|
||||
Channels.delete_channel_by_id(id)
|
||||
Channels.delete_channel_by_id(id, db=db)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -732,16 +746,17 @@ async def get_channel_messages(
|
|||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -763,7 +778,7 @@ async def get_channel_messages(
|
|||
messages = []
|
||||
for message in message_list:
|
||||
if message.user_id not in users:
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
users[message.user_id] = user
|
||||
|
||||
thread_replies = Messages.get_thread_replies_by_message_id(message.id)
|
||||
|
|
@ -799,16 +814,17 @@ async def get_pinned_channel_messages(
|
|||
id: str,
|
||||
page: int = 1,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -830,7 +846,7 @@ async def get_pinned_channel_messages(
|
|||
messages = []
|
||||
for message in message_list:
|
||||
if message.user_id not in users:
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
users[message.user_id] = user
|
||||
|
||||
messages.append(
|
||||
|
|
@ -944,7 +960,7 @@ async def model_response_handler(request, channel, message, user):
|
|||
for thread_message in thread_messages:
|
||||
message_user = None
|
||||
if thread_message.user_id not in message_users:
|
||||
message_user = Users.get_user_by_id(thread_message.user_id)
|
||||
message_user = Users.get_user_by_id(thread_message.user_id, db=db)
|
||||
message_users[thread_message.user_id] = message_user
|
||||
else:
|
||||
message_user = message_users[thread_message.user_id]
|
||||
|
|
@ -1059,39 +1075,39 @@ async def model_response_handler(request, channel, message, user):
|
|||
|
||||
|
||||
async def new_message_handler(
|
||||
request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
request: Request, id: str, form_data: MessageForm, user, db
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
else:
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="write", access_control=channel.access_control, strict=False
|
||||
user.id, type="write", access_control=channel.access_control, strict=False, db=db
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.insert_new_message(form_data, channel.id, user.id)
|
||||
message = Messages.insert_new_message(form_data, channel.id, user.id, db=db)
|
||||
if message:
|
||||
if channel.type in ["group", "dm"]:
|
||||
members = Channels.get_members_by_channel_id(channel.id)
|
||||
members = Channels.get_members_by_channel_id(channel.id, db=db)
|
||||
for member in members:
|
||||
if not member.is_active:
|
||||
Channels.update_member_active_status(
|
||||
channel.id, member.user_id, True
|
||||
channel.id, member.user_id, True, db=db
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message.id)
|
||||
message = Messages.get_message_by_id(message.id, db=db)
|
||||
event_data = {
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
|
|
@ -1111,7 +1127,7 @@ async def new_message_handler(
|
|||
|
||||
if message.parent_id:
|
||||
# If this message is a reply, emit to the parent message as well
|
||||
parent_message = Messages.get_message_by_id(message.parent_id)
|
||||
parent_message = Messages.get_message_by_id(message.parent_id, db=db)
|
||||
|
||||
if parent_message:
|
||||
await sio.emit(
|
||||
|
|
@ -1145,16 +1161,17 @@ async def post_new_message(
|
|||
form_data: MessageForm,
|
||||
background_tasks: BackgroundTasks,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
|
||||
try:
|
||||
message, channel = await new_message_handler(request, id, form_data, user)
|
||||
message, channel = await new_message_handler(request, id, form_data, user, db)
|
||||
try:
|
||||
if files := message.data.get("files", []):
|
||||
for file in files:
|
||||
Channels.set_file_message_id_in_channel_by_id(
|
||||
channel.id, file.get("id", ""), message.id
|
||||
channel.id, file.get("id", ""), message.id, db=db
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
|
@ -1195,16 +1212,17 @@ async def get_channel_message(
|
|||
id: str,
|
||||
message_id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1216,7 +1234,7 @@ async def get_channel_message(
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1231,7 +1249,7 @@ async def get_channel_message(
|
|||
**{
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(message.user_id).model_dump()
|
||||
**Users.get_user_by_id(message.user_id, db=db).model_dump()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
|
@ -1248,16 +1266,17 @@ async def get_channel_message_data(
|
|||
id: str,
|
||||
message_id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1269,7 +1288,7 @@ async def get_channel_message_data(
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1301,16 +1320,17 @@ async def pin_channel_message(
|
|||
message_id: str,
|
||||
form_data: PinMessageForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1322,7 +1342,7 @@ async def pin_channel_message(
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1335,12 +1355,12 @@ async def pin_channel_message(
|
|||
|
||||
try:
|
||||
Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id)
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
return MessageUserResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(
|
||||
**Users.get_user_by_id(message.user_id).model_dump()
|
||||
**Users.get_user_by_id(message.user_id, db=db).model_dump()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
|
@ -1366,16 +1386,17 @@ async def get_channel_thread_messages(
|
|||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1393,7 +1414,7 @@ async def get_channel_thread_messages(
|
|||
messages = []
|
||||
for message in message_list:
|
||||
if message.user_id not in users:
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
users[message.user_id] = user
|
||||
|
||||
messages.append(
|
||||
|
|
@ -1425,15 +1446,16 @@ async def update_message_by_id(
|
|||
message_id: str,
|
||||
form_data: MessageForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1445,7 +1467,7 @@ async def update_message_by_id(
|
|||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1463,7 +1485,7 @@ async def update_message_by_id(
|
|||
|
||||
try:
|
||||
message = Messages.update_message_by_id(message_id, form_data)
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
|
||||
if message:
|
||||
await sio.emit(
|
||||
|
|
@ -1505,16 +1527,17 @@ async def add_reaction_to_message(
|
|||
message_id: str,
|
||||
form_data: ReactionForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1526,7 +1549,7 @@ async def add_reaction_to_message(
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1539,7 +1562,7 @@ async def add_reaction_to_message(
|
|||
|
||||
try:
|
||||
Messages.add_reaction_to_message(message_id, user.id, form_data.name)
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
|
||||
await sio.emit(
|
||||
"events:channel",
|
||||
|
|
@ -1579,16 +1602,17 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
|||
message_id: str,
|
||||
form_data: ReactionForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
@ -1600,7 +1624,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1616,7 +1640,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
|||
message_id, user.id, form_data.name
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
|
||||
await sio.emit(
|
||||
"events:channel",
|
||||
|
|
@ -1655,15 +1679,16 @@ async def delete_message_by_id(
|
|||
id: str,
|
||||
message_id: str,
|
||||
user=Depends(get_verified_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
check_channels_access(request)
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
channel = Channels.get_channel_by_id(id, db=db)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
|
@ -1675,7 +1700,7 @@ async def delete_message_by_id(
|
|||
)
|
||||
|
||||
if channel.type in ["group", "dm"]:
|
||||
if not Channels.is_user_channel_member(channel.id, user.id):
|
||||
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from open_webui.models.chats import Chats
|
|||
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
from open_webui.internal.db import get_session
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.utils.images.comfyui import (
|
||||
ComfyUICreateImageForm,
|
||||
ComfyUIEditImageForm,
|
||||
|
|
@ -496,7 +498,7 @@ def get_image_data(data: str, headers=None):
|
|||
return None, None
|
||||
|
||||
|
||||
def upload_image(request, image_data, content_type, metadata, user):
|
||||
def upload_image(request, image_data, content_type, metadata, user, db=None):
|
||||
image_format = mimetypes.guess_extension(content_type)
|
||||
file = UploadFile(
|
||||
file=io.BytesIO(image_data),
|
||||
|
|
@ -524,6 +526,7 @@ def upload_image(request, image_data, content_type, metadata, user):
|
|||
message_id=message_id,
|
||||
file_ids=[file_item.id],
|
||||
user_id=user.id,
|
||||
db=db,
|
||||
)
|
||||
|
||||
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)
|
||||
|
|
|
|||
|
|
@ -300,7 +300,7 @@ def get_scim_auth(
|
|||
)
|
||||
|
||||
|
||||
def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||
def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser:
|
||||
"""Convert internal User model to SCIM User"""
|
||||
# Parse display name into name components
|
||||
name_parts = user.name.split(" ", 1) if user.name else ["", ""]
|
||||
|
|
@ -308,7 +308,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
|||
family_name = name_parts[1] if len(name_parts) > 1 else ""
|
||||
|
||||
# Get user's groups
|
||||
user_groups = Groups.get_groups_by_member_id(user.id)
|
||||
user_groups = Groups.get_groups_by_member_id(user.id, db=db)
|
||||
groups = [
|
||||
{
|
||||
"value": group.id,
|
||||
|
|
@ -487,6 +487,7 @@ async def get_users(
|
|||
count: int = Query(20, ge=1, le=100),
|
||||
filter: Optional[str] = None,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""List SCIM Users"""
|
||||
skip = startIndex - 1
|
||||
|
|
@ -498,20 +499,20 @@ async def get_users(
|
|||
# In production, you'd want a more robust filter parser
|
||||
if "userName eq" in filter:
|
||||
email = filter.split('"')[1]
|
||||
user = Users.get_user_by_email(email)
|
||||
user = Users.get_user_by_email(email, db=db)
|
||||
users_list = [user] if user else []
|
||||
total = 1 if user else 0
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit)
|
||||
response = Users.get_users(skip=skip, limit=limit, db=db)
|
||||
users_list = response["users"]
|
||||
total = response["total"]
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit)
|
||||
response = Users.get_users(skip=skip, limit=limit, db=db)
|
||||
users_list = response["users"]
|
||||
total = response["total"]
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_users = [user_to_scim(user, request) for user in users_list]
|
||||
scim_users = [user_to_scim(user, request, db=db) for user in users_list]
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total,
|
||||
|
|
@ -526,15 +527,16 @@ async def get_user(
|
|||
user_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Get SCIM User by ID"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if not user:
|
||||
return scim_error(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
|
||||
)
|
||||
|
||||
return user_to_scim(user, request)
|
||||
return user_to_scim(user, request, db=db)
|
||||
|
||||
|
||||
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
|
||||
|
|
@ -542,10 +544,11 @@ async def create_user(
|
|||
request: Request,
|
||||
user_data: SCIMUserCreateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Create SCIM User"""
|
||||
# Check if user already exists
|
||||
existing_user = Users.get_user_by_email(user_data.userName)
|
||||
existing_user = Users.get_user_by_email(user_data.userName, db=db)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
|
|
@ -576,6 +579,7 @@ async def create_user(
|
|||
email=email,
|
||||
profile_image_url=profile_image,
|
||||
role="user" if user_data.active else "pending",
|
||||
db=db,
|
||||
)
|
||||
|
||||
if not new_user:
|
||||
|
|
@ -584,7 +588,7 @@ async def create_user(
|
|||
detail="Failed to create user",
|
||||
)
|
||||
|
||||
return user_to_scim(new_user, request)
|
||||
return user_to_scim(new_user, request, db=db)
|
||||
|
||||
|
||||
@router.put("/Users/{user_id}", response_model=SCIMUser)
|
||||
|
|
@ -593,9 +597,10 @@ async def update_user(
|
|||
request: Request,
|
||||
user_data: SCIMUserUpdateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Update SCIM User (full update)"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -628,14 +633,14 @@ async def update_user(
|
|||
update_data["profile_image_url"] = user_data.photos[0].value
|
||||
|
||||
# Update user
|
||||
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||
updated_user = Users.update_user_by_id(user_id, update_data, db=db)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user",
|
||||
)
|
||||
|
||||
return user_to_scim(updated_user, request)
|
||||
return user_to_scim(updated_user, request, db=db)
|
||||
|
||||
|
||||
@router.patch("/Users/{user_id}", response_model=SCIMUser)
|
||||
|
|
@ -644,9 +649,10 @@ async def patch_user(
|
|||
request: Request,
|
||||
patch_data: SCIMPatchRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Update SCIM User (partial update)"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -674,7 +680,7 @@ async def patch_user(
|
|||
|
||||
# Update user
|
||||
if update_data:
|
||||
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||
updated_user = Users.update_user_by_id(user_id, update_data, db=db)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -683,7 +689,7 @@ async def patch_user(
|
|||
else:
|
||||
updated_user = user
|
||||
|
||||
return user_to_scim(updated_user, request)
|
||||
return user_to_scim(updated_user, request, db=db)
|
||||
|
||||
|
||||
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
|
@ -691,16 +697,17 @@ async def delete_user(
|
|||
user_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Delete SCIM User"""
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"User {user_id} not found",
|
||||
)
|
||||
|
||||
success = Users.delete_user_by_id(user_id)
|
||||
success = Users.delete_user_by_id(user_id, db=db)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -749,7 +756,7 @@ async def get_group(
|
|||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Get SCIM Group by ID"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = Groups.get_group_by_id(group_id, db=db)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -921,16 +928,17 @@ async def delete_group(
|
|||
group_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Delete SCIM Group"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = Groups.get_group_by_id(group_id, db=db)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
success = Groups.delete_group_by_id(group_id)
|
||||
success = Groups.delete_group_by_id(group_id, db=db)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
|
|||
Loading…
Reference in a new issue