diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index a27ae52519..ff4553ee9d 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -201,8 +201,14 @@ class MessageTable: with get_db() as db: message = db.get(Message, id) message.content = form_data.content - message.data = form_data.data - message.meta = form_data.meta + message.data = { + **(message.data if message.data else {}), + **(form_data.data if form_data.data else {}), + } + message.meta = { + **(message.meta if message.meta else {}), + **(form_data.meta if form_data.meta else {}), + } message.updated_at = int(time.time_ns()) db.commit() db.refresh(message) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 0a4c7e5d84..da52be6e79 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -24,9 +24,17 @@ from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.models import ( + get_all_models, + get_filtered_models, +) +from open_webui.utils.chat import generate_chat_completion + + from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, get_users_with_access from open_webui.utils.webhook import post_webhook +from open_webui.utils.channels import extract_mentions, replace_mentions log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -221,13 +229,131 @@ async def send_notification(name, webui_url, channel, message, active_user_ids): return True -@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) -async def post_new_message( - request: Request, - id: str, - form_data: MessageForm, - background_tasks: BackgroundTasks, - user=Depends(get_verified_user), +async def model_response_handler(request, channel, message, user): + MODELS = { + model["id"]: model + for model in get_filtered_models(await get_all_models(request, user=user), user) + } + + mentions = extract_mentions(message.content) + message_content = replace_mentions(message.content) + + # check if any of the mentions are models + model_mentions = [mention for mention in mentions if mention["id_type"] == "M"] + if not model_mentions: + return False + + for mention in model_mentions: + model_id = mention["id"] + model = MODELS.get(model_id, None) + + if model: + try: + # reverse to get in chronological order + thread_messages = Messages.get_messages_by_parent_id( + channel.id, + message.parent_id if message.parent_id else message.id, + )[::-1] + + response_message, channel = await new_message_handler( + request, + channel.id, + MessageForm( + **{ + "parent_id": ( + message.parent_id if message.parent_id else message.id + ), + "content": f"", + "data": {}, + "meta": { + "model_id": model_id, + "model_name": model.get("name", model_id), + }, + } + ), + user, + ) + + thread_history = [] + message_users = {} + + for thread_message in thread_messages: + message_user = None + if thread_message.user_id not in message_users: + message_user = Users.get_user_by_id(thread_message.user_id) + message_users[thread_message.user_id] = message_user + else: + message_user = message_users[thread_message.user_id] + + if thread_message.meta and thread_message.meta.get( + "model_id", None + ): + # If the message was sent by a model, use the model name + message_model_id = thread_message.meta.get("model_id", None) + message_model = MODELS.get(message_model_id, None) + username = ( + message_model.get("name", message_model_id) + if message_model + else message_model_id + ) + else: + username = message_user.name if message_user else "Unknown" + + thread_history.append( + f"{username}: {replace_mentions(thread_message.content)}" + ) + + system_message = { + "role": "system", + "content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational." + + ( + f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally, addressing the most recent message while being aware of the full context." + if thread_history + else "" + ), + } + + form_data = { + "model": model_id, + "messages": [ + system_message, + { + "role": "user", + "content": f"{user.name if user else 'User'}: {message_content}", + }, + ], + "stream": False, + } + + res = await generate_chat_completion( + request, + form_data=form_data, + user=user, + ) + + if res: + await update_message_by_id( + channel.id, + response_message.id, + MessageForm( + **{ + "content": res["choices"][0]["message"]["content"], + "meta": { + "done": True, + }, + } + ), + user, + ) + except Exception as e: + log.info(e) + pass + + return True + + +async def new_message_handler( + request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user) ): channel = Channels.get_channel_by_id(id) if not channel: @@ -301,21 +427,43 @@ async def post_new_message( }, to=f"channel:{channel.id}", ) + return MessageModel(**message.model_dump()), channel + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) - active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") - async def background_handler(): - await send_notification( - request.app.state.WEBUI_NAME, - request.app.state.config.WEBUI_URL, - channel, - message, - active_user_ids, - ) +@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) +async def post_new_message( + request: Request, + id: str, + form_data: MessageForm, + background_tasks: BackgroundTasks, + user=Depends(get_verified_user), +): - background_tasks.add_task(background_handler) + try: + message, channel = await new_message_handler(request, id, form_data, user) + active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") - return MessageModel(**message.model_dump()) + async def background_handler(): + await model_response_handler(request, channel, message, user) + await send_notification( + request.app.state.WEBUI_NAME, + request.app.state.config.WEBUI_URL, + channel, + message, + active_user_ids, + ) + + background_tasks.add_task(background_handler) + + return message + + except HTTPException as e: + raise e except Exception as e: log.exception(e) raise HTTPException( diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index a4d4e3668e..05d7c68006 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,4 +1,6 @@ from typing import Optional +import io +import base64 from open_webui.models.models import ( ModelForm, @@ -10,12 +12,13 @@ from open_webui.models.models import ( from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request, status, Response +from fastapi.responses import FileResponse, StreamingResponse from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission -from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR router = APIRouter() @@ -129,6 +132,39 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ) +########################### +# GetModelById +########################### + + +@router.get("/model/profile/image") +async def get_model_profile_image(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if model.meta.profile_image_url: + if model.meta.profile_image_url.startswith("http"): + return Response( + status_code=status.HTTP_302_FOUND, + headers={"Location": model.meta.profile_image_url}, + ) + elif model.meta.profile_image_url.startswith("data:image"): + try: + header, base64_data = model.meta.profile_image_url.split(",", 1) + image_data = base64.b64decode(base64_data) + image_buffer = io.BytesIO(image_data) + + return StreamingResponse( + image_buffer, + media_type="image/png", + headers={"Content-Disposition": "inline; filename=image.png"}, + ) + except Exception as e: + pass + return FileResponse(f"{STATIC_DIR}/favicon.png") + else: + return FileResponse(f"{STATIC_DIR}/favicon.png") + + ############################ # ToggleModelById ############################ diff --git a/backend/open_webui/utils/channels.py b/backend/open_webui/utils/channels.py new file mode 100644 index 0000000000..312b5ea24c --- /dev/null +++ b/backend/open_webui/utils/channels.py @@ -0,0 +1,31 @@ +import re + + +def extract_mentions(message: str, triggerChar: str = "@"): + # Escape triggerChar in case it's a regex special character + triggerChar = re.escape(triggerChar) + pattern = rf"<{triggerChar}([A-Z]):([^|>]+)" + + matches = re.findall(pattern, message) + return [{"id_type": id_type, "id": id_value} for id_type, id_value in matches] + + +def replace_mentions(message: str, triggerChar: str = "@", use_label: bool = True): + """ + Replace mentions in the message with either their label (after the pipe `|`) + or their id if no label exists. + + Example: + "<@M:gpt-4.1|GPT-4>" -> "GPT-4" (if use_label=True) + "<@M:gpt-4.1|GPT-4>" -> "gpt-4.1" (if use_label=False) + """ + # Escape triggerChar + triggerChar = re.escape(triggerChar) + + def replacer(match): + id_type, id_value, label = match.groups() + return label if use_label and label else id_value + + # Regex captures: idType, id, optional label + pattern = rf"<{triggerChar}([A-Z]):([^|>]+)(?:\|([^>]+))?>" + return re.sub(pattern, replacer, message) diff --git a/src/lib/components/channel/Messages.svelte b/src/lib/components/channel/Messages.svelte index a2a35b2001..540891b500 100644 --- a/src/lib/components/channel/Messages.svelte +++ b/src/lib/components/channel/Messages.svelte @@ -95,7 +95,8 @@ {message} {thread} showUserProfile={messageIdx === 0 || - messageList.at(messageIdx - 1)?.user_id !== message.user_id} + messageList.at(messageIdx - 1)?.user_id !== message.user_id || + messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id} onDelete={() => { messages = messages.filter((m) => m.id !== message.id); diff --git a/src/lib/components/channel/Messages/Message.svelte b/src/lib/components/channel/Messages/Message.svelte index 541d4f3450..4ea6a67aea 100644 --- a/src/lib/components/channel/Messages/Message.svelte +++ b/src/lib/components/channel/Messages/Message.svelte @@ -15,7 +15,7 @@ import { settings, user, shortCodesToEmojis } from '$lib/stores'; - import { WEBUI_BASE_URL } from '$lib/constants'; + import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import Markdown from '$lib/components/chat/Messages/Markdown.svelte'; import ProfileImage from '$lib/components/chat/Messages/ProfileImage.svelte'; @@ -34,6 +34,8 @@ import ChevronRight from '$lib/components/icons/ChevronRight.svelte'; import { formatDate } from '$lib/utils'; import Emoji from '$lib/components/common/Emoji.svelte'; + import { t } from 'i18next'; + import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte'; export let message; export let showUserProfile = true; @@ -138,12 +140,20 @@ >
{#if showUserProfile} - - - + {:else} + + + + {/if} {:else} @@ -163,7 +173,11 @@ {#if showUserProfile}
- {message?.user?.name} + {#if message?.meta?.model_id} + {message?.meta?.model_name ?? message?.meta?.model_id} + {:else} + {message?.user?.name} + {/if}
{#if message.created_at} @@ -251,12 +265,16 @@
{:else}
- {#if message.created_at !== message.updated_at}(edited){/if} + {#if (message?.content ?? '').trim() === '' && message?.meta?.model_id} + + {:else} + {#if message.created_at !== message.updated_at && (message?.meta?.model_id ?? null) === null}({$i18n.t('edited')}){/if} + {/if}
{#if (message?.reactions ?? []).length > 0} diff --git a/src/lib/components/channel/Messages/Message/UserStatusLinkPreview.svelte b/src/lib/components/channel/Messages/Message/UserStatusLinkPreview.svelte index 23dad6e00e..0660548891 100644 --- a/src/lib/components/channel/Messages/Message/UserStatusLinkPreview.svelte +++ b/src/lib/components/channel/Messages/Message/UserStatusLinkPreview.svelte @@ -27,7 +27,7 @@ {#if user}