mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: channel/thread @ model
This commit is contained in:
parent
9738ddfd99
commit
4fe97d8794
7 changed files with 277 additions and 37 deletions
|
|
@ -201,8 +201,14 @@ class MessageTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
message = db.get(Message, id)
|
message = db.get(Message, id)
|
||||||
message.content = form_data.content
|
message.content = form_data.content
|
||||||
message.data = form_data.data
|
message.data = {
|
||||||
message.meta = form_data.meta
|
**(message.data if message.data else {}),
|
||||||
|
**(form_data.data if form_data.data else {}),
|
||||||
|
}
|
||||||
|
message.meta = {
|
||||||
|
**(message.meta if message.meta else {}),
|
||||||
|
**(form_data.meta if form_data.meta else {}),
|
||||||
|
}
|
||||||
message.updated_at = int(time.time_ns())
|
message.updated_at = int(time.time_ns())
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(message)
|
db.refresh(message)
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,17 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
||||||
|
from open_webui.utils.models import (
|
||||||
|
get_all_models,
|
||||||
|
get_filtered_models,
|
||||||
|
)
|
||||||
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, get_users_with_access
|
from open_webui.utils.access_control import has_access, get_users_with_access
|
||||||
from open_webui.utils.webhook import post_webhook
|
from open_webui.utils.webhook import post_webhook
|
||||||
|
from open_webui.utils.channels import extract_mentions, replace_mentions
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||||
|
|
@ -221,13 +229,131 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
|
async def model_response_handler(request, channel, message, user):
|
||||||
async def post_new_message(
|
MODELS = {
|
||||||
request: Request,
|
model["id"]: model
|
||||||
id: str,
|
for model in get_filtered_models(await get_all_models(request, user=user), user)
|
||||||
form_data: MessageForm,
|
}
|
||||||
background_tasks: BackgroundTasks,
|
|
||||||
user=Depends(get_verified_user),
|
mentions = extract_mentions(message.content)
|
||||||
|
message_content = replace_mentions(message.content)
|
||||||
|
|
||||||
|
# check if any of the mentions are models
|
||||||
|
model_mentions = [mention for mention in mentions if mention["id_type"] == "M"]
|
||||||
|
if not model_mentions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for mention in model_mentions:
|
||||||
|
model_id = mention["id"]
|
||||||
|
model = MODELS.get(model_id, None)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
try:
|
||||||
|
# reverse to get in chronological order
|
||||||
|
thread_messages = Messages.get_messages_by_parent_id(
|
||||||
|
channel.id,
|
||||||
|
message.parent_id if message.parent_id else message.id,
|
||||||
|
)[::-1]
|
||||||
|
|
||||||
|
response_message, channel = await new_message_handler(
|
||||||
|
request,
|
||||||
|
channel.id,
|
||||||
|
MessageForm(
|
||||||
|
**{
|
||||||
|
"parent_id": (
|
||||||
|
message.parent_id if message.parent_id else message.id
|
||||||
|
),
|
||||||
|
"content": f"",
|
||||||
|
"data": {},
|
||||||
|
"meta": {
|
||||||
|
"model_id": model_id,
|
||||||
|
"model_name": model.get("name", model_id),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_history = []
|
||||||
|
message_users = {}
|
||||||
|
|
||||||
|
for thread_message in thread_messages:
|
||||||
|
message_user = None
|
||||||
|
if thread_message.user_id not in message_users:
|
||||||
|
message_user = Users.get_user_by_id(thread_message.user_id)
|
||||||
|
message_users[thread_message.user_id] = message_user
|
||||||
|
else:
|
||||||
|
message_user = message_users[thread_message.user_id]
|
||||||
|
|
||||||
|
if thread_message.meta and thread_message.meta.get(
|
||||||
|
"model_id", None
|
||||||
|
):
|
||||||
|
# If the message was sent by a model, use the model name
|
||||||
|
message_model_id = thread_message.meta.get("model_id", None)
|
||||||
|
message_model = MODELS.get(message_model_id, None)
|
||||||
|
username = (
|
||||||
|
message_model.get("name", message_model_id)
|
||||||
|
if message_model
|
||||||
|
else message_model_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
username = message_user.name if message_user else "Unknown"
|
||||||
|
|
||||||
|
thread_history.append(
|
||||||
|
f"{username}: {replace_mentions(thread_message.content)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
system_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational."
|
||||||
|
+ (
|
||||||
|
f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally, addressing the most recent message while being aware of the full context."
|
||||||
|
if thread_history
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
form_data = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": [
|
||||||
|
system_message,
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{user.name if user else 'User'}: {message_content}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
res = await generate_chat_completion(
|
||||||
|
request,
|
||||||
|
form_data=form_data,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
if res:
|
||||||
|
await update_message_by_id(
|
||||||
|
channel.id,
|
||||||
|
response_message.id,
|
||||||
|
MessageForm(
|
||||||
|
**{
|
||||||
|
"content": res["choices"][0]["message"]["content"],
|
||||||
|
"meta": {
|
||||||
|
"done": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.info(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def new_message_handler(
|
||||||
|
request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
channel = Channels.get_channel_by_id(id)
|
channel = Channels.get_channel_by_id(id)
|
||||||
if not channel:
|
if not channel:
|
||||||
|
|
@ -301,10 +427,29 @@ async def post_new_message(
|
||||||
},
|
},
|
||||||
to=f"channel:{channel.id}",
|
to=f"channel:{channel.id}",
|
||||||
)
|
)
|
||||||
|
return MessageModel(**message.model_dump()), channel
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
|
||||||
|
async def post_new_message(
|
||||||
|
request: Request,
|
||||||
|
id: str,
|
||||||
|
form_data: MessageForm,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
user=Depends(get_verified_user),
|
||||||
|
):
|
||||||
|
|
||||||
|
try:
|
||||||
|
message, channel = await new_message_handler(request, id, form_data, user)
|
||||||
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
||||||
|
|
||||||
async def background_handler():
|
async def background_handler():
|
||||||
|
await model_response_handler(request, channel, message, user)
|
||||||
await send_notification(
|
await send_notification(
|
||||||
request.app.state.WEBUI_NAME,
|
request.app.state.WEBUI_NAME,
|
||||||
request.app.state.config.WEBUI_URL,
|
request.app.state.config.WEBUI_URL,
|
||||||
|
|
@ -315,7 +460,10 @@ async def post_new_message(
|
||||||
|
|
||||||
background_tasks.add_task(background_handler)
|
background_tasks.add_task(background_handler)
|
||||||
|
|
||||||
return MessageModel(**message.model_dump())
|
return message
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
|
||||||
from open_webui.models.models import (
|
from open_webui.models.models import (
|
||||||
ModelForm,
|
ModelForm,
|
||||||
|
|
@ -10,12 +12,13 @@ from open_webui.models.models import (
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status, Response
|
||||||
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.utils.access_control import has_access, has_permission
|
from open_webui.utils.access_control import has_access, has_permission
|
||||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -129,6 +132,39 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# GetModelById
|
||||||
|
###########################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/model/profile/image")
|
||||||
|
async def get_model_profile_image(id: str, user=Depends(get_verified_user)):
|
||||||
|
model = Models.get_model_by_id(id)
|
||||||
|
if model:
|
||||||
|
if model.meta.profile_image_url:
|
||||||
|
if model.meta.profile_image_url.startswith("http"):
|
||||||
|
return Response(
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
headers={"Location": model.meta.profile_image_url},
|
||||||
|
)
|
||||||
|
elif model.meta.profile_image_url.startswith("data:image"):
|
||||||
|
try:
|
||||||
|
header, base64_data = model.meta.profile_image_url.split(",", 1)
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
image_buffer = io.BytesIO(image_data)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
image_buffer,
|
||||||
|
media_type="image/png",
|
||||||
|
headers={"Content-Disposition": "inline; filename=image.png"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
||||||
|
else:
|
||||||
|
return FileResponse(f"{STATIC_DIR}/favicon.png")
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# ToggleModelById
|
# ToggleModelById
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
31
backend/open_webui/utils/channels.py
Normal file
31
backend/open_webui/utils/channels.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_mentions(message: str, triggerChar: str = "@"):
|
||||||
|
# Escape triggerChar in case it's a regex special character
|
||||||
|
triggerChar = re.escape(triggerChar)
|
||||||
|
pattern = rf"<{triggerChar}([A-Z]):([^|>]+)"
|
||||||
|
|
||||||
|
matches = re.findall(pattern, message)
|
||||||
|
return [{"id_type": id_type, "id": id_value} for id_type, id_value in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mentions(message: str, triggerChar: str = "@", use_label: bool = True):
|
||||||
|
"""
|
||||||
|
Replace mentions in the message with either their label (after the pipe `|`)
|
||||||
|
or their id if no label exists.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
"<@M:gpt-4.1|GPT-4>" -> "GPT-4" (if use_label=True)
|
||||||
|
"<@M:gpt-4.1|GPT-4>" -> "gpt-4.1" (if use_label=False)
|
||||||
|
"""
|
||||||
|
# Escape triggerChar
|
||||||
|
triggerChar = re.escape(triggerChar)
|
||||||
|
|
||||||
|
def replacer(match):
|
||||||
|
id_type, id_value, label = match.groups()
|
||||||
|
return label if use_label and label else id_value
|
||||||
|
|
||||||
|
# Regex captures: idType, id, optional label
|
||||||
|
pattern = rf"<{triggerChar}([A-Z]):([^|>]+)(?:\|([^>]+))?>"
|
||||||
|
return re.sub(pattern, replacer, message)
|
||||||
|
|
@ -95,7 +95,8 @@
|
||||||
{message}
|
{message}
|
||||||
{thread}
|
{thread}
|
||||||
showUserProfile={messageIdx === 0 ||
|
showUserProfile={messageIdx === 0 ||
|
||||||
messageList.at(messageIdx - 1)?.user_id !== message.user_id}
|
messageList.at(messageIdx - 1)?.user_id !== message.user_id ||
|
||||||
|
messageList.at(messageIdx - 1)?.meta?.model_id !== message?.meta?.model_id}
|
||||||
onDelete={() => {
|
onDelete={() => {
|
||||||
messages = messages.filter((m) => m.id !== message.id);
|
messages = messages.filter((m) => m.id !== message.id);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import { settings, user, shortCodesToEmojis } from '$lib/stores';
|
import { settings, user, shortCodesToEmojis } from '$lib/stores';
|
||||||
|
|
||||||
import { WEBUI_BASE_URL } from '$lib/constants';
|
import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
|
||||||
|
|
||||||
import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
|
import Markdown from '$lib/components/chat/Messages/Markdown.svelte';
|
||||||
import ProfileImage from '$lib/components/chat/Messages/ProfileImage.svelte';
|
import ProfileImage from '$lib/components/chat/Messages/ProfileImage.svelte';
|
||||||
|
|
@ -34,6 +34,8 @@
|
||||||
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
|
import ChevronRight from '$lib/components/icons/ChevronRight.svelte';
|
||||||
import { formatDate } from '$lib/utils';
|
import { formatDate } from '$lib/utils';
|
||||||
import Emoji from '$lib/components/common/Emoji.svelte';
|
import Emoji from '$lib/components/common/Emoji.svelte';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import Skeleton from '$lib/components/chat/Messages/Skeleton.svelte';
|
||||||
|
|
||||||
export let message;
|
export let message;
|
||||||
export let showUserProfile = true;
|
export let showUserProfile = true;
|
||||||
|
|
@ -138,12 +140,20 @@
|
||||||
>
|
>
|
||||||
<div class={`shrink-0 mr-3 w-9`}>
|
<div class={`shrink-0 mr-3 w-9`}>
|
||||||
{#if showUserProfile}
|
{#if showUserProfile}
|
||||||
|
{#if message?.meta?.model_id}
|
||||||
|
<img
|
||||||
|
src={`${WEBUI_API_BASE_URL}/models/model/profile/image?id=${message.meta.model_id}`}
|
||||||
|
alt={message.meta.model_name ?? message.meta.model_id}
|
||||||
|
class="size-8 translate-y-1 ml-0.5 object-cover rounded-full"
|
||||||
|
/>
|
||||||
|
{:else}
|
||||||
<ProfilePreview user={message.user}>
|
<ProfilePreview user={message.user}>
|
||||||
<ProfileImage
|
<ProfileImage
|
||||||
src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
|
src={message.user?.profile_image_url ?? `${WEBUI_BASE_URL}/static/favicon.png`}
|
||||||
className={'size-8 translate-y-1 ml-0.5'}
|
className={'size-8 translate-y-1 ml-0.5'}
|
||||||
/>
|
/>
|
||||||
</ProfilePreview>
|
</ProfilePreview>
|
||||||
|
{/if}
|
||||||
{:else}
|
{:else}
|
||||||
<!-- <div class="w-7 h-7 rounded-full bg-transparent" /> -->
|
<!-- <div class="w-7 h-7 rounded-full bg-transparent" /> -->
|
||||||
|
|
||||||
|
|
@ -163,7 +173,11 @@
|
||||||
{#if showUserProfile}
|
{#if showUserProfile}
|
||||||
<Name>
|
<Name>
|
||||||
<div class=" self-end text-base shrink-0 font-medium truncate">
|
<div class=" self-end text-base shrink-0 font-medium truncate">
|
||||||
|
{#if message?.meta?.model_id}
|
||||||
|
{message?.meta?.model_name ?? message?.meta?.model_id}
|
||||||
|
{:else}
|
||||||
{message?.user?.name}
|
{message?.user?.name}
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if message.created_at}
|
{#if message.created_at}
|
||||||
|
|
@ -251,12 +265,16 @@
|
||||||
</div>
|
</div>
|
||||||
{:else}
|
{:else}
|
||||||
<div class=" min-w-full markdown-prose">
|
<div class=" min-w-full markdown-prose">
|
||||||
|
{#if (message?.content ?? '').trim() === '' && message?.meta?.model_id}
|
||||||
|
<Skeleton />
|
||||||
|
{:else}
|
||||||
<Markdown
|
<Markdown
|
||||||
id={message.id}
|
id={message.id}
|
||||||
content={message.content}
|
content={message.content}
|
||||||
/>{#if message.created_at !== message.updated_at}<span class="text-gray-500 text-[10px]"
|
/>{#if message.created_at !== message.updated_at && (message?.meta?.model_id ?? null) === null}<span
|
||||||
>(edited)</span
|
class="text-gray-500 text-[10px]">({$i18n.t('edited')})</span
|
||||||
>{/if}
|
>{/if}
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if (message?.reactions ?? []).length > 0}
|
{#if (message?.reactions ?? []).length > 0}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@
|
||||||
|
|
||||||
{#if user}
|
{#if user}
|
||||||
<LinkPreview.Content
|
<LinkPreview.Content
|
||||||
class="w-full max-w-[260px] rounded-2xl border border-gray-100 dark:border-gray-800 z-50 bg-white dark:bg-gray-850 dark:text-white shadow-lg transition"
|
class="w-full max-w-[260px] rounded-2xl border border-gray-100 dark:border-gray-800 z-999 bg-white dark:bg-gray-850 dark:text-white shadow-lg transition"
|
||||||
{side}
|
{side}
|
||||||
{align}
|
{align}
|
||||||
{sideOffset}
|
{sideOffset}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue