diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 6eb5c1bbdb..d8f2a61257 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -47,7 +47,7 @@ from open_webui.utils.misc import ( ) from open_webui.utils.payload import ( apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, + apply_system_prompt_to_body, ) @@ -253,9 +253,7 @@ async def generate_function_chat_completion( if params: system = params.pop("system", None) form_data = apply_model_params_to_body_openai(params, form_data) - form_data = apply_model_system_prompt_to_body( - system, form_data, metadata, user - ) + form_data = apply_system_prompt_to_body(system, form_data, metadata, user) pipe_id = get_pipe_id(form_data) function_module = get_function_module_by_id(request, pipe_id) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 0bc043c6f1..11bf5b914f 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -47,7 +47,7 @@ from open_webui.utils.misc import ( from open_webui.utils.payload import ( apply_model_params_to_body_ollama, apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, + apply_system_prompt_to_body, ) from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access @@ -1330,7 +1330,7 @@ async def generate_chat_completion( system = params.pop("system", None) payload = apply_model_params_to_body_ollama(params, payload) - payload = apply_model_system_prompt_to_body(system, payload, metadata, user) + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if not bypass_filter and user.role == "user": @@ -1519,7 +1519,7 @@ async def generate_openai_chat_completion( system = params.pop("system", None) payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(system, payload, metadata, user) + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if user.role == "user": diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 875b9c3b75..7ba0c5f68a 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -39,7 +39,7 @@ from open_webui.env import SRC_LOG_LEVELS from open_webui.utils.payload import ( apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, + apply_system_prompt_to_body, ) from open_webui.utils.misc import ( convert_logit_bias_input_to_json, @@ -763,7 +763,7 @@ async def generate_chat_completion( system = params.pop("system", None) payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(system, payload, metadata, user) + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if not bypass_filter and user.role == "user": diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 2dec218d92..e49602094f 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -198,14 +198,7 @@ async def generate_title( else: template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE - content = title_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) + content = title_generation_template(template, form_data["messages"], user) max_tokens = ( models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000) @@ -289,14 +282,7 @@ async def generate_follow_ups( else: template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE - content = follow_up_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) + content = follow_up_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -369,9 +355,7 @@ async def generate_chat_tags( else: template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE - content = tags_generation_template( - template, form_data["messages"], {"name": user.name} - ) + content = tags_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -437,13 +421,7 @@ async def generate_image_prompt( else: template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE - content = image_prompt_generation_template( - template, - form_data["messages"], - user={ - "name": user.name, - }, - ) + content = image_prompt_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -524,9 +502,7 @@ async def generate_queries( else: template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE - content = query_generation_template( - template, form_data["messages"], {"name": user.name} - ) + content = query_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -611,9 +587,7 @@ async def generate_autocompletion( else: template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - content = autocomplete_generation_template( - template, prompt, messages, type, {"name": user.name} - ) + content = autocomplete_generation_template(template, prompt, messages, type, user) payload = { "model": task_model_id, @@ -675,14 +649,7 @@ async def generate_emoji( template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE - content = emoji_generation_template( - template, - form_data["prompt"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) + content = emoji_generation_template(template, form_data["prompt"], user) payload = { "model": task_model_id, diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 48ed2cd489..d9bcce9272 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -74,6 +74,7 @@ from open_webui.utils.misc import ( add_or_update_user_message, get_last_user_message, get_last_assistant_message, + get_system_message, prepend_to_first_user_message_content, convert_logit_bias_input_to_json, ) @@ -84,7 +85,7 @@ from open_webui.utils.filter import ( process_filter_functions, ) from open_webui.utils.code_interpreter import execute_code_jupyter -from open_webui.utils.payload import apply_model_system_prompt_to_body +from open_webui.utils.payload import apply_system_prompt_to_body from open_webui.config import ( @@ -737,6 +738,12 @@ async def process_chat_payload(request, form_data, user, metadata, model): form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") + system_message = get_system_message(form_data.get("messages", [])) + if system_message: + form_data = apply_system_prompt_to_body( + system_message.get("content"), form_data, metadata, user + ) + event_emitter = get_event_emitter(metadata) event_call = get_event_call(metadata) @@ -778,7 +785,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): if folder and folder.data: if "system_prompt" in folder.data: - form_data = apply_model_system_prompt_to_body( + form_data = apply_system_prompt_to_body( folder.data["system_prompt"], form_data, metadata, user ) if "files" in folder.data: diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 316e61c34c..811ba75c9f 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -9,7 +9,7 @@ import json # inplace function: form_data is modified -def apply_model_system_prompt_to_body( +def apply_system_prompt_to_body( system: Optional[str], form_data: dict, metadata: Optional[dict] = None, user=None ) -> dict: if not system: @@ -22,15 +22,7 @@ def apply_model_system_prompt_to_body( system = prompt_variables_template(system, variables) # Legacy (API Usage) - if user: - template_params = { - "user_name": user.name, - "user_location": user.info.get("location") if user.info else None, - } - else: - template_params = {} - - system = prompt_template(system, **template_params) + system = prompt_template(system, user) form_data["messages"] = add_or_update_system_message( system, form_data.get("messages", []) diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 42b44d5167..ee08a97ad2 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -2,7 +2,7 @@ import logging import math import re from datetime import datetime -from typing import Optional +from typing import Optional, Any import uuid @@ -38,9 +38,42 @@ def prompt_variables_template(template: str, variables: dict[str, str]) -> str: return template -def prompt_template( - template: str, user_name: Optional[str] = None, user_location: Optional[str] = None -) -> str: +def prompt_template(template: str, user: Optional[Any] = None) -> str: + if hasattr(user, "model_dump"): + user = user.model_dump() + + USER_VARIABLES = {} + + if isinstance(user, dict): + birth_date = user.get("date_of_birth") + age = None + + if birth_date: + try: + # If birth_date is str, convert to datetime + if isinstance(birth_date, str): + birth_date = datetime.strptime(birth_date, "%Y-%m-%d") + + today = datetime.now() + age = ( + today.year + - birth_date.year + - ((today.month, today.day) < (birth_date.month, birth_date.day)) + ) + except Exception as e: + pass + + USER_VARIABLES = { + "name": str(user.get("name")), + "location": str(user.get("info", {}).get("location")), + "bio": str(user.get("bio")), + "gender": str(user.get("gender")), + "birth_date": str(birth_date), + "age": str(age), + } + + print(USER_VARIABLES) + # Get the current date current_date = datetime.now() @@ -56,19 +89,20 @@ def prompt_template( ) template = template.replace("{{CURRENT_WEEKDAY}}", formatted_weekday) - if user_name: - # Replace {{USER_NAME}} in the template with the user's name - template = template.replace("{{USER_NAME}}", user_name) - else: - # Replace {{USER_NAME}} in the template with "Unknown" - template = template.replace("{{USER_NAME}}", "Unknown") - - if user_location: - # Replace {{USER_LOCATION}} in the template with the current location - template = template.replace("{{USER_LOCATION}}", user_location) - else: - # Replace {{USER_LOCATION}} in the template with "Unknown" - template = template.replace("{{USER_LOCATION}}", "Unknown") + template = template.replace("{{USER_NAME}}", USER_VARIABLES.get("name", "Unknown")) + template = template.replace("{{USER_BIO}}", USER_VARIABLES.get("bio", "Unknown")) + template = template.replace( + "{{USER_GENDER}}", USER_VARIABLES.get("gender", "Unknown") + ) + template = template.replace( + "{{USER_BIRTH_DATE}}", USER_VARIABLES.get("birth_date", "Unknown") + ) + template = template.replace( + "{{USER_AGE}}", str(USER_VARIABLES.get("age", "Unknown")) + ) + template = template.replace( + "{{USER_LOCATION}}", USER_VARIABLES.get("location", "Unknown") + ) return template @@ -189,90 +223,56 @@ def rag_template(template: str, context: str, query: str): def title_generation_template( - template: str, messages: list[dict], user: Optional[dict] = None + template: str, messages: list[dict], user: Optional[Any] = None ) -> str: + prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template def follow_up_generation_template( - template: str, messages: list[dict], user: Optional[dict] = None + template: str, messages: list[dict], user: Optional[Any] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template def tags_generation_template( - template: str, messages: list[dict], user: Optional[dict] = None + template: str, messages: list[dict], user: Optional[Any] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template def image_prompt_generation_template( - template: str, messages: list[dict], user: Optional[dict] = None + template: str, messages: list[dict], user: Optional[Any] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template def emoji_generation_template( - template: str, prompt: str, user: Optional[dict] = None + template: str, prompt: str, user: Optional[Any] = None ) -> str: template = replace_prompt_variable(template, prompt) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template @@ -282,38 +282,24 @@ def autocomplete_generation_template( prompt: str, messages: Optional[list[dict]] = None, type: Optional[str] = None, - user: Optional[dict] = None, + user: Optional[Any] = None, ) -> str: template = template.replace("{{TYPE}}", type if type else "") template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template def query_generation_template( - template: str, messages: list[dict], user: Optional[dict] = None + template: str, messages: list[dict], user: Optional[Any] = None ) -> str: prompt = get_last_user_message(messages) template = replace_prompt_variable(template, prompt) template = replace_messages_variable(template, messages) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) + template = prompt_template(template, user) return template diff --git a/src/lib/components/channel/MessageInput.svelte b/src/lib/components/channel/MessageInput.svelte index f0b9ba0514..86503b381a 100644 --- a/src/lib/components/channel/MessageInput.svelte +++ b/src/lib/components/channel/MessageInput.svelte @@ -12,6 +12,7 @@ blobToFile, compressImage, extractInputVariables, + getAge, getCurrentDateTime, getFormattedDate, getFormattedTime, @@ -31,6 +32,7 @@ import FilesOverlay from '../chat/MessageInput/FilesOverlay.svelte'; import Commands from '../chat/MessageInput/Commands.svelte'; import InputVariablesModal from '../chat/MessageInput/InputVariablesModal.svelte'; + import { getSessionUser } from '$lib/apis/auths'; export let placeholder = $i18n.t('Send a Message'); @@ -116,11 +118,47 @@ text = text.replaceAll('{{USER_LOCATION}}', String(location)); } + const sessionUser = getSessionUser(localStorage.token); + if (text.includes('{{USER_NAME}}')) { - const name = $user?.name || 'User'; + const name = sessionUser?.name || 'User'; text = text.replaceAll('{{USER_NAME}}', name); } + if (text.includes('{{USER_BIO}}')) { + const bio = sessionUser?.bio || ''; + + if (bio) { + text = text.replaceAll('{{USER_BIO}}', bio); + } + } + + if (text.includes('{{USER_GENDER}}')) { + const gender = sessionUser?.gender || ''; + + if (gender) { + text = text.replaceAll('{{USER_GENDER}}', gender); + } + } + + if (text.includes('{{USER_BIRTH_DATE}}')) { + const birthDate = sessionUser?.date_of_birth || ''; + + if (birthDate) { + text = text.replaceAll('{{USER_BIRTH_DATE}}', birthDate); + } + } + + if (text.includes('{{USER_AGE}}')) { + const birthDate = sessionUser?.date_of_birth || ''; + + if (birthDate) { + // calculate age using date + const age = getAge(birthDate); + text = text.replaceAll('{{USER_AGE}}', age); + } + } + if (text.includes('{{USER_LANGUAGE}}')) { const language = localStorage.getItem('locale') || 'en-US'; text = text.replaceAll('{{USER_LANGUAGE}}', language); diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index a081ecd768..720d9ebbc6 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -44,11 +44,6 @@ copyToClipboard, getMessageContentParts, createMessagesList, - extractSentencesForAudio, - promptTemplate, - splitStream, - sleep, - removeDetails, getPromptVariables, processDetails, removeAllDetails @@ -1655,6 +1650,14 @@ ); await tick(); + let userLocation; + if ($settings?.userLocation) { + userLocation = await getAndUpdateUserLocation(localStorage.token).catch((err) => { + console.error(err); + return undefined; + }); + } + const stream = model?.info?.params?.stream_response ?? $settings?.params?.stream_response ?? @@ -1665,16 +1668,7 @@ params?.system || $settings.system ? { role: 'system', - content: `${promptTemplate( - params?.system ?? $settings?.system ?? '', - $user?.name, - $settings?.userLocation - ? await getAndUpdateUserLocation(localStorage.token).catch((err) => { - console.error(err); - return undefined; - }) - : undefined - )}` + content: `${params?.system ?? $settings?.system ?? ''}` } : undefined, ..._messages.map((message) => ({ @@ -1752,15 +1746,7 @@ memory: $settings?.memory ?? false }, variables: { - ...getPromptVariables( - $user?.name, - $settings?.userLocation - ? await getAndUpdateUserLocation(localStorage.token).catch((err) => { - console.error(err); - return undefined; - }) - : undefined - ) + ...getPromptVariables($user?.name, $settings?.userLocation ? userLocation : undefined) }, model_item: $models.find((m) => m.id === model.id), diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 99a755db61..c5f30a8c4e 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -38,6 +38,7 @@ extractContentFromFile, extractCurlyBraceWords, extractInputVariables, + getAge, getCurrentDateTime, getFormattedDate, getFormattedTime, @@ -73,6 +74,7 @@ import { KokoroWorker } from '$lib/workers/KokoroWorker'; import InputVariablesModal from './MessageInput/InputVariablesModal.svelte'; import Voice from '../icons/Voice.svelte'; + import { getSessionUser } from '$lib/apis/auths'; const i18n = getContext('i18n'); export let onChange: Function = () => {}; @@ -176,11 +178,47 @@ text = text.replaceAll('{{USER_LOCATION}}', String(location)); } + const sessionUser = getSessionUser(localStorage.token); + if (text.includes('{{USER_NAME}}')) { - const name = $_user?.name || 'User'; + const name = sessionUser?.name || 'User'; text = text.replaceAll('{{USER_NAME}}', name); } + if (text.includes('{{USER_BIO}}')) { + const bio = sessionUser?.bio || ''; + + if (bio) { + text = text.replaceAll('{{USER_BIO}}', bio); + } + } + + if (text.includes('{{USER_GENDER}}')) { + const gender = sessionUser?.gender || ''; + + if (gender) { + text = text.replaceAll('{{USER_GENDER}}', gender); + } + } + + if (text.includes('{{USER_BIRTH_DATE}}')) { + const birthDate = sessionUser?.date_of_birth || ''; + + if (birthDate) { + text = text.replaceAll('{{USER_BIRTH_DATE}}', birthDate); + } + } + + if (text.includes('{{USER_AGE}}')) { + const birthDate = sessionUser?.date_of_birth || ''; + + if (birthDate) { + // calculate age using date + const age = getAge(birthDate); + text = text.replaceAll('{{USER_AGE}}', age); + } + } + if (text.includes('{{USER_LANGUAGE}}')) { const language = localStorage.getItem('locale') || 'en-US'; text = text.replaceAll('{{USER_LANGUAGE}}', language); @@ -872,7 +910,8 @@ : `${WEBUI_BASE_URL}/static/favicon.png`)} />