diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index d397471dd9..3d6e3606f9 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -71,7 +71,7 @@ from open_webui.models.models import Models from open_webui.retrieval.utils import get_sources_from_items -from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.chat import generate_chat_completion, chat_completed from open_webui.utils.task import ( get_task_model_id, rag_template, @@ -1825,6 +1825,319 @@ async def process_chat_response( event_emitter = get_event_emitter(metadata) event_caller = get_event_call(metadata) + model_id = form_data.get("model", "") + + def split_content_and_whitespace(content): + content_stripped = content.rstrip() + original_whitespace = ( + content[len(content_stripped) :] + if len(content) > len(content_stripped) + else "" + ) + return content_stripped, original_whitespace + + def is_opening_code_block(content): + backtick_segments = content.split("```") + # Even number of segments means the last backticks are opening a new block + return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 + + def serialize_content_blocks(content_blocks, raw=False): + content = "" + + for block in content_blocks: + if block["type"] == "text": + block_content = block["content"].strip() + if block_content: + content = f"{content}{block_content}\n" + elif block["type"] == "tool_calls": + attributes = block.get("attributes", {}) + + tool_calls = block.get("content", []) + results = block.get("results", []) + + if content and not content.endswith("\n"): + content += "\n" + + if results: + + tool_calls_display_content = "" + for tool_call in tool_calls: + + tool_call_id = tool_call.get("id", "") + tool_name = tool_call.get("function", {}).get( + "name", "" + ) + tool_arguments = tool_call.get("function", {}).get( + "arguments", "" + ) + + tool_result = None + tool_result_files = None + for result in results: + if tool_call_id == result.get("tool_call_id", ""): + tool_result = result.get("content", None) + tool_result_files = result.get("files", None) + break + + if tool_result is not None: + tool_result_embeds = result.get("embeds", "") + tool_calls_display_content = f'{tool_calls_display_content}
\nTool Executed\n
\n' + else: + tool_calls_display_content = f'{tool_calls_display_content}
\nExecuting...\n
\n' + + if not raw: + content = f"{content}{tool_calls_display_content}" + else: + tool_calls_display_content = "" + + for tool_call in tool_calls: + tool_call_id = tool_call.get("id", "") + tool_name = tool_call.get("function", {}).get( + "name", "" + ) + tool_arguments = tool_call.get("function", {}).get( + "arguments", "" + ) + + tool_calls_display_content = f'{tool_calls_display_content}\n
\nExecuting...\n
\n' + + if not raw: + content = f"{content}{tool_calls_display_content}" + + elif block["type"] == "reasoning": + reasoning_display_content = html.escape( + "\n".join( + (f"> {line}" if not line.startswith(">") else line) + for line in block["content"].splitlines() + ) + ) + + reasoning_duration = block.get("duration", None) + + start_tag = block.get("start_tag", "") + end_tag = block.get("end_tag", "") + + if content and not content.endswith("\n"): + content += "\n" + + if reasoning_duration is not None: + if raw: + content = ( + f'{content}{start_tag}{block["content"]}{end_tag}\n' + ) + else: + content = f'{content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' + else: + if raw: + content = ( + f'{content}{start_tag}{block["content"]}{end_tag}\n' + ) + else: + content = f'{content}
\nThinkingļæ½\n{reasoning_display_content}\n
\n' + + elif block["type"] == "code_interpreter": + attributes = block.get("attributes", {}) + output = block.get("output", None) + lang = attributes.get("lang", "") + + content_stripped, original_whitespace = ( + split_content_and_whitespace(content) + ) + if is_opening_code_block(content_stripped): + # Remove trailing backticks that would open a new block + content = ( + content_stripped.rstrip("`").rstrip() + + original_whitespace + ) + else: + # Keep content as is - either closing backticks or no backticks + content = content_stripped + original_whitespace + + if content and not content.endswith("\n"): + content += "\n" + + if output: + output = html.escape(json.dumps(output)) + + if raw: + content = f'{content}\n{block["content"]}\n\n```output\n{output}\n```\n' + else: + content = f'{content}
\nAnalyzed\n```{lang}\n{block["content"]}\n```\n
\n' + else: + if raw: + content = f'{content}\n{block["content"]}\n\n' + else: + content = f'{content}
\nAnalyzing...\n```{lang}\n{block["content"]}\n```\n
\n' + + else: + block_content = str(block["content"]).strip() + if block_content: + content = f"{content}{block['type']}: {block_content}\n" + + return content.strip() + + def convert_content_blocks_to_messages(content_blocks, raw=False): + messages = [] + + temp_blocks = [] + for idx, block in enumerate(content_blocks): + if block["type"] == "tool_calls": + messages.append( + { + "role": "assistant", + "content": serialize_content_blocks(temp_blocks, raw), + "tool_calls": block.get("content"), + } + ) + + results = block.get("results", []) + + for result in results: + messages.append( + { + "role": "tool", + "tool_call_id": result["tool_call_id"], + "content": result.get("content", "") or "", + } + ) + temp_blocks = [] + else: + temp_blocks.append(block) + + if temp_blocks: + content = serialize_content_blocks(temp_blocks, raw) + if content: + messages.append( + { + "role": "assistant", + "content": content, + } + ) + + return messages + + message = Chats.get_message_by_id_and_message_id( + metadata["chat_id"], metadata["message_id"] + ) + + last_assistant_message = None + try: + if form_data["messages"][-1]["role"] == "assistant": + last_assistant_message = get_last_assistant_message( + form_data["messages"] + ) + except Exception as e: + pass + + initial_content = ( + message.get("content", "") + if message + else last_assistant_message if last_assistant_message else "" + ) + + content_blocks = [ + { + "type": "text", + "content": initial_content, + } + ] + + latest_usage = None + completion_dispatched = False + collected_sources = [] + source_hashes = set() + outlet_result_data = None + outlet_content_override = None + + def extend_sources(items): + if not items: + return + for item in items: + try: + key = json.dumps(item, sort_keys=True) + except (TypeError, ValueError): + key = None + if key and key in source_hashes: + continue + if key: + source_hashes.add(key) + collected_sources.append(item) + + async def dispatch_chat_completed(): + nonlocal completion_dispatched, outlet_result_data, outlet_content_override, latest_usage + if completion_dispatched: + return outlet_result_data + + base_messages = [dict(message) for message in form_data.get("messages", [])] + generated_messages = convert_content_blocks_to_messages( + content_blocks, raw=True + ) + final_messages = [*base_messages, *generated_messages] + + if final_messages: + last_message = final_messages[-1] + if isinstance(last_message, dict) and last_message.get("role") == "assistant": + last_message = {**last_message} + if collected_sources: + last_message["sources"] = collected_sources + if latest_usage: + last_message["usage"] = latest_usage + final_messages[-1] = last_message + + payload = { + "model": model_id, + "messages": final_messages, + "chat_id": metadata["chat_id"], + "session_id": metadata["session_id"], + "id": metadata["message_id"], + "model_item": model, + } + if metadata.get("filter_ids"): + payload["filter_ids"] = metadata["filter_ids"] + + try: + outlet_result_data = await chat_completed(request, payload, user) + + if isinstance(outlet_result_data, dict): + extend_sources(outlet_result_data.get("sources")) + message_updates = outlet_result_data.get("messages") + if isinstance(message_updates, list): + for message_update in message_updates: + if not isinstance(message_update, dict): + continue + if message_update.get("id") != metadata["message_id"]: + continue + + if message_update.get("sources"): + extend_sources(message_update.get("sources")) + + usage_update = message_update.get("usage") + if usage_update: + try: + latest_usage = dict(usage_update) + except Exception: + latest_usage = usage_update + + outlet_content_override = message_update.get("content") + + try: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + **message_update, + }, + ) + except Exception as e: + log.debug(f"Failed to upsert outlet message: {e}") + break + except Exception as e: + log.warning(f"chat_completed outlet failed: {e}") + finally: + completion_dispatched = True + + return outlet_result_data + # Non-streaming response if not isinstance(response, StreamingResponse): if event_emitter: @@ -1885,6 +2198,16 @@ async def process_chat_response( content = response_data["choices"][0]["message"]["content"] if content: + + await dispatch_chat_completed() + + if outlet_content_override is not None: + response_data["content"] = outlet_content_override + if collected_sources: + response_data["sources"] = collected_sources + if latest_usage: + response_data["usage"] = latest_usage + await event_emitter( { "type": "chat:completion", @@ -2012,199 +2335,9 @@ async def process_chat_response( # Streaming response if event_emitter and event_caller: task_id = str(uuid4()) # Create a unique task ID. - model_id = form_data.get("model", "") - - def split_content_and_whitespace(content): - content_stripped = content.rstrip() - original_whitespace = ( - content[len(content_stripped) :] - if len(content) > len(content_stripped) - else "" - ) - return content_stripped, original_whitespace - - def is_opening_code_block(content): - backtick_segments = content.split("```") - # Even number of segments means the last backticks are opening a new block - return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 # Handle as a background task async def response_handler(response, events): - def serialize_content_blocks(content_blocks, raw=False): - content = "" - - for block in content_blocks: - if block["type"] == "text": - block_content = block["content"].strip() - if block_content: - content = f"{content}{block_content}\n" - elif block["type"] == "tool_calls": - attributes = block.get("attributes", {}) - - tool_calls = block.get("content", []) - results = block.get("results", []) - - if content and not content.endswith("\n"): - content += "\n" - - if results: - - tool_calls_display_content = "" - for tool_call in tool_calls: - - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("function", {}).get( - "name", "" - ) - tool_arguments = tool_call.get("function", {}).get( - "arguments", "" - ) - - tool_result = None - tool_result_files = None - for result in results: - if tool_call_id == result.get("tool_call_id", ""): - tool_result = result.get("content", None) - tool_result_files = result.get("files", None) - break - - if tool_result is not None: - tool_result_embeds = result.get("embeds", "") - tool_calls_display_content = f'{tool_calls_display_content}
\nTool Executed\n
\n' - else: - tool_calls_display_content = f'{tool_calls_display_content}
\nExecuting...\n
\n' - - if not raw: - content = f"{content}{tool_calls_display_content}" - else: - tool_calls_display_content = "" - - for tool_call in tool_calls: - tool_call_id = tool_call.get("id", "") - tool_name = tool_call.get("function", {}).get( - "name", "" - ) - tool_arguments = tool_call.get("function", {}).get( - "arguments", "" - ) - - tool_calls_display_content = f'{tool_calls_display_content}\n
\nExecuting...\n
\n' - - if not raw: - content = f"{content}{tool_calls_display_content}" - - elif block["type"] == "reasoning": - reasoning_display_content = html.escape( - "\n".join( - (f"> {line}" if not line.startswith(">") else line) - for line in block["content"].splitlines() - ) - ) - - reasoning_duration = block.get("duration", None) - - start_tag = block.get("start_tag", "") - end_tag = block.get("end_tag", "") - - if content and not content.endswith("\n"): - content += "\n" - - if reasoning_duration is not None: - if raw: - content = ( - f'{content}{start_tag}{block["content"]}{end_tag}\n' - ) - else: - content = f'{content}
\nThought for {reasoning_duration} seconds\n{reasoning_display_content}\n
\n' - else: - if raw: - content = ( - f'{content}{start_tag}{block["content"]}{end_tag}\n' - ) - else: - content = f'{content}
\nThinking…\n{reasoning_display_content}\n
\n' - - elif block["type"] == "code_interpreter": - attributes = block.get("attributes", {}) - output = block.get("output", None) - lang = attributes.get("lang", "") - - content_stripped, original_whitespace = ( - split_content_and_whitespace(content) - ) - if is_opening_code_block(content_stripped): - # Remove trailing backticks that would open a new block - content = ( - content_stripped.rstrip("`").rstrip() - + original_whitespace - ) - else: - # Keep content as is - either closing backticks or no backticks - content = content_stripped + original_whitespace - - if content and not content.endswith("\n"): - content += "\n" - - if output: - output = html.escape(json.dumps(output)) - - if raw: - content = f'{content}\n{block["content"]}\n\n```output\n{output}\n```\n' - else: - content = f'{content}
\nAnalyzed\n```{lang}\n{block["content"]}\n```\n
\n' - else: - if raw: - content = f'{content}\n{block["content"]}\n\n' - else: - content = f'{content}
\nAnalyzing...\n```{lang}\n{block["content"]}\n```\n
\n' - - else: - block_content = str(block["content"]).strip() - if block_content: - content = f"{content}{block['type']}: {block_content}\n" - - return content.strip() - - def convert_content_blocks_to_messages(content_blocks, raw=False): - messages = [] - - temp_blocks = [] - for idx, block in enumerate(content_blocks): - if block["type"] == "tool_calls": - messages.append( - { - "role": "assistant", - "content": serialize_content_blocks(temp_blocks, raw), - "tool_calls": block.get("content"), - } - ) - - results = block.get("results", []) - - for result in results: - messages.append( - { - "role": "tool", - "tool_call_id": result["tool_call_id"], - "content": result.get("content", "") or "", - } - ) - temp_blocks = [] - else: - temp_blocks.append(block) - - if temp_blocks: - content = serialize_content_blocks(temp_blocks, raw) - if content: - messages.append( - { - "role": "assistant", - "content": content, - } - ) - - return messages - def tag_content_handler(content_type, tags, content, content_blocks): end_flag = False @@ -2381,33 +2514,23 @@ async def process_chat_response( return content, content_blocks, end_flag - message = Chats.get_message_by_id_and_message_id( - metadata["chat_id"], metadata["message_id"] - ) - - tool_calls = [] - - last_assistant_message = None - try: - if form_data["messages"][-1]["role"] == "assistant": - last_assistant_message = get_last_assistant_message( - form_data["messages"] - ) - except Exception as e: - pass - - content = ( - message.get("content", "") - if message - else last_assistant_message if last_assistant_message else "" - ) + nonlocal content_blocks, latest_usage, completion_dispatched, collected_sources, source_hashes, outlet_result_data, outlet_content_override content_blocks = [ { "type": "text", - "content": content, + "content": initial_content, } ] + latest_usage = None + completion_dispatched = False + collected_sources = [] + source_hashes = set() + outlet_result_data = None + outlet_content_override = None + + content = initial_content + tool_calls = [] reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags") DETECT_REASONING_TAGS = reasoning_tags_param is not False @@ -2429,6 +2552,7 @@ async def process_chat_response( try: for event in events: + extend_sources(event.get("sources")) await event_emitter( { "type": "chat:completion", @@ -2448,6 +2572,7 @@ async def process_chat_response( async def stream_body_handler(response, form_data): nonlocal content nonlocal content_blocks + nonlocal latest_usage response_tool_calls = [] @@ -2506,6 +2631,7 @@ async def process_chat_response( ) if data: + extend_sources(data.get("sources")) if "event" in data and not getattr( request.state, "direct", False ): @@ -2533,6 +2659,7 @@ async def process_chat_response( usage = data.get("usage", {}) or {} usage.update(data.get("timings", {})) # llama.cpp if usage: + latest_usage = dict(usage) await event_emitter( { "type": "chat:completion", @@ -3247,6 +3374,8 @@ async def process_chat_response( "content": serialize_content_blocks(content_blocks), }, ) + finally: + await dispatch_chat_completed() if response.background is not None: await response.background() @@ -3255,7 +3384,11 @@ async def process_chat_response( else: # Fallback to the original response + latest_usage = None + async def stream_wrapper(original_generator, events): + nonlocal latest_usage + def wrap_item(item): return f"data: {item}\n\n" @@ -3269,19 +3402,40 @@ async def process_chat_response( ) if event: + extend_sources(event.get("sources")) yield wrap_item(json.dumps(event)) - async for data in original_generator: - data, _ = await process_filter_functions( - request=request, - filter_functions=filter_functions, - filter_type="stream", - form_data=data, - extra_params=extra_params, - ) + try: + async for data in original_generator: + data, _ = await process_filter_functions( + request=request, + filter_functions=filter_functions, + filter_type="stream", + form_data=data, + extra_params=extra_params, + ) - if data: + if not data: + continue + + if isinstance(data, (bytes, bytearray)): + try: + data = data.decode("utf-8") + except Exception: + data = data.decode("utf-8", "replace") + + if isinstance(data, str): + yield data + continue + + if isinstance(data, dict): + extend_sources(data.get("sources")) + usage = data.get("usage") + if usage: + latest_usage = dict(usage) yield data + finally: + await dispatch_chat_completed() return StreamingResponse( stream_wrapper(response.body_iterator, events), diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index c2ebe59bea..27fc33fcda 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1151,45 +1151,6 @@ } }; const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => { - const res = await chatCompleted(localStorage.token, { - model: modelId, - messages: messages.map((m) => ({ - id: m.id, - role: m.role, - content: m.content, - info: m.info ? m.info : undefined, - timestamp: m.timestamp, - ...(m.usage ? { usage: m.usage } : {}), - ...(m.sources ? { sources: m.sources } : {}) - })), - filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined, - model_item: $models.find((m) => m.id === modelId), - chat_id: _chatId, - session_id: $socket?.id, - id: responseMessageId - }).catch((error) => { - toast.error(`${error}`); - messages.at(-1).error = { content: error }; - - return null; - }); - - if (res !== null && res.messages) { - // Update chat history with the new messages - for (const message of res.messages) { - if (message?.id) { - // Add null check for message and message.id - history.messages[message.id] = { - ...history.messages[message.id], - ...(history.messages[message.id].content !== message.content - ? { originalContent: history.messages[message.id].content } - : {}), - ...message - }; - } - } - } - await tick(); if ($chatId == _chatId) {