From e2d4a6975019fb63ff0230a33580922e98515a5b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 21 Oct 2024 04:24:17 -0700 Subject: [PATCH] refac: title generation --- backend/open_webui/main.py | 8 +++--- backend/open_webui/utils/task.py | 38 +++++++++++++++++------------ src/lib/apis/index.ts | 4 +-- src/lib/components/chat/Chat.svelte | 12 +++++---- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index aeb622e819..c5c1dd8f07 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1472,7 +1472,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE else: - template = """Create a concise, 3-5 word title with an emoji as a title for the prompt in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. + template = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT. Examples of titles: 📉 Stock Market Trends @@ -1482,11 +1482,13 @@ Remote Work Productivity Tips Artificial Intelligence in Healthcare 🎮 Video Game Development Insights -Prompt: {{prompt:middletruncate:8000}}""" + +{{MESSAGES:END:2}} +""" content = title_generation_template( template, - form_data["prompt"], + form_data["messages"], { "name": user.name, "location": user.info.get("location") if user.info else None, diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 7f7876fc56..2c8ec462e8 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -70,22 +70,6 @@ def replace_prompt_variable(template: str, prompt: str) -> str: return template -def title_generation_template( - template: str, prompt: str, user: Optional[dict] = None -) -> str: - template = replace_prompt_variable(template, prompt) - template = prompt_template( - template, - **( - {"user_name": user.get("name"), "user_location": user.get("location")} - if user - else {} - ), - ) - - return template - - def replace_messages_variable(template: str, messages: list[str]) -> str: def replacement_function(match): full_match = match.group(0) @@ -123,6 +107,28 @@ def replace_messages_variable(template: str, messages: list[str]) -> str: return template +# {{prompt:middletruncate:8000}} + + +def title_generation_template( + template: str, messages: list[dict], user: Optional[dict] = None +) -> str: + prompt = get_last_user_message(messages) + template = replace_prompt_variable(template, prompt) + template = replace_messages_variable(template, messages) + + template = prompt_template( + template, + **( + {"user_name": user.get("name"), "user_location": user.get("location")} + if user + else {} + ), + ) + + return template + + def tags_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 2b3218cb16..40d0e03924 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -208,7 +208,7 @@ export const updateTaskConfig = async (token: string, config: object) => { export const generateTitle = async ( token: string = '', model: string, - prompt: string, + messages: string[], chat_id?: string ) => { let error = null; @@ -222,7 +222,7 @@ export const generateTitle = async ( }, body: JSON.stringify({ model: model, - prompt: prompt, + messages: messages, ...(chat_id && { chat_id: chat_id }) }) }) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 566a58ddbf..d3a2cd4f81 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1408,7 +1408,8 @@ const messages = createMessagesList(responseMessageId); if (messages.length == 2 && messages.at(-1).content !== '' && selectedModels[0] === model.id) { window.history.replaceState(history.state, '', `/c/${_chatId}`); - const title = await generateChatTitle(userPrompt); + + const title = await generateChatTitle(messages); await setChatTitle(_chatId, title); if ($settings?.autoTags ?? true) { @@ -1726,7 +1727,8 @@ const messages = createMessagesList(responseMessageId); if (messages.length == 2 && selectedModels[0] === model.id) { window.history.replaceState(history.state, '', `/c/${_chatId}`); - const title = await generateChatTitle(userPrompt); + + const title = await generateChatTitle(messages); await setChatTitle(_chatId, title); if ($settings?.autoTags ?? true) { @@ -1887,12 +1889,12 @@ } }; - const generateChatTitle = async (userPrompt) => { + const generateChatTitle = async (messages) => { if ($settings?.title?.auto ?? true) { const title = await generateTitle( localStorage.token, selectedModels[0], - userPrompt, + messages, $chatId ).catch((error) => { console.error(error); @@ -1901,7 +1903,7 @@ return title; } else { - return `${userPrompt}`; + return 'New Chat'; } };