This commit is contained in:
Timothy Jaeryang Baek 2025-12-29 01:20:04 +04:00
parent 88dbc14abc
commit 5d1459df16
3 changed files with 148 additions and 112 deletions

View file

@ -61,6 +61,8 @@ from open_webui.utils.access_control import (
)
from open_webui.utils.webhook import post_webhook
from open_webui.utils.channels import extract_mentions, replace_mentions
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
log = logging.getLogger(__name__)
@ -98,26 +100,27 @@ class ChannelListItemResponse(ChannelModel):
async def get_channels(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
channels = Channels.get_channels_by_user_id(user.id)
channels = Channels.get_channels_by_user_id(user.id, db=db)
channel_list = []
for channel in channels:
last_message = Messages.get_last_message_by_channel_id(channel.id)
last_message = Messages.get_last_message_by_channel_id(channel.id, db=db)
last_message_at = last_message.created_at if last_message else None
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
unread_count = (
Messages.get_unread_message_count(
channel.id, user.id, channel_member.last_read_at
channel.id, user.id, channel_member.last_read_at, db=db
)
if channel_member
else 0
@ -128,13 +131,13 @@ async def get_channels(
if channel.type == "dm":
user_ids = [
member.user_id
for member in Channels.get_members_by_channel_id(channel.id)
for member in Channels.get_members_by_channel_id(channel.id, db=db)
]
users = [
UserIdNameStatusResponse(
**{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
**{**user.model_dump(), "is_active": Users.is_user_active(user.id, db=db)}
)
for user in Users.get_users_by_user_ids(user_ids)
for user in Users.get_users_by_user_ids(user_ids, db=db)
]
channel_list.append(
@ -154,11 +157,12 @@ async def get_channels(
async def get_all_channels(
request: Request,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role == "admin":
return Channels.get_channels()
return Channels.get_channels_by_user_id(user.id)
return Channels.get_channels(db=db)
return Channels.get_channels_by_user_id(user.id, db=db)
############################
@ -171,10 +175,11 @@ async def get_dm_channel_by_user_id(
request: Request,
user_id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -182,11 +187,11 @@ async def get_dm_channel_by_user_id(
)
try:
existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id])
existing_channel = Channels.get_dm_channel_by_user_ids([user.id, user_id], db=db)
if existing_channel:
participant_ids = [
member.user_id
for member in Channels.get_members_by_channel_id(existing_channel.id)
for member in Channels.get_members_by_channel_id(existing_channel.id, db=db)
]
await emit_to_users(
@ -198,7 +203,7 @@ async def get_dm_channel_by_user_id(
f"channel:{existing_channel.id}", participant_ids
)
Channels.update_member_active_status(existing_channel.id, user.id, True)
Channels.update_member_active_status(existing_channel.id, user.id, True, db=db)
return ChannelModel(**existing_channel.model_dump())
channel = Channels.insert_new_channel(
@ -208,12 +213,13 @@ async def get_dm_channel_by_user_id(
user_ids=[user_id],
),
user.id,
db=db,
)
if channel:
participant_ids = [
member.user_id
for member in Channels.get_members_by_channel_id(channel.id)
for member in Channels.get_members_by_channel_id(channel.id, db=db)
]
await emit_to_users(
@ -243,10 +249,11 @@ async def create_new_channel(
request: Request,
form_data: CreateChannelForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -263,13 +270,13 @@ async def create_new_channel(
try:
if form_data.type == "dm":
existing_channel = Channels.get_dm_channel_by_user_ids(
[user.id, *form_data.user_ids]
[user.id, *form_data.user_ids], db=db
)
if existing_channel:
participant_ids = [
member.user_id
for member in Channels.get_members_by_channel_id(
existing_channel.id
existing_channel.id, db=db
)
]
await emit_to_users(
@ -281,15 +288,15 @@ async def create_new_channel(
f"channel:{existing_channel.id}", participant_ids
)
Channels.update_member_active_status(existing_channel.id, user.id, True)
Channels.update_member_active_status(existing_channel.id, user.id, True, db=db)
return ChannelModel(**existing_channel.model_dump())
channel = Channels.insert_new_channel(form_data, user.id)
channel = Channels.insert_new_channel(form_data, user.id, db=db)
if channel:
participant_ids = [
member.user_id
for member in Channels.get_members_by_channel_id(channel.id)
for member in Channels.get_members_by_channel_id(channel.id, db=db)
]
await emit_to_users(
@ -327,9 +334,10 @@ async def get_channel_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -339,7 +347,7 @@ async def get_channel_by_id(
users = None
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -355,7 +363,7 @@ async def get_channel_by_id(
for user in Users.get_users_by_user_ids(user_ids)
]
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
unread_count = Messages.get_unread_message_count(
channel.id, user.id, channel_member.last_read_at if channel_member else None
)
@ -365,7 +373,7 @@ async def get_channel_by_id(
**channel.model_dump(),
"user_ids": user_ids,
"users": users,
"is_manager": Channels.is_user_channel_manager(channel.id, user.id),
"is_manager": Channels.is_user_channel_manager(channel.id, user.id, db=db),
"write_access": True,
"user_count": len(user_ids),
"last_read_at": channel_member.last_read_at if channel_member else None,
@ -386,7 +394,7 @@ async def get_channel_by_id(
user_count = len(get_users_with_access("read", channel.access_control))
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id)
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
unread_count = Messages.get_unread_message_count(
channel.id, user.id, channel_member.last_read_at if channel_member else None
)
@ -396,7 +404,7 @@ async def get_channel_by_id(
**channel.model_dump(),
"user_ids": user_ids,
"users": users,
"is_manager": Channels.is_user_channel_manager(channel.id, user.id),
"is_manager": Channels.is_user_channel_manager(channel.id, user.id, db=db),
"write_access": write_access or user.role == "admin",
"user_count": user_count,
"last_read_at": channel_member.last_read_at if channel_member else None,
@ -422,10 +430,11 @@ async def get_channel_members_by_id(
direction: Optional[str] = None,
page: Optional[int] = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -437,7 +446,7 @@ async def get_channel_members_by_id(
skip = (page - 1) * limit
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -510,15 +519,16 @@ async def update_is_active_member_by_id_and_user_id(
id: str,
form_data: UpdateActiveMemberForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
@ -543,17 +553,18 @@ async def add_members_by_id(
id: str,
form_data: UpdateMembersForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -592,17 +603,18 @@ async def remove_members_by_id(
id: str,
form_data: RemoveMembersForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -614,7 +626,7 @@ async def remove_members_by_id(
)
try:
deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids)
deleted = Channels.remove_members_from_channel(channel.id, form_data.user_ids, db=db)
return deleted
except Exception as e:
@ -635,17 +647,18 @@ async def update_channel_by_id(
id: str,
form_data: ChannelForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -657,7 +670,7 @@ async def update_channel_by_id(
)
try:
channel = Channels.update_channel_by_id(id, form_data)
channel = Channels.update_channel_by_id(id, form_data, db=db)
return ChannelModel(**channel.model_dump())
except Exception as e:
log.exception(e)
@ -676,17 +689,18 @@ async def delete_channel_by_id(
request: Request,
id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
if user.role != "admin" and not has_permission(
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS
user.id, "features.channels", request.app.state.config.USER_PERMISSIONS, db=db
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -698,7 +712,7 @@ async def delete_channel_by_id(
)
try:
Channels.delete_channel_by_id(id)
Channels.delete_channel_by_id(id, db=db)
return True
except Exception as e:
log.exception(e)
@ -732,16 +746,17 @@ async def get_channel_messages(
skip: int = 0,
limit: int = 50,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -763,7 +778,7 @@ async def get_channel_messages(
messages = []
for message in message_list:
if message.user_id not in users:
user = Users.get_user_by_id(message.user_id)
user = Users.get_user_by_id(message.user_id, db=db)
users[message.user_id] = user
thread_replies = Messages.get_thread_replies_by_message_id(message.id)
@ -799,16 +814,17 @@ async def get_pinned_channel_messages(
id: str,
page: int = 1,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -830,7 +846,7 @@ async def get_pinned_channel_messages(
messages = []
for message in message_list:
if message.user_id not in users:
user = Users.get_user_by_id(message.user_id)
user = Users.get_user_by_id(message.user_id, db=db)
users[message.user_id] = user
messages.append(
@ -944,7 +960,7 @@ async def model_response_handler(request, channel, message, user):
for thread_message in thread_messages:
message_user = None
if thread_message.user_id not in message_users:
message_user = Users.get_user_by_id(thread_message.user_id)
message_user = Users.get_user_by_id(thread_message.user_id, db=db)
message_users[thread_message.user_id] = message_user
else:
message_user = message_users[thread_message.user_id]
@ -1059,39 +1075,39 @@ async def model_response_handler(request, channel, message, user):
async def new_message_handler(
request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user)
request: Request, id: str, form_data: MessageForm, user, db
):
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
else:
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
user.id, type="write", access_control=channel.access_control, strict=False, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
try:
message = Messages.insert_new_message(form_data, channel.id, user.id)
message = Messages.insert_new_message(form_data, channel.id, user.id, db=db)
if message:
if channel.type in ["group", "dm"]:
members = Channels.get_members_by_channel_id(channel.id)
members = Channels.get_members_by_channel_id(channel.id, db=db)
for member in members:
if not member.is_active:
Channels.update_member_active_status(
channel.id, member.user_id, True
channel.id, member.user_id, True, db=db
)
message = Messages.get_message_by_id(message.id)
message = Messages.get_message_by_id(message.id, db=db)
event_data = {
"channel_id": channel.id,
"message_id": message.id,
@ -1111,7 +1127,7 @@ async def new_message_handler(
if message.parent_id:
# If this message is a reply, emit to the parent message as well
parent_message = Messages.get_message_by_id(message.parent_id)
parent_message = Messages.get_message_by_id(message.parent_id, db=db)
if parent_message:
await sio.emit(
@ -1145,16 +1161,17 @@ async def post_new_message(
form_data: MessageForm,
background_tasks: BackgroundTasks,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
try:
message, channel = await new_message_handler(request, id, form_data, user)
message, channel = await new_message_handler(request, id, form_data, user, db)
try:
if files := message.data.get("files", []):
for file in files:
Channels.set_file_message_id_in_channel_by_id(
channel.id, file.get("id", ""), message.id
channel.id, file.get("id", ""), message.id, db=db
)
except Exception as e:
log.debug(e)
@ -1195,16 +1212,17 @@ async def get_channel_message(
id: str,
message_id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1216,7 +1234,7 @@ async def get_channel_message(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1231,7 +1249,7 @@ async def get_channel_message(
**{
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
**Users.get_user_by_id(message.user_id, db=db).model_dump()
),
}
)
@ -1248,16 +1266,17 @@ async def get_channel_message_data(
id: str,
message_id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1269,7 +1288,7 @@ async def get_channel_message_data(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1301,16 +1320,17 @@ async def pin_channel_message(
message_id: str,
form_data: PinMessageForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1322,7 +1342,7 @@ async def pin_channel_message(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1335,12 +1355,12 @@ async def pin_channel_message(
try:
Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
return MessageUserResponse(
**{
**message.model_dump(),
"user": UserNameResponse(
**Users.get_user_by_id(message.user_id).model_dump()
**Users.get_user_by_id(message.user_id, db=db).model_dump()
),
}
)
@ -1366,16 +1386,17 @@ async def get_channel_thread_messages(
skip: int = 0,
limit: int = 50,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1393,7 +1414,7 @@ async def get_channel_thread_messages(
messages = []
for message in message_list:
if message.user_id not in users:
user = Users.get_user_by_id(message.user_id)
user = Users.get_user_by_id(message.user_id, db=db)
users[message.user_id] = user
messages.append(
@ -1425,15 +1446,16 @@ async def update_message_by_id(
message_id: str,
form_data: MessageForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1445,7 +1467,7 @@ async def update_message_by_id(
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1463,7 +1485,7 @@ async def update_message_by_id(
try:
message = Messages.update_message_by_id(message_id, form_data)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if message:
await sio.emit(
@ -1505,16 +1527,17 @@ async def add_reaction_to_message(
message_id: str,
form_data: ReactionForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1526,7 +1549,7 @@ async def add_reaction_to_message(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1539,7 +1562,7 @@ async def add_reaction_to_message(
try:
Messages.add_reaction_to_message(message_id, user.id, form_data.name)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
await sio.emit(
"events:channel",
@ -1579,16 +1602,17 @@ async def remove_reaction_by_id_and_user_id_and_name(
message_id: str,
form_data: ReactionForm,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
@ -1600,7 +1624,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1616,7 +1640,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
message_id, user.id, form_data.name
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
await sio.emit(
"events:channel",
@ -1655,15 +1679,16 @@ async def delete_message_by_id(
id: str,
message_id: str,
user=Depends(get_verified_user),
db: Session = Depends(get_session),
):
check_channels_access(request)
channel = Channels.get_channel_by_id(id)
channel = Channels.get_channel_by_id(id, db=db)
if not channel:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
message = Messages.get_message_by_id(message_id)
message = Messages.get_message_by_id(message_id, db=db)
if not message:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@ -1675,7 +1700,7 @@ async def delete_message_by_id(
)
if channel.type in ["group", "dm"]:
if not Channels.is_user_channel_member(channel.id, user.id):
if not Channels.is_user_channel_member(channel.id, user.id, db=db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)

View file

@ -22,6 +22,8 @@ from open_webui.models.chats import Chats
from open_webui.routers.files import upload_file_handler, get_file_content_by_id
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.headers import include_user_info_headers
from open_webui.internal.db import get_session
from sqlalchemy.orm import Session
from open_webui.utils.images.comfyui import (
ComfyUICreateImageForm,
ComfyUIEditImageForm,
@ -496,7 +498,7 @@ def get_image_data(data: str, headers=None):
return None, None
def upload_image(request, image_data, content_type, metadata, user):
def upload_image(request, image_data, content_type, metadata, user, db=None):
image_format = mimetypes.guess_extension(content_type)
file = UploadFile(
file=io.BytesIO(image_data),
@ -524,6 +526,7 @@ def upload_image(request, image_data, content_type, metadata, user):
message_id=message_id,
file_ids=[file_item.id],
user_id=user.id,
db=db,
)
url = request.app.url_path_for("get_file_content_by_id", id=file_item.id)

View file

@ -300,7 +300,7 @@ def get_scim_auth(
)
def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser:
"""Convert internal User model to SCIM User"""
# Parse display name into name components
name_parts = user.name.split(" ", 1) if user.name else ["", ""]
@ -308,7 +308,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
family_name = name_parts[1] if len(name_parts) > 1 else ""
# Get user's groups
user_groups = Groups.get_groups_by_member_id(user.id)
user_groups = Groups.get_groups_by_member_id(user.id, db=db)
groups = [
{
"value": group.id,
@ -487,6 +487,7 @@ async def get_users(
count: int = Query(20, ge=1, le=100),
filter: Optional[str] = None,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""List SCIM Users"""
skip = startIndex - 1
@ -498,20 +499,20 @@ async def get_users(
# In production, you'd want a more robust filter parser
if "userName eq" in filter:
email = filter.split('"')[1]
user = Users.get_user_by_email(email)
user = Users.get_user_by_email(email, db=db)
users_list = [user] if user else []
total = 1 if user else 0
else:
response = Users.get_users(skip=skip, limit=limit)
response = Users.get_users(skip=skip, limit=limit, db=db)
users_list = response["users"]
total = response["total"]
else:
response = Users.get_users(skip=skip, limit=limit)
response = Users.get_users(skip=skip, limit=limit, db=db)
users_list = response["users"]
total = response["total"]
# Convert to SCIM format
scim_users = [user_to_scim(user, request) for user in users_list]
scim_users = [user_to_scim(user, request, db=db) for user in users_list]
return SCIMListResponse(
totalResults=total,
@ -526,15 +527,16 @@ async def get_user(
user_id: str,
request: Request,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Get SCIM User by ID"""
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if not user:
return scim_error(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
)
return user_to_scim(user, request)
return user_to_scim(user, request, db=db)
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
@ -542,10 +544,11 @@ async def create_user(
request: Request,
user_data: SCIMUserCreateRequest,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Create SCIM User"""
# Check if user already exists
existing_user = Users.get_user_by_email(user_data.userName)
existing_user = Users.get_user_by_email(user_data.userName, db=db)
if existing_user:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@ -576,6 +579,7 @@ async def create_user(
email=email,
profile_image_url=profile_image,
role="user" if user_data.active else "pending",
db=db,
)
if not new_user:
@ -584,7 +588,7 @@ async def create_user(
detail="Failed to create user",
)
return user_to_scim(new_user, request)
return user_to_scim(new_user, request, db=db)
@router.put("/Users/{user_id}", response_model=SCIMUser)
@ -593,9 +597,10 @@ async def update_user(
request: Request,
user_data: SCIMUserUpdateRequest,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Update SCIM User (full update)"""
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -628,14 +633,14 @@ async def update_user(
update_data["profile_image_url"] = user_data.photos[0].value
# Update user
updated_user = Users.update_user_by_id(user_id, update_data)
updated_user = Users.update_user_by_id(user_id, update_data, db=db)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update user",
)
return user_to_scim(updated_user, request)
return user_to_scim(updated_user, request, db=db)
@router.patch("/Users/{user_id}", response_model=SCIMUser)
@ -644,9 +649,10 @@ async def patch_user(
request: Request,
patch_data: SCIMPatchRequest,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Update SCIM User (partial update)"""
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -674,7 +680,7 @@ async def patch_user(
# Update user
if update_data:
updated_user = Users.update_user_by_id(user_id, update_data)
updated_user = Users.update_user_by_id(user_id, update_data, db=db)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -683,7 +689,7 @@ async def patch_user(
else:
updated_user = user
return user_to_scim(updated_user, request)
return user_to_scim(updated_user, request, db=db)
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@ -691,16 +697,17 @@ async def delete_user(
user_id: str,
request: Request,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Delete SCIM User"""
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {user_id} not found",
)
success = Users.delete_user_by_id(user_id)
success = Users.delete_user_by_id(user_id, db=db)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -749,7 +756,7 @@ async def get_group(
db: Session = Depends(get_session),
):
"""Get SCIM Group by ID"""
group = Groups.get_group_by_id(group_id)
group = Groups.get_group_by_id(group_id, db=db)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -921,16 +928,17 @@ async def delete_group(
group_id: str,
request: Request,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Delete SCIM Group"""
group = Groups.get_group_by_id(group_id)
group = Groups.get_group_by_id(group_id, db=db)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found",
)
success = Groups.delete_group_by_id(group_id)
success = Groups.delete_group_by_id(group_id, db=db)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,