diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index fc8cc1feb5..e1b0bed853 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -71,7 +71,7 @@ from open_webui.models.models import Models from open_webui.retrieval.utils import get_sources_from_items -from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.chat import generate_chat_completion, chat_completed from open_webui.utils.task import ( get_task_model_id, rag_template, @@ -1730,6 +1730,16 @@ async def process_chat_response( content = response_data["choices"][0]["message"]["content"] if content: + + await dispatch_chat_completed() + + if outlet_content_override is not None: + data["content"] = outlet_content_override + if collected_sources: + data["sources"] = collected_sources + if latest_usage: + data["usage"] = latest_usage + await event_emitter( { "type": "chat:completion", @@ -2251,6 +2261,102 @@ async def process_chat_response( "content": content, } ] + + latest_usage = None + completion_dispatched = False + collected_sources = [] + source_hashes = set() + outlet_result_data = None + outlet_content_override = None + + def extend_sources(items): + if not items: + return + for item in items: + try: + key = json.dumps(item, sort_keys=True) + except (TypeError, ValueError): + key = None + if key and key in source_hashes: + continue + if key: + source_hashes.add(key) + collected_sources.append(item) + + async def dispatch_chat_completed(): + nonlocal completion_dispatched, outlet_result_data, outlet_content_override, latest_usage + if completion_dispatched: + return outlet_result_data + + base_messages = [dict(message) for message in form_data.get("messages", [])] + generated_messages = convert_content_blocks_to_messages( + content_blocks, raw=True + ) + final_messages = [*base_messages, *generated_messages] + + if final_messages: + last_message = final_messages[-1] + if isinstance(last_message, dict) and last_message.get("role") == "assistant": + last_message = {**last_message} + if collected_sources: + last_message["sources"] = collected_sources + if latest_usage: + last_message["usage"] = latest_usage + final_messages[-1] = last_message + + payload = { + "model": model_id, + "messages": final_messages, + "chat_id": metadata["chat_id"], + "session_id": metadata["session_id"], + "id": metadata["message_id"], + "model_item": model, + } + if metadata.get("filter_ids"): + payload["filter_ids"] = metadata["filter_ids"] + + try: + outlet_result_data = await chat_completed(request, payload, user) + + if isinstance(outlet_result_data, dict): + extend_sources(outlet_result_data.get("sources")) + message_updates = outlet_result_data.get("messages") + if isinstance(message_updates, list): + for message_update in message_updates: + if not isinstance(message_update, dict): + continue + if message_update.get("id") != metadata["message_id"]: + continue + + if message_update.get("sources"): + extend_sources(message_update.get("sources")) + + usage_update = message_update.get("usage") + if usage_update: + try: + latest_usage = dict(usage_update) + except Exception: + latest_usage = usage_update + + outlet_content_override = message_update.get("content") + + try: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + **message_update, + }, + ) + except Exception as e: + log.debug(f"Failed to upsert outlet message: {e}") + break + except Exception as e: + log.warning(f"chat_completed outlet failed: {e}") + finally: + completion_dispatched = True + + return outlet_result_data reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags") DETECT_REASONING_TAGS = reasoning_tags_param is not False @@ -2272,6 +2378,7 @@ async def process_chat_response( try: for event in events: + extend_sources(event.get("sources")) await event_emitter( { "type": "chat:completion", @@ -2291,6 +2398,7 @@ async def process_chat_response( async def stream_body_handler(response, form_data): nonlocal content nonlocal content_blocks + nonlocal latest_usage response_tool_calls = [] @@ -2349,6 +2457,7 @@ async def process_chat_response( ) if data: + extend_sources(data.get("sources")) if "event" in data and not getattr( request.state, "direct", False ): @@ -2376,6 +2485,7 @@ async def process_chat_response( usage = data.get("usage", {}) or {} usage.update(data.get("timings", {})) # llama.cpp if usage: + latest_usage = dict(usage) await event_emitter( { "type": "chat:completion", @@ -3056,6 +3166,8 @@ async def process_chat_response( "content": serialize_content_blocks(content_blocks), }, ) + finally: + await dispatch_chat_completed() if response.background is not None: await response.background() @@ -3064,7 +3176,9 @@ async def process_chat_response( else: # Fallback to the original response + latest_usage = None async def stream_wrapper(original_generator, events): + nonlocal latest_usage def wrap_item(item): return f"data: {item}\n\n" @@ -3078,19 +3192,27 @@ async def process_chat_response( ) if event: + extend_sources(event.get("sources")) yield wrap_item(json.dumps(event)) - async for data in original_generator: - data, _ = await process_filter_functions( - request=request, - filter_functions=filter_functions, - filter_type="stream", - form_data=data, - extra_params=extra_params, - ) + try: + async for data in original_generator: + data, _ = await process_filter_functions( + request=request, + filter_functions=filter_functions, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) if data: + extend_sources(data.get("sources")) + usage = data.get("usage") + if usage: + latest_usage = dict(usage) yield data + finally: + await dispatch_chat_completed() return StreamingResponse( stream_wrapper(response.body_iterator, events), diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index dae3ff91e3..0d1e62c23e 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1140,45 +1140,6 @@ } }; const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => { - const res = await chatCompleted(localStorage.token, { - model: modelId, - messages: messages.map((m) => ({ - id: m.id, - role: m.role, - content: m.content, - info: m.info ? m.info : undefined, - timestamp: m.timestamp, - ...(m.usage ? { usage: m.usage } : {}), - ...(m.sources ? { sources: m.sources } : {}) - })), - filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined, - model_item: $models.find((m) => m.id === modelId), - chat_id: _chatId, - session_id: $socket?.id, - id: responseMessageId - }).catch((error) => { - toast.error(`${error}`); - messages.at(-1).error = { content: error }; - - return null; - }); - - if (res !== null && res.messages) { - // Update chat history with the new messages - for (const message of res.messages) { - if (message?.id) { - // Add null check for message and message.id - history.messages[message.id] = { - ...history.messages[message.id], - ...(history.messages[message.id].content !== message.content - ? { originalContent: history.messages[message.id].content } - : {}), - ...message - }; - } - } - } - await tick(); if ($chatId == _chatId) {