From edfa141ab4adee5cb00e30edfafce61584747145 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 11 Nov 2025 00:19:45 -0500 Subject: [PATCH] refac --- backend/open_webui/main.py | 68 +++++++---- backend/open_webui/utils/chat.py | 38 ++---- backend/open_webui/utils/middleware.py | 134 ++++++++++---------- src/lib/components/chat/Chat.svelte | 163 +++++++------------------ 4 files changed, 165 insertions(+), 238 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index f0aeeab02a..1ea97d24e2 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -64,6 +64,7 @@ from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, get_event_emitter, + get_event_call, get_models_in_use, get_active_user_ids, ) @@ -481,7 +482,6 @@ from open_webui.utils.models import ( ) from open_webui.utils.chat import ( generate_chat_completion as chat_completion_handler, - chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) from open_webui.utils.embeddings import generate_embeddings @@ -1566,10 +1566,40 @@ async def chat_completion( detail=str(e), ) - async def process_chat(request, form_data, user, metadata, model): + try: + event_emitter = get_event_emitter(metadata) + event_call = get_event_call(metadata) + + oauth_token = None + try: + if request.cookies.get("oauth_session_id", None): + oauth_token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + + extra_params = { + "__event_emitter__": event_emitter, + "__event_call__": event_call, + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, + "__metadata__": metadata, + "__request__": request, + "__model__": model, + "__oauth_token__": oauth_token, + } + except Exception as e: + log.debug(f"Error setting up extra params: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + async def process_chat(request, form_data, user, metadata, extra_params): try: form_data, metadata, events = await process_chat_payload( - request, form_data, user, metadata, model + request, form_data, user, metadata, extra_params ) response = await chat_completion_handler(request, form_data, user) @@ -1587,7 +1617,14 @@ async def chat_completion( pass return await process_chat_response( - request, response, form_data, user, metadata, model, events, tasks + request, + response, + form_data, + user, + metadata, + extra_params, + events, + tasks, ) except asyncio.CancelledError: log.info("Chat processing was cancelled") @@ -1646,12 +1683,12 @@ async def chat_completion( # Asynchronous Chat Processing task_id, _ = await create_task( request.app.state.redis, - process_chat(request, form_data, user, metadata, model), + process_chat(request, form_data, user, metadata, extra_params), id=metadata["chat_id"], ) return {"status": True, "task_id": task_id} else: - return await process_chat(request, form_data, user, metadata, model) + return await process_chat(request, form_data, user, metadata, extra_params) # Alias for chat_completion (Legacy) @@ -1659,25 +1696,6 @@ generate_chat_completions = chat_completion generate_chat_completion = chat_completion -@app.post("/api/chat/completed") -async def chat_completed( - request: Request, form_data: dict, user=Depends(get_verified_user) -): - try: - model_item = form_data.pop("model_item", {}) - - if model_item.get("direct", False): - request.state.direct = True - request.state.model = model_item - - return await chat_completed_handler(request, form_data, user) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) - - @app.post("/api/chat/actions/{action_id}") async def chat_action( request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 8b6a0b9da2..2ea83986a3 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -290,10 +290,9 @@ async def generate_chat_completion( chat_completion = generate_chat_completion -async def chat_completed(request: Request, form_data: dict, user: Any): - if not request.app.state.MODELS: - await get_all_models(request, user=user) - +async def chat_completed( + request: Request, form_data: dict, user, metadata, extra_params +): if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, @@ -301,35 +300,19 @@ async def chat_completed(request: Request, form_data: dict, user: Any): else: models = request.app.state.MODELS - data = form_data - model_id = data["model"] + model_id = form_data["model"] if model_id not in models: raise Exception("Model not found") model = models[model_id] try: - data = await process_pipeline_outlet_filter(request, data, user, models) + form_data = await process_pipeline_outlet_filter( + request, form_data, user, models + ) except Exception as e: return Exception(f"Error: {e}") - metadata = { - "chat_id": data["chat_id"], - "message_id": data["id"], - "filter_ids": data.get("filter_ids", []), - "session_id": data["session_id"], - "user_id": user.id, - } - - extra_params = { - "__event_emitter__": get_event_emitter(metadata), - "__event_call__": get_event_call(metadata), - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__request__": request, - "__model__": model, - } - try: filter_functions = [ Functions.get_function_by_id(filter_id) @@ -338,14 +321,15 @@ async def chat_completed(request: Request, form_data: dict, user: Any): ) ] - result, _ = await process_filter_functions( + form_data, _ = await process_filter_functions( request=request, filter_functions=filter_functions, filter_type="outlet", - form_data=data, + form_data=form_data, extra_params=extra_params, ) - return result + + return form_data except Exception as e: return Exception(f"Error: {e}") diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e5b84a3d79..117b839f13 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -71,7 +71,10 @@ from open_webui.models.models import Models from open_webui.retrieval.utils import get_sources_from_items -from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.chat import ( + generate_chat_completion, + chat_completed, +) from open_webui.utils.task import ( get_task_model_id, rag_template, @@ -1079,11 +1082,17 @@ def apply_params_to_form_data(form_data, model): return form_data -async def process_chat_payload(request, form_data, user, metadata, model): +async def process_chat_payload(request, form_data, user, metadata, extra_params): # Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation # -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling # -> Chat Files + event_emitter = extra_params.get("__event_emitter__", None) + event_caller = extra_params.get("__event_call__", None) + + oauth_token = extra_params.get("__oauth_token__", None) + model = extra_params.get("__model__", None) + form_data = apply_params_to_form_data(form_data, model) log.debug(f"form_data: {form_data}") @@ -1096,29 +1105,6 @@ async def process_chat_payload(request, form_data, user, metadata, model): except: pass - event_emitter = get_event_emitter(metadata) - event_call = get_event_call(metadata) - - oauth_token = None - try: - if request.cookies.get("oauth_session_id", None): - oauth_token = await request.app.state.oauth_manager.get_oauth_token( - user.id, - request.cookies.get("oauth_session_id", None), - ) - except Exception as e: - log.error(f"Error getting OAuth token: {e}") - - extra_params = { - "__event_emitter__": event_emitter, - "__event_call__": event_call, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__request__": request, - "__model__": model, - "__oauth_token__": oauth_token, - } - # Initialize events to store additional event to be sent to the client # Initialize contexts and citation if getattr(request.state, "direct", False) and hasattr(request.state, "model"): @@ -1529,7 +1515,7 @@ async def process_chat_payload(request, form_data, user, metadata, model): async def process_chat_response( - request, response, form_data, user, metadata, model, events, tasks + request, response, form_data, user, metadata, extra_params, events, tasks ): async def background_tasks_handler(): message = None @@ -1752,18 +1738,9 @@ async def process_chat_response( except Exception as e: pass - event_emitter = None - event_caller = None - if ( - "session_id" in metadata - and metadata["session_id"] - and "chat_id" in metadata - and metadata["chat_id"] - and "message_id" in metadata - and metadata["message_id"] - ): - event_emitter = get_event_emitter(metadata) - event_caller = get_event_call(metadata) + model = extra_params.get("__model__", None) + event_emitter = extra_params.get("__event_emitter__", None) + event_caller = extra_params.get("__event_call__", None) # Non-streaming response if not isinstance(response, StreamingResponse): @@ -1832,8 +1809,18 @@ async def process_chat_response( } ) - title = Chats.get_chat_title_by_id(metadata["chat_id"]) + # Save message in the database + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "role": "assistant", + "content": content, + }, + ) + + title = Chats.get_chat_title_by_id(metadata["chat_id"]) await event_emitter( { "type": "chat:completion", @@ -1845,16 +1832,6 @@ async def process_chat_response( } ) - # Save message in the database - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "role": "assistant", - "content": content, - }, - ) - # Send a webhook notification if the user is not active if not get_active_status_by_user_id(user.id): webhook_url = Users.get_user_webhook_url_by_id(user.id) @@ -1923,32 +1900,12 @@ async def process_chat_response( ): return response - oauth_token = None - try: - if request.cookies.get("oauth_session_id", None): - oauth_token = await request.app.state.oauth_manager.get_oauth_token( - user.id, - request.cookies.get("oauth_session_id", None), - ) - except Exception as e: - log.error(f"Error getting OAuth token: {e}") - - extra_params = { - "__event_emitter__": event_emitter, - "__event_call__": event_caller, - "__user__": user.model_dump() if isinstance(user, UserModel) else {}, - "__metadata__": metadata, - "__oauth_token__": oauth_token, - "__request__": request, - "__model__": model, - } filter_functions = [ Functions.get_function_by_id(filter_id) for filter_id in get_sorted_filter_ids( request, model, metadata.get("filter_ids", []) ) ] - # Streaming response if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. @@ -3163,12 +3120,35 @@ async def process_chat_response( }, ) + + completed_res = await chat_completed( + request, + { + "id": metadata.get("message_id"), + "chat_id": metadata.get("chat_id"), + "session_id": metadata.get("session_id"), + "filter_ids": metadata.get("filter_ids", []), + + "model": form_data.get("model"), + "messages": [*form_data.get("messages", []), response_message], + }, + user, + metadata, + extra_params, + ) + + if completed_res and completed_res.get("messages"): + for message in completed_res["messages"]: + + + if response.background is not None: await response.background() return await response_handler(response, events) else: + response_message = {} # Fallback to the original response async def stream_wrapper(original_generator, events): def wrap_item(item): @@ -3198,6 +3178,22 @@ async def process_chat_response( if data: yield data + await chat_completed( + request, + { + "id": metadata.get("message_id"), + "chat_id": metadata.get("chat_id"), + "session_id": metadata.get("session_id"), + "filter_ids": metadata.get("filter_ids", []), + + "model": form_data.get("model"), + "messages": [*form_data.get("messages", []), response_message], + }, + user, + metadata, + extra_params, + ) + return StreamingResponse( stream_wrapper(response.body_iterator, events), headers=dict(response.headers), diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 56b93f3643..6b850ec670 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1155,65 +1155,6 @@ }); } }; - const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => { - const res = await chatCompleted(localStorage.token, { - model: modelId, - messages: messages.map((m) => ({ - id: m.id, - role: m.role, - content: m.content, - info: m.info ? m.info : undefined, - timestamp: m.timestamp, - ...(m.usage ? { usage: m.usage } : {}), - ...(m.sources ? { sources: m.sources } : {}) - })), - filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined, - model_item: $models.find((m) => m.id === modelId), - chat_id: _chatId, - session_id: $socket?.id, - id: responseMessageId - }).catch((error) => { - toast.error(`${error}`); - messages.at(-1).error = { content: error }; - - return null; - }); - - if (res !== null && res.messages) { - // Update chat history with the new messages - for (const message of res.messages) { - if (message?.id) { - // Add null check for message and message.id - history.messages[message.id] = { - ...history.messages[message.id], - ...(history.messages[message.id].content !== message.content - ? { originalContent: history.messages[message.id].content } - : {}), - ...message - }; - } - } - } - - await tick(); - - if ($chatId == _chatId) { - if (!$temporaryChatEnabled) { - chat = await updateChatById(localStorage.token, _chatId, { - models: selectedModels, - messages: messages, - history: history, - params: params, - files: chatFiles - }); - - currentChatPage.set(1); - await chats.set(await getChatList(localStorage.token, $currentChatPage)); - } - } - - taskIds = null; - }; const chatActionHandler = async (_chatId, actionId, modelId, responseMessageId, event = null) => { const messages = createMessagesList(history, responseMessageId); @@ -1401,17 +1342,53 @@ } }; - const chatCompletionEventHandler = async (data, message, chatId) => { + const emitChatTTSEvents = (message) => { + const messageContentParts = getMessageContentParts( + removeAllDetails(message.content), + $config?.audio?.tts?.split_on ?? 'punctuation' + ); + messageContentParts.pop(); + + // dispatch only last sentence and make sure it hasn't been dispatched before + if ( + messageContentParts.length > 0 && + messageContentParts[messageContentParts.length - 1] !== message.lastSentence + ) { + message.lastSentence = messageContentParts[messageContentParts.length - 1]; + eventTarget.dispatchEvent( + new CustomEvent('chat', { + detail: { + id: message.id, + content: messageContentParts[messageContentParts.length - 1] + } + }) + ); + } + + return message; + }; + + const chatCompletionEventHandler = async (data, message, _chatId) => { const { id, done, choices, content, sources, selected_model_id, error, usage } = data; if (error) { await handleOpenAIError(error, message); } + if (usage) { + message.usage = usage; + } + if (sources && !message?.sources) { message.sources = sources; } + if (selected_model_id) { + message.selectedModelId = selected_model_id; + message.arena = true; + } + + // Raw response handling if (choices) { if (choices[0]?.message?.content) { // Non-stream response @@ -1429,31 +1406,12 @@ } // Emit chat event for TTS - const messageContentParts = getMessageContentParts( - removeAllDetails(message.content), - $config?.audio?.tts?.split_on ?? 'punctuation' - ); - messageContentParts.pop(); - - // dispatch only last sentence and make sure it hasn't been dispatched before - if ( - messageContentParts.length > 0 && - messageContentParts[messageContentParts.length - 1] !== message.lastSentence - ) { - message.lastSentence = messageContentParts[messageContentParts.length - 1]; - eventTarget.dispatchEvent( - new CustomEvent('chat', { - detail: { - id: message.id, - content: messageContentParts[messageContentParts.length - 1] - } - }) - ); - } + message = emitChatTTSEvents(message); } } } + // Normal response handling if (content) { // REALTIME_CHAT_SAVE is disabled message.content = content; @@ -1463,36 +1421,7 @@ } // Emit chat event for TTS - const messageContentParts = getMessageContentParts( - removeAllDetails(message.content), - $config?.audio?.tts?.split_on ?? 'punctuation' - ); - messageContentParts.pop(); - - // dispatch only last sentence and make sure it hasn't been dispatched before - if ( - messageContentParts.length > 0 && - messageContentParts[messageContentParts.length - 1] !== message.lastSentence - ) { - message.lastSentence = messageContentParts[messageContentParts.length - 1]; - eventTarget.dispatchEvent( - new CustomEvent('chat', { - detail: { - id: message.id, - content: messageContentParts[messageContentParts.length - 1] - } - }) - ); - } - } - - if (selected_model_id) { - message.selectedModelId = selected_model_id; - message.arena = true; - } - - if (usage) { - message.usage = usage; + message = emitChatTTSEvents(message); } history.messages[message.id] = message; @@ -1538,12 +1467,12 @@ scrollToBottom(); } - await chatCompletedHandler( - chatId, - message.model, - message.id, - createMessagesList(history, message.id) - ); + if ($chatId == _chatId) { + if (!$temporaryChatEnabled) { + currentChatPage.set(1); + await chats.set(await getChatList(localStorage.token, $currentChatPage)); + } + } } console.log(data);