diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5d269513f5..5dad3d7904 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1495,7 +1495,7 @@ async def chat_completion( } if metadata.get("chat_id") and (user and user.role != "admin"): - if metadata["chat_id"] != "local": + if not metadata["chat_id"].startswith("local:"): chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id) if chat is None: raise HTTPException( @@ -1522,13 +1522,14 @@ async def chat_completion( response = await chat_completion_handler(request, form_data, user) if metadata.get("chat_id") and metadata.get("message_id"): try: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "model": model_id, - }, - ) + if not metadata["chat_id"].startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "model": model_id, + }, + ) except: pass @@ -1549,13 +1550,14 @@ async def chat_completion( if metadata.get("chat_id") and metadata.get("message_id"): # Update the chat message with the error try: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "error": {"content": str(e)}, - }, - ) + if not metadata["chat_id"].startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "error": {"content": str(e)}, + }, + ) event_emitter = get_event_emitter(metadata) await event_emitter( diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index e481571df4..657533c714 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -653,12 +653,15 @@ def get_event_emitter(request_info, update_db=True): ) ) + chat_id = request_info.get("chat_id", None) + message_id = request_info.get("message_id", None) + emit_tasks = [ sio.emit( "chat-events", { - "chat_id": request_info.get("chat_id", None), - "message_id": request_info.get("message_id", None), + "chat_id": chat_id, + "message_id": message_id, "data": event_data, }, to=session_id, @@ -667,8 +670,11 @@ def get_event_emitter(request_info, update_db=True): ] await asyncio.gather(*emit_tasks) - - if update_db: + if ( + update_db + and message_id + and not request_info.get("chat_id", "").startswith("local:") + ): if "type" in event_data and event_data["type"] == "status": Chats.add_message_status_to_chat_by_id_and_message_id( request_info["chat_id"], diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 2667252b37..377ba54dc3 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -80,6 +80,7 @@ from open_webui.utils.misc import ( add_or_update_system_message, add_or_update_user_message, get_last_user_message, + get_last_user_message_item, get_last_assistant_message, get_system_message, prepend_to_first_user_message_content, @@ -1418,10 +1419,13 @@ async def process_chat_response( request, response, form_data, user, metadata, model, events, tasks ): async def background_tasks_handler(): - messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) - message = messages_map.get(metadata["message_id"]) if messages_map else None + message = None + messages = [] + + if "chat_id" in metadata and not metadata["chat_id"].startswith("local:"): + messages_map = Chats.get_messages_map_by_chat_id(metadata["chat_id"]) + message = messages_map.get(metadata["message_id"]) if messages_map else None - if message: message_list = get_message_list(messages_map, metadata["message_id"]) # Remove details tags and files from the messages. @@ -1454,12 +1458,21 @@ async def process_chat_response( "content": content, } ) + else: + # Local temp chat, get the model and message from the form_data + message = get_last_user_message_item(form_data.get("messages", [])) + messages = form_data.get("messages", []) + if message: + message["model"] = form_data.get("model") + if message and "model" in message: if tasks and messages: if ( TASKS.FOLLOW_UP_GENERATION in tasks and tasks[TASKS.FOLLOW_UP_GENERATION] ): + + print("Generating follow ups") res = await generate_follow_ups( request, { @@ -1490,15 +1503,6 @@ async def process_chat_response( follow_ups = json.loads(follow_ups_string).get( "follow_ups", [] ) - - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "followUps": follow_ups, - }, - ) - await event_emitter( { "type": "chat:message:follow_ups", @@ -1507,17 +1511,93 @@ async def process_chat_response( }, } ) + + if not metadata.get("chat_id", "").startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "followUps": follow_ups, + }, + ) + except Exception as e: pass - if TASKS.TITLE_GENERATION in tasks: - user_message = get_last_user_message(messages) - if user_message and len(user_message) > 100: - user_message = user_message[:100] + "..." + if not metadata.get("chat_id", "").startswith( + "local:" + ): # Only update titles and tags for non-temp chats + if ( + TASKS.TITLE_GENERATION in tasks + and tasks[TASKS.TITLE_GENERATION] + ): + user_message = get_last_user_message(messages) + if user_message and len(user_message) > 100: + user_message = user_message[:100] + "..." - if tasks[TASKS.TITLE_GENERATION]: + if tasks[TASKS.TITLE_GENERATION]: - res = await generate_title( + res = await generate_title( + request, + { + "model": message["model"], + "messages": messages, + "chat_id": metadata["chat_id"], + }, + user, + ) + + if res and isinstance(res, dict): + if len(res.get("choices", [])) == 1: + title_string = ( + res.get("choices", [])[0] + .get("message", {}) + .get( + "content", + message.get("content", user_message), + ) + ) + else: + title_string = "" + + title_string = title_string[ + title_string.find("{") : title_string.rfind("}") + 1 + ] + + try: + title = json.loads(title_string).get( + "title", user_message + ) + except Exception as e: + title = "" + + if not title: + title = messages[0].get("content", user_message) + + Chats.update_chat_title_by_id( + metadata["chat_id"], title + ) + + await event_emitter( + { + "type": "chat:title", + "data": title, + } + ) + elif len(messages) == 2: + title = messages[0].get("content", user_message) + + Chats.update_chat_title_by_id(metadata["chat_id"], title) + + await event_emitter( + { + "type": "chat:title", + "data": message.get("content", user_message), + } + ) + + if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: + res = await generate_chat_tags( request, { "model": message["model"], @@ -1529,89 +1609,32 @@ async def process_chat_response( if res and isinstance(res, dict): if len(res.get("choices", [])) == 1: - title_string = ( + tags_string = ( res.get("choices", [])[0] .get("message", {}) - .get( - "content", message.get("content", user_message) - ) + .get("content", "") ) else: - title_string = "" + tags_string = "" - title_string = title_string[ - title_string.find("{") : title_string.rfind("}") + 1 + tags_string = tags_string[ + tags_string.find("{") : tags_string.rfind("}") + 1 ] try: - title = json.loads(title_string).get( - "title", user_message + tags = json.loads(tags_string).get("tags", []) + Chats.update_chat_tags_by_id( + metadata["chat_id"], tags, user + ) + + await event_emitter( + { + "type": "chat:tags", + "data": tags, + } ) except Exception as e: - title = "" - - if not title: - title = messages[0].get("content", user_message) - - Chats.update_chat_title_by_id(metadata["chat_id"], title) - - await event_emitter( - { - "type": "chat:title", - "data": title, - } - ) - elif len(messages) == 2: - title = messages[0].get("content", user_message) - - Chats.update_chat_title_by_id(metadata["chat_id"], title) - - await event_emitter( - { - "type": "chat:title", - "data": message.get("content", user_message), - } - ) - - if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: - res = await generate_chat_tags( - request, - { - "model": message["model"], - "messages": messages, - "chat_id": metadata["chat_id"], - }, - user, - ) - - if res and isinstance(res, dict): - if len(res.get("choices", [])) == 1: - tags_string = ( - res.get("choices", [])[0] - .get("message", {}) - .get("content", "") - ) - else: - tags_string = "" - - tags_string = tags_string[ - tags_string.find("{") : tags_string.rfind("}") + 1 - ] - - try: - tags = json.loads(tags_string).get("tags", []) - Chats.update_chat_tags_by_id( - metadata["chat_id"], tags, user - ) - - await event_emitter( - { - "type": "chat:tags", - "data": tags, - } - ) - except Exception as e: - pass + pass event_emitter = None event_caller = None diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 05f787bcec..ebdb15c4e3 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -2207,8 +2207,8 @@ selectedFolder.set(null); } else { - _chatId = 'local'; - await chatId.set('local'); + _chatId = `local:${$socket?.id}`; // Use socket id for temporary chat + await chatId.set(_chatId); } await tick(); diff --git a/src/lib/components/chat/Navbar.svelte b/src/lib/components/chat/Navbar.svelte index c8939892dd..755e98e7af 100644 --- a/src/lib/components/chat/Navbar.svelte +++ b/src/lib/components/chat/Navbar.svelte @@ -248,7 +248,7 @@ - {#if $temporaryChatEnabled && $chatId === 'local'} + {#if $temporaryChatEnabled && ($chatId ?? '').startsWith('local:')}