diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index 3b3762f31f..c68f3a8705 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -57,6 +57,7 @@ from open_webui.utils.task import (
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
+ deep_update,
get_message_list,
add_or_update_system_message,
add_or_update_user_message,
@@ -1126,8 +1127,18 @@ async def process_chat_response(
for block in content_blocks:
if block["type"] == "text":
content = f"{content}{block['content'].strip()}\n"
- elif block["type"] == "tool":
- pass
+ elif block["type"] == "tool_calls":
+ attributes = block.get("attributes", {})
+
+ block_content = block.get("content", [])
+ results = block.get("results", [])
+
+ if results:
+ if not raw:
+ content = f'{content}\n\nTool Executed
\n```json\n{block_content}\n```\n```json\n{results}\n```\n \n'
+ else:
+ if not raw:
+ content = f'{content}\n\nTool Executing...
\n```json\n{block_content}\n```\n \n'
elif block["type"] == "reasoning":
reasoning_display_content = "\n".join(
@@ -1254,6 +1265,7 @@ async def process_chat_response(
metadata["chat_id"], metadata["message_id"]
)
+ tool_calls = []
content = message.get("content", "") if message else ""
content_blocks = [
{
@@ -1293,6 +1305,8 @@ async def process_chat_response(
nonlocal content
nonlocal content_blocks
+ response_tool_calls = []
+
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
@@ -1326,7 +1340,42 @@ async def process_chat_response(
if not choices:
continue
- value = choices[0].get("delta", {}).get("content")
+ delta = choices[0].get("delta", {})
+ delta_tool_calls = delta.get("tool_calls", None)
+
+ if delta_tool_calls:
+ for delta_tool_call in delta_tool_calls:
+ tool_call_index = delta_tool_call.get("index")
+
+ if tool_call_index is not None:
+ if (
+ len(response_tool_calls)
+ <= tool_call_index
+ ):
+ response_tool_calls.append(
+ delta_tool_call
+ )
+ else:
+ delta_name = delta_tool_call.get(
+ "function", {}
+ ).get("name")
+ delta_arguments = delta_tool_call.get(
+ "function", {}
+ ).get("arguments")
+
+ if delta_name:
+ response_tool_calls[
+ tool_call_index
+ ]["function"]["name"] += delta_name
+
+ if delta_arguments:
+ response_tool_calls[
+ tool_call_index
+ ]["function"][
+ "arguments"
+ ] += delta_arguments
+
+ value = delta.get("content")
if value:
content = f"{content}{value}"
@@ -1398,6 +1447,29 @@ async def process_chat_response(
if not content_blocks[-1]["content"]:
content_blocks.pop()
+ if response_tool_calls:
+ tool_calls.append(response_tool_calls)
+
+ if response.background:
+ await response.background()
+
+ await stream_body_handler(response)
+
+ MAX_TOOL_CALL_RETRIES = 5
+ tool_call_retries = 0
+
+ while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
+ tool_call_retries += 1
+
+ response_tool_calls = tool_calls.pop(0)
+
+ content_blocks.append(
+ {
+ "type": "tool_calls",
+ "content": response_tool_calls,
+ }
+ )
+
await event_emitter(
{
"type": "chat:completion",
@@ -1407,10 +1479,103 @@ async def process_chat_response(
}
)
- if response.background:
- await response.background()
+ tools = metadata.get("tools", {})
- await stream_body_handler(response)
+ results = []
+ for tool_call in response_tool_calls:
+ tool_call_id = tool_call.get("id", "")
+ tool_name = tool_call.get("function", {}).get("name", "")
+
+ tool_function_params = {}
+ try:
+ tool_function_params = json.loads(
+ tool_call.get("function", {}).get("arguments", "{}")
+ )
+ except Exception as e:
+ log.debug(e)
+
+ tool_result = None
+
+ if tool_name in tools:
+ tool = tools[tool_name]
+ spec = tool.get("spec", {})
+
+ try:
+ required_params = spec.get("parameters", {}).get(
+ "required", []
+ )
+ tool_function = tool["callable"]
+ tool_function_params = {
+ k: v
+ for k, v in tool_function_params.items()
+ if k in required_params
+ }
+ tool_result = await tool_function(
+ **tool_function_params
+ )
+ except Exception as e:
+ tool_result = str(e)
+
+ results.append(
+ {
+ "tool_call_id": tool_call_id,
+ "content": tool_result,
+ }
+ )
+
+ content_blocks[-1]["results"] = results
+
+ content_blocks.append(
+ {
+ "type": "text",
+ "content": "",
+ }
+ )
+
+ await event_emitter(
+ {
+ "type": "chat:completion",
+ "data": {
+ "content": serialize_content_blocks(content_blocks),
+ },
+ }
+ )
+
+ try:
+ res = await generate_chat_completion(
+ request,
+ {
+ "model": model_id,
+ "stream": True,
+ "messages": [
+ *form_data["messages"],
+ {
+ "role": "assistant",
+ "content": serialize_content_blocks(
+ content_blocks, raw=True
+ ),
+ "tool_calls": response_tool_calls,
+ },
+ *[
+ {
+ "role": "tool",
+ "tool_call_id": result["tool_call_id"],
+ "content": result["content"],
+ }
+ for result in results
+ ],
+ ],
+ },
+ user,
+ )
+
+ if isinstance(res, StreamingResponse):
+ await stream_body_handler(res)
+ else:
+ break
+ except Exception as e:
+ log.debug(e)
+ break
if DETECT_CODE_INTERPRETER:
MAX_RETRIES = 5
@@ -1472,6 +1637,7 @@ async def process_chat_response(
output = str(e)
content_blocks[-1]["output"] = output
+
content_blocks.append(
{
"type": "text",
diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py
index c3655bf073..b073939219 100644
--- a/backend/open_webui/utils/misc.py
+++ b/backend/open_webui/utils/misc.py
@@ -7,6 +7,18 @@ from pathlib import Path
from typing import Callable, Optional
+import collections.abc
+
+
+def deep_update(d, u):
+ for k, v in u.items():
+ if isinstance(v, collections.abc.Mapping):
+ d[k] = deep_update(d.get(k, {}), v)
+ else:
+ d[k] = v
+ return d
+
+
def get_message_list(messages, message_id):
"""
Reconstructs a list of messages in order up to the specified message_id.
@@ -187,6 +199,7 @@ def openai_chat_chunk_message_template(
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
+ template["choices"][0]["index"] = 0
template["choices"][0]["delta"] = {}
if content: