mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 06:35:20 +00:00
refac
This commit is contained in:
parent
2453b75ff0
commit
5649a668fa
1 changed files with 33 additions and 28 deletions
|
|
@ -353,14 +353,14 @@ async def get_channel_by_id(
|
|||
)
|
||||
|
||||
user_ids = [
|
||||
member.user_id for member in Channels.get_members_by_channel_id(channel.id)
|
||||
member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)
|
||||
]
|
||||
|
||||
users = [
|
||||
UserIdNameStatusResponse(
|
||||
**{**user.model_dump(), "is_active": Users.is_user_active(user.id)}
|
||||
**{**user.model_dump(), "is_active": Users.is_user_active(user.id, db=db)}
|
||||
)
|
||||
for user in Users.get_users_by_user_ids(user_ids)
|
||||
for user in Users.get_users_by_user_ids(user_ids, db=db)
|
||||
]
|
||||
|
||||
channel_member = Channels.get_member_by_channel_and_user_id(channel.id, user.id, db=db)
|
||||
|
|
@ -453,15 +453,15 @@ async def get_channel_members_by_id(
|
|||
|
||||
if channel.type == "dm":
|
||||
user_ids = [
|
||||
member.user_id for member in Channels.get_members_by_channel_id(channel.id)
|
||||
member.user_id for member in Channels.get_members_by_channel_id(channel.id, db=db)
|
||||
]
|
||||
users = Users.get_users_by_user_ids(user_ids)
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db)
|
||||
total = len(users)
|
||||
|
||||
return {
|
||||
"users": [
|
||||
UserModelResponse(
|
||||
**user.model_dump(), is_active=Users.is_user_active(user.id)
|
||||
**user.model_dump(), is_active=Users.is_user_active(user.id, db=db)
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
|
|
@ -488,7 +488,7 @@ async def get_channel_members_by_id(
|
|||
filter["user_ids"] = permitted_ids.get("user_ids")
|
||||
filter["group_ids"] = permitted_ids.get("group_ids")
|
||||
|
||||
result = Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
result = Users.get_users(filter=filter, skip=skip, limit=limit, db=db)
|
||||
|
||||
users = result["users"]
|
||||
total = result["total"]
|
||||
|
|
@ -496,7 +496,7 @@ async def get_channel_members_by_id(
|
|||
return {
|
||||
"users": [
|
||||
UserModelResponse(
|
||||
**user.model_dump(), is_active=Users.is_user_active(user.id)
|
||||
**user.model_dump(), is_active=Users.is_user_active(user.id, db=db)
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
|
|
@ -533,7 +533,7 @@ async def update_is_active_member_by_id_and_user_id(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
Channels.update_member_active_status(channel.id, user.id, form_data.is_active)
|
||||
Channels.update_member_active_status(channel.id, user.id, form_data.is_active, db=db)
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -577,7 +577,7 @@ async def add_members_by_id(
|
|||
|
||||
try:
|
||||
memberships = Channels.add_members_to_channel(
|
||||
channel.id, user.id, form_data.user_ids, form_data.group_ids
|
||||
channel.id, user.id, form_data.user_ids, form_data.group_ids, db=db
|
||||
)
|
||||
|
||||
return memberships
|
||||
|
|
@ -769,10 +769,10 @@ async def get_channel_messages(
|
|||
)
|
||||
|
||||
channel_member = Channels.join_channel(
|
||||
id, user.id
|
||||
id, user.id, db=db
|
||||
) # Ensure user is a member of the channel
|
||||
|
||||
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
|
||||
message_list = Messages.get_messages_by_channel_id(id, skip, limit, db=db)
|
||||
users = {}
|
||||
|
||||
messages = []
|
||||
|
|
@ -781,7 +781,7 @@ async def get_channel_messages(
|
|||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
users[message.user_id] = user
|
||||
|
||||
thread_replies = Messages.get_thread_replies_by_message_id(message.id)
|
||||
thread_replies = Messages.get_thread_replies_by_message_id(message.id, db=db)
|
||||
latest_thread_reply_at = (
|
||||
thread_replies[0].created_at if thread_replies else None
|
||||
)
|
||||
|
|
@ -792,7 +792,7 @@ async def get_channel_messages(
|
|||
**message.model_dump(),
|
||||
"reply_count": len(thread_replies),
|
||||
"latest_reply_at": latest_thread_reply_at,
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id),
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id, db=db),
|
||||
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||
}
|
||||
)
|
||||
|
|
@ -840,7 +840,7 @@ async def get_pinned_channel_messages(
|
|||
skip = (page - 1) * PAGE_ITEM_COUNT_PINNED
|
||||
limit = PAGE_ITEM_COUNT_PINNED
|
||||
|
||||
message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit)
|
||||
message_list = Messages.get_pinned_messages_by_channel_id(id, skip, limit, db=db)
|
||||
users = {}
|
||||
|
||||
messages = []
|
||||
|
|
@ -853,7 +853,7 @@ async def get_pinned_channel_messages(
|
|||
MessageWithReactionsResponse(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id),
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id, db=db),
|
||||
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||
}
|
||||
)
|
||||
|
|
@ -867,12 +867,12 @@ async def get_pinned_channel_messages(
|
|||
############################
|
||||
|
||||
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids, db=None):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
if (user.id not in active_user_ids) and Channels.is_user_channel_member(
|
||||
channel.id, user.id
|
||||
channel.id, user.id, db=db
|
||||
):
|
||||
if user.settings:
|
||||
webhook_url = user.settings.ui.get("notifications", {}).get(
|
||||
|
|
@ -894,7 +894,7 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
|
|||
return True
|
||||
|
||||
|
||||
async def model_response_handler(request, channel, message, user):
|
||||
async def model_response_handler(request, channel, message, user, db=None):
|
||||
MODELS = {
|
||||
model["id"]: model
|
||||
for model in get_filtered_models(await get_all_models(request, user=user), user)
|
||||
|
|
@ -932,6 +932,7 @@ async def model_response_handler(request, channel, message, user):
|
|||
thread_messages = Messages.get_messages_by_parent_id(
|
||||
channel.id,
|
||||
message.parent_id if message.parent_id else message.id,
|
||||
db=db,
|
||||
)[::-1]
|
||||
|
||||
response_message, channel = await new_message_handler(
|
||||
|
|
@ -951,6 +952,7 @@ async def model_response_handler(request, channel, message, user):
|
|||
}
|
||||
),
|
||||
user,
|
||||
db,
|
||||
)
|
||||
|
||||
thread_history = []
|
||||
|
|
@ -1051,6 +1053,7 @@ async def model_response_handler(request, channel, message, user):
|
|||
}
|
||||
),
|
||||
user,
|
||||
db,
|
||||
)
|
||||
elif res.get("error", None):
|
||||
await update_message_by_id(
|
||||
|
|
@ -1066,6 +1069,7 @@ async def model_response_handler(request, channel, message, user):
|
|||
}
|
||||
),
|
||||
user,
|
||||
db,
|
||||
)
|
||||
except Exception as e:
|
||||
log.info(e)
|
||||
|
|
@ -1179,13 +1183,14 @@ async def post_new_message(
|
|||
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
||||
|
||||
async def background_handler():
|
||||
await model_response_handler(request, channel, message, user)
|
||||
await model_response_handler(request, channel, message, user, db)
|
||||
await send_notification(
|
||||
request.app.state.WEBUI_NAME,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
active_user_ids,
|
||||
db=db,
|
||||
)
|
||||
|
||||
background_tasks.add_task(background_handler)
|
||||
|
|
@ -1354,7 +1359,7 @@ async def pin_channel_message(
|
|||
)
|
||||
|
||||
try:
|
||||
Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id)
|
||||
Messages.update_is_pinned_by_id(message_id, form_data.is_pinned, user.id, db=db)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
return MessageUserResponse(
|
||||
**{
|
||||
|
|
@ -1408,7 +1413,7 @@ async def get_channel_thread_messages(
|
|||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit)
|
||||
message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit, db=db)
|
||||
users = {}
|
||||
|
||||
messages = []
|
||||
|
|
@ -1423,7 +1428,7 @@ async def get_channel_thread_messages(
|
|||
**message.model_dump(),
|
||||
"reply_count": 0,
|
||||
"latest_reply_at": None,
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id),
|
||||
"reactions": Messages.get_reactions_by_message_id(message.id, db=db),
|
||||
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||
}
|
||||
)
|
||||
|
|
@ -1484,7 +1489,7 @@ async def update_message_by_id(
|
|||
)
|
||||
|
||||
try:
|
||||
message = Messages.update_message_by_id(message_id, form_data)
|
||||
message = Messages.update_message_by_id(message_id, form_data, db=db)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
|
||||
if message:
|
||||
|
|
@ -1561,7 +1566,7 @@ async def add_reaction_to_message(
|
|||
)
|
||||
|
||||
try:
|
||||
Messages.add_reaction_to_message(message_id, user.id, form_data.name)
|
||||
Messages.add_reaction_to_message(message_id, user.id, form_data.name, db=db)
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
|
||||
await sio.emit(
|
||||
|
|
@ -1637,7 +1642,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
|||
|
||||
try:
|
||||
Messages.remove_reaction_by_id_and_user_id_and_name(
|
||||
message_id, user.id, form_data.name
|
||||
message_id, user.id, form_data.name, db=db
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id, db=db)
|
||||
|
|
@ -1721,7 +1726,7 @@ async def delete_message_by_id(
|
|||
)
|
||||
|
||||
try:
|
||||
Messages.delete_message_by_id(message_id)
|
||||
Messages.delete_message_by_id(message_id, db=db)
|
||||
await sio.emit(
|
||||
"events:channel",
|
||||
{
|
||||
|
|
@ -1742,7 +1747,7 @@ async def delete_message_by_id(
|
|||
|
||||
if message.parent_id:
|
||||
# If this message is a reply, emit to the parent message as well
|
||||
parent_message = Messages.get_message_by_id(message.parent_id)
|
||||
parent_message = Messages.get_message_by_id(message.parent_id, db=db)
|
||||
|
||||
if parent_message:
|
||||
await sio.emit(
|
||||
|
|
|
|||
Loading…
Reference in a new issue