feat: channel/thread @ model

This commit is contained in:
Timothy Jaeryang Baek 2025-09-17 00:49:44 -05:00
parent 9738ddfd99
commit 4fe97d8794
7 changed files with 277 additions and 37 deletions

View file

@ -201,8 +201,14 @@ class MessageTable:
with get_db() as db: with get_db() as db:
message = db.get(Message, id) message = db.get(Message, id)
message.content = form_data.content message.content = form_data.content
message.data = form_data.data message.data = {
message.meta = form_data.meta **(message.data if message.data else {}),
**(form_data.data if form_data.data else {}),
}
message.meta = {
**(message.meta if message.meta else {}),
**(form_data.meta if form_data.meta else {}),
}
message.updated_at = int(time.time_ns()) message.updated_at = int(time.time_ns())
db.commit() db.commit()
db.refresh(message) db.refresh(message)

View file

@ -24,9 +24,17 @@ from open_webui.constants import ERROR_MESSAGES
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
from open_webui.utils.models import (
get_all_models,
get_filtered_models,
)
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, get_users_with_access from open_webui.utils.access_control import has_access, get_users_with_access
from open_webui.utils.webhook import post_webhook from open_webui.utils.webhook import post_webhook
from open_webui.utils.channels import extract_mentions, replace_mentions
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"]) log.setLevel(SRC_LOG_LEVELS["MODELS"])
@ -221,13 +229,131 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
return True return True
@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) async def model_response_handler(request, channel, message, user):
async def post_new_message( MODELS = {
request: Request, model["id"]: model
id: str, for model in get_filtered_models(await get_all_models(request, user=user), user)
form_data: MessageForm, }
background_tasks: BackgroundTasks,
user=Depends(get_verified_user), mentions = extract_mentions(message.content)
message_content = replace_mentions(message.content)
# check if any of the mentions are models
model_mentions = [mention for mention in mentions if mention["id_type"] == "M"]
if not model_mentions:
return False
for mention in model_mentions:
model_id = mention["id"]
model = MODELS.get(model_id, None)
if model:
try:
# reverse to get in chronological order
thread_messages = Messages.get_messages_by_parent_id(
channel.id,
message.parent_id if message.parent_id else message.id,
)[::-1]
response_message, channel = await new_message_handler(
request,
channel.id,
MessageForm(
**{
"parent_id": (
message.parent_id if message.parent_id else message.id
),
"content": f"",
"data": {},
"meta": {
"model_id": model_id,
"model_name": model.get("name", model_id),
},
}
),
user,
)
thread_history = []
message_users = {}
for thread_message in thread_messages:
message_user = None
if thread_message.user_id not in message_users:
message_user = Users.get_user_by_id(thread_message.user_id)
message_users[thread_message.user_id] = message_user
else:
message_user = message_users[thread_message.user_id]
if thread_message.meta and thread_message.meta.get(
"model_id", None
):
# If the message was sent by a model, use the model name
message_model_id = thread_message.meta.get("model_id", None)
message_model = MODELS.get(message_model_id, None)
username = (
message_model.get("name", message_model_id)
if message_model
else message_model_id
)
else:
username = message_user.name if message_user else "Unknown"
thread_history.append(
f"{username}: {replace_mentions(thread_message.content)}"
)
system_message = {
"role": "system",
"content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational."
+ (
f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally, addressing the most recent message while being aware of the full context."
if thread_history
else ""
),
}
form_data = {
"model": model_id,
"messages": [
system_message,
{
"role": "user",
"content": f"{user.name if user else 'User'}: {message_content}",
},
],
"stream": False,
}
res = await generate_chat_completion(
request,
form_data=form_data,
user=user,
)
if res:
await update_message_by_id(
channel.id,
response_message.id,
MessageForm(
**{
"content": res["choices"][0]["message"]["content"],
"meta": {
"done": True,
},
}
),
user,
)
except Exception as e:
log.info(e)
pass
return True
async def new_message_handler(
request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user)
): ):
channel = Channels.get_channel_by_id(id) channel = Channels.get_channel_by_id(id)
if not channel: if not channel:
@ -301,10 +427,29 @@ async def post_new_message(
}, },
to=f"channel:{channel.id}", to=f"channel:{channel.id}",
) )
return MessageModel(**message.model_dump()), channel
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
async def post_new_message(
request: Request,
id: str,
form_data: MessageForm,
background_tasks: BackgroundTasks,
user=Depends(get_verified_user),
):
try:
message, channel = await new_message_handler(request, id, form_data, user)
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
async def background_handler(): async def background_handler():
await model_response_handler(request, channel, message, user)
await send_notification( await send_notification(
request.app.state.WEBUI_NAME, request.app.state.WEBUI_NAME,
request.app.state.config.WEBUI_URL, request.app.state.config.WEBUI_URL,
@ -315,7 +460,10 @@ async def post_new_message(
background_tasks.add_task(background_handler) background_tasks.add_task(background_handler)
return MessageModel(**message.model_dump()) return message
except HTTPException as e:
raise e
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(

View file

@ -1,4 +1,6 @@
from typing import Optional from typing import Optional
import io
import base64
from open_webui.models.models import ( from open_webui.models.models import (
ModelForm, ModelForm,
@ -10,12 +12,13 @@ from open_webui.models.models import (
from pydantic import BaseModel from pydantic import BaseModel
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
from fastapi.responses import FileResponse, StreamingResponse
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission from open_webui.utils.access_control import has_access, has_permission
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
router = APIRouter() router = APIRouter()
@ -129,6 +132,39 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
) )
###########################
# GetModelById
###########################
@router.get("/model/profile/image")
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
if model.meta.profile_image_url:
if model.meta.profile_image_url.startswith("http"):
return Response(
status_code=status.HTTP_302_FOUND,
headers={"Location": model.meta.profile_image_url},
)
elif model.meta.profile_image_url.startswith("data:image"):
try:
header, base64_data = model.meta.profile_image_url.split(",", 1)
image_data = base64.b64decode(base64_data)
image_buffer = io.BytesIO(image_data)
return StreamingResponse(
image_buffer,
media_type="image/png",
headers={"Content-Disposition": "inline; filename=image.png"},
)
except Exception as e:
pass
return FileResponse(f"{STATIC_DIR}/favicon.png")
else:
return FileResponse(f"{STATIC_DIR}/favicon.png")
############################ ############################
# ToggleModelById # ToggleModelById
############################ ############################

View file

@ -0,0 +1,31 @@
import re
def extract_mentions(message: str, triggerChar: str = "@"):
# Escape triggerChar in case it's a regex special character
triggerChar = re.escape(triggerChar)
pattern = rf"<{triggerChar}([A-Z]):([^|>]+)"
matches = re.findall(pattern, message)
return [{"id_type": id_type, "id": id_value} for id_type, id_value in matches]
def replace_mentions(message: str, triggerChar: str = "@", use_label: bool = True):
"""
Replace mentions in the message with either their label (after the pipe `|`)
or their id if no label exists.
Example:
"<@M:gpt-4.1|GPT-4>" -> "GPT-4" (if use_label=True)
"<@M:gpt-4.1|GPT-4>" -> "gpt-4.1" (if use_label=False)
"""
# Escape triggerChar
triggerChar = re.escape(triggerChar)
def replacer(match):
id_type, id_value, label = match.groups()
return label if use_label and label else id_value
# Regex captures: idType, id, optional label
pattern = rf"<{triggerChar}([A-Z]):([^|>]+)(?:\|([^>]+))?>"
return re.sub(pattern, replacer, message)

View file

@ -95,7 +95,8 @@
{message} {message}
{thread} {thread}
showUserProfile={messageIdx === 0 || showUserProfile={messageIdx === 0 ||
messageList.at(messageIdx - 1)?.user_id !== message.user_id} messageList.at(messageIdx - 1)?.user_id !== message.user_id ||
messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id}
onDelete={() => { onDelete={() => {
messages = messages.filter((m) => m.id !== message.id); messages = messages.filter((m) => m.id !== message.id);

View file

@ -15,7 +15,7 @@
import { settings, user, shortCodesToEmojis } from '$lib/stores'; import { settings, user, shortCodesToEmojis } from '$lib/stores';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
import Markdown from '$lib/components/chat/Messages/Markdown.svelte'; import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
import ProfileImage from '$lib/components/chat/Messages/ProfileImage.svelte'; import ProfileImage from '$lib/components/chat/Messages/ProfileImage.svelte';
@ -34,6 +34,8 @@
import ChevronRight from '$lib/components/icons/ChevronRight.svelte'; import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
import { formatDate } from '$lib/utils'; import { formatDate } from '$lib/utils';
import Emoji from '$lib/components/common/Emoji.svelte'; import Emoji from '$lib/components/common/Emoji.svelte';
import { t } from 'i18next';
import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte';
export let message; export let message;
export let showUserProfile = true; export let showUserProfile = true;
@ -138,12 +140,20 @@
> >
<div class={`shrink-0 mr-3 w-9`}> <div class={`shrink-0 mr-3 w-9`}>
{#if showUserProfile} {#if showUserProfile}
{#if message?.meta?.model_id}
<img
src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${message.meta.model_id}`}
alt={message.meta.model_name ?? message.meta.model_id}
class="size-8 translate-y-1 ml-0.5 object-cover rounded-full"
/>
{:else}
<ProfilePreview user={message.user}> <ProfilePreview user={message.user}>
<ProfileImage <ProfileImage
src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`} src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
className={'size-8 translate-y-1 ml-0.5'} className={'size-8 translate-y-1 ml-0.5'}
/> />
</ProfilePreview> </ProfilePreview>
{/if}
{:else} {:else}
<!-- <div class="w-7 h-7 rounded-full bg-transparent" /> --> <!-- <div class="w-7 h-7 rounded-full bg-transparent" /> -->
@ -163,7 +173,11 @@
{#if showUserProfile} {#if showUserProfile}
<Name> <Name>
<div class=" self-end text-base shrink-0 font-medium truncate"> <div class=" self-end text-base shrink-0 font-medium truncate">
{#if message?.meta?.model_id}
{message?.meta?.model_name ?? message?.meta?.model_id}
{:else}
{message?.user?.name} {message?.user?.name}
{/if}
</div> </div>
{#if message.created_at} {#if message.created_at}
@ -251,12 +265,16 @@
</div> </div>
{:else} {:else}
<div class=" min-w-full markdown-prose"> <div class=" min-w-full markdown-prose">
{#if (message?.content ?? '').trim() === '' && message?.meta?.model_id}
<Skeleton />
{:else}
<Markdown <Markdown
id={message.id} id={message.id}
content={message.content} content={message.content}
/>{#if message.created_at !== message.updated_at}<span class="text-gray-500 text-[10px]" />{#if message.created_at !== message.updated_at && (message?.meta?.model_id ?? null) === null}<span
>(edited)</span class="text-gray-500 text-[10px]">({$i18n.t('edited')})</span
>{/if} >{/if}
{/if}
</div> </div>
{#if (message?.reactions ?? []).length > 0} {#if (message?.reactions ?? []).length > 0}

View file

@ -27,7 +27,7 @@
{#if user} {#if user}
<LinkPreview.Content <LinkPreview.Content
class="w-full max-w-[260px] rounded-2xl border border-gray-100 dark:border-gray-800 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-lg transition" class="w-full max-w-[260px] rounded-2xl border border-gray-100 dark:border-gray-800 z-999 bg-white dark:bg-gray-850 dark:text-white shadow-lg transition"
{side} {side}
{align} {align}
{sideOffset} {sideOffset}