diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 3b3762f31f..c68f3a8705 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -57,6 +57,7 @@ from open_webui.utils.task import ( tools_function_calling_generation_template, ) from open_webui.utils.misc import ( + deep_update, get_message_list, add_or_update_system_message, add_or_update_user_message, @@ -1126,8 +1127,18 @@ async def process_chat_response( for block in content_blocks: if block["type"] == "text": content = f"{content}{block['content'].strip()}\n" - elif block["type"] == "tool": - pass + elif block["type"] == "tool_calls": + attributes = block.get("attributes", {}) + + block_content = block.get("content", []) + results = block.get("results", []) + + if results: + if not raw: + content = f'{content}\n
\nTool Executed\n```json\n{block_content}\n```\n```json\n{results}\n```\n
\n' + else: + if not raw: + content = f'{content}\n
\nTool Executing...\n```json\n{block_content}\n```\n
\n' elif block["type"] == "reasoning": reasoning_display_content = "\n".join( @@ -1254,6 +1265,7 @@ async def process_chat_response( metadata["chat_id"], metadata["message_id"] ) + tool_calls = [] content = message.get("content", "") if message else "" content_blocks = [ { @@ -1293,6 +1305,8 @@ async def process_chat_response( nonlocal content nonlocal content_blocks + response_tool_calls = [] + async for line in response.body_iterator: line = line.decode("utf-8") if isinstance(line, bytes) else line data = line @@ -1326,7 +1340,42 @@ async def process_chat_response( if not choices: continue - value = choices[0].get("delta", {}).get("content") + delta = choices[0].get("delta", {}) + delta_tool_calls = delta.get("tool_calls", None) + + if delta_tool_calls: + for delta_tool_call in delta_tool_calls: + tool_call_index = delta_tool_call.get("index") + + if tool_call_index is not None: + if ( + len(response_tool_calls) + <= tool_call_index + ): + response_tool_calls.append( + delta_tool_call + ) + else: + delta_name = delta_tool_call.get( + "function", {} + ).get("name") + delta_arguments = delta_tool_call.get( + "function", {} + ).get("arguments") + + if delta_name: + response_tool_calls[ + tool_call_index + ]["function"]["name"] += delta_name + + if delta_arguments: + response_tool_calls[ + tool_call_index + ]["function"][ + "arguments" + ] += delta_arguments + + value = delta.get("content") if value: content = f"{content}{value}" @@ -1398,6 +1447,29 @@ async def process_chat_response( if not content_blocks[-1]["content"]: content_blocks.pop() + if response_tool_calls: + tool_calls.append(response_tool_calls) + + if response.background: + await response.background() + + await stream_body_handler(response) + + MAX_TOOL_CALL_RETRIES = 5 + tool_call_retries = 0 + + while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES: + tool_call_retries += 1 + + response_tool_calls = tool_calls.pop(0) + + content_blocks.append( + { + "type": "tool_calls", + "content": response_tool_calls, + } + ) + await event_emitter( { "type": "chat:completion", @@ -1407,10 +1479,103 @@ async def process_chat_response( } ) - if response.background: - await response.background() + tools = metadata.get("tools", {}) - await stream_body_handler(response) + results = [] + for tool_call in response_tool_calls: + tool_call_id = tool_call.get("id", "") + tool_name = tool_call.get("function", {}).get("name", "") + + tool_function_params = {} + try: + tool_function_params = json.loads( + tool_call.get("function", {}).get("arguments", "{}") + ) + except Exception as e: + log.debug(e) + + tool_result = None + + if tool_name in tools: + tool = tools[tool_name] + spec = tool.get("spec", {}) + + try: + required_params = spec.get("parameters", {}).get( + "required", [] + ) + tool_function = tool["callable"] + tool_function_params = { + k: v + for k, v in tool_function_params.items() + if k in required_params + } + tool_result = await tool_function( + **tool_function_params + ) + except Exception as e: + tool_result = str(e) + + results.append( + { + "tool_call_id": tool_call_id, + "content": tool_result, + } + ) + + content_blocks[-1]["results"] = results + + content_blocks.append( + { + "type": "text", + "content": "", + } + ) + + await event_emitter( + { + "type": "chat:completion", + "data": { + "content": serialize_content_blocks(content_blocks), + }, + } + ) + + try: + res = await generate_chat_completion( + request, + { + "model": model_id, + "stream": True, + "messages": [ + *form_data["messages"], + { + "role": "assistant", + "content": serialize_content_blocks( + content_blocks, raw=True + ), + "tool_calls": response_tool_calls, + }, + *[ + { + "role": "tool", + "tool_call_id": result["tool_call_id"], + "content": result["content"], + } + for result in results + ], + ], + }, + user, + ) + + if isinstance(res, StreamingResponse): + await stream_body_handler(res) + else: + break + except Exception as e: + log.debug(e) + break if DETECT_CODE_INTERPRETER: MAX_RETRIES = 5 @@ -1472,6 +1637,7 @@ async def process_chat_response( output = str(e) content_blocks[-1]["output"] = output + content_blocks.append( { "type": "text", diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index c3655bf073..b073939219 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -7,6 +7,18 @@ from pathlib import Path from typing import Callable, Optional +import collections.abc + + +def deep_update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = deep_update(d.get(k, {}), v) + else: + d[k] = v + return d + + def get_message_list(messages, message_id): """ Reconstructs a list of messages in order up to the specified message_id. @@ -187,6 +199,7 @@ def openai_chat_chunk_message_template( template = openai_chat_message_template(model) template["object"] = "chat.completion.chunk" + template["choices"][0]["index"] = 0 template["choices"][0]["delta"] = {} if content: