mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
feat: native tool calling support
This commit is contained in:
parent
7766a08b70
commit
314b674f32
2 changed files with 185 additions and 6 deletions
|
|
@ -57,6 +57,7 @@ from open_webui.utils.task import (
|
||||||
tools_function_calling_generation_template,
|
tools_function_calling_generation_template,
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
|
deep_update,
|
||||||
get_message_list,
|
get_message_list,
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
add_or_update_user_message,
|
add_or_update_user_message,
|
||||||
|
|
@ -1126,8 +1127,18 @@ async def process_chat_response(
|
||||||
for block in content_blocks:
|
for block in content_blocks:
|
||||||
if block["type"] == "text":
|
if block["type"] == "text":
|
||||||
content = f"{content}{block['content'].strip()}\n"
|
content = f"{content}{block['content'].strip()}\n"
|
||||||
elif block["type"] == "tool":
|
elif block["type"] == "tool_calls":
|
||||||
pass
|
attributes = block.get("attributes", {})
|
||||||
|
|
||||||
|
block_content = block.get("content", [])
|
||||||
|
results = block.get("results", [])
|
||||||
|
|
||||||
|
if results:
|
||||||
|
if not raw:
|
||||||
|
content = f'{content}\n<details type="tool_calls" done="true" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n```json\n{block_content}\n```\n```json\n{results}\n```\n</details>\n'
|
||||||
|
else:
|
||||||
|
if not raw:
|
||||||
|
content = f'{content}\n<details type="tool_calls" done="false">\n<summary>Tool Executing...</summary>\n```json\n{block_content}\n```\n</details>\n'
|
||||||
|
|
||||||
elif block["type"] == "reasoning":
|
elif block["type"] == "reasoning":
|
||||||
reasoning_display_content = "\n".join(
|
reasoning_display_content = "\n".join(
|
||||||
|
|
@ -1254,6 +1265,7 @@ async def process_chat_response(
|
||||||
metadata["chat_id"], metadata["message_id"]
|
metadata["chat_id"], metadata["message_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
content = message.get("content", "") if message else ""
|
content = message.get("content", "") if message else ""
|
||||||
content_blocks = [
|
content_blocks = [
|
||||||
{
|
{
|
||||||
|
|
@ -1293,6 +1305,8 @@ async def process_chat_response(
|
||||||
nonlocal content
|
nonlocal content
|
||||||
nonlocal content_blocks
|
nonlocal content_blocks
|
||||||
|
|
||||||
|
response_tool_calls = []
|
||||||
|
|
||||||
async for line in response.body_iterator:
|
async for line in response.body_iterator:
|
||||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||||
data = line
|
data = line
|
||||||
|
|
@ -1326,7 +1340,42 @@ async def process_chat_response(
|
||||||
if not choices:
|
if not choices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
value = choices[0].get("delta", {}).get("content")
|
delta = choices[0].get("delta", {})
|
||||||
|
delta_tool_calls = delta.get("tool_calls", None)
|
||||||
|
|
||||||
|
if delta_tool_calls:
|
||||||
|
for delta_tool_call in delta_tool_calls:
|
||||||
|
tool_call_index = delta_tool_call.get("index")
|
||||||
|
|
||||||
|
if tool_call_index is not None:
|
||||||
|
if (
|
||||||
|
len(response_tool_calls)
|
||||||
|
<= tool_call_index
|
||||||
|
):
|
||||||
|
response_tool_calls.append(
|
||||||
|
delta_tool_call
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta_name = delta_tool_call.get(
|
||||||
|
"function", {}
|
||||||
|
).get("name")
|
||||||
|
delta_arguments = delta_tool_call.get(
|
||||||
|
"function", {}
|
||||||
|
).get("arguments")
|
||||||
|
|
||||||
|
if delta_name:
|
||||||
|
response_tool_calls[
|
||||||
|
tool_call_index
|
||||||
|
]["function"]["name"] += delta_name
|
||||||
|
|
||||||
|
if delta_arguments:
|
||||||
|
response_tool_calls[
|
||||||
|
tool_call_index
|
||||||
|
]["function"][
|
||||||
|
"arguments"
|
||||||
|
] += delta_arguments
|
||||||
|
|
||||||
|
value = delta.get("content")
|
||||||
|
|
||||||
if value:
|
if value:
|
||||||
content = f"{content}{value}"
|
content = f"{content}{value}"
|
||||||
|
|
@ -1398,6 +1447,29 @@ async def process_chat_response(
|
||||||
if not content_blocks[-1]["content"]:
|
if not content_blocks[-1]["content"]:
|
||||||
content_blocks.pop()
|
content_blocks.pop()
|
||||||
|
|
||||||
|
if response_tool_calls:
|
||||||
|
tool_calls.append(response_tool_calls)
|
||||||
|
|
||||||
|
if response.background:
|
||||||
|
await response.background()
|
||||||
|
|
||||||
|
await stream_body_handler(response)
|
||||||
|
|
||||||
|
MAX_TOOL_CALL_RETRIES = 5
|
||||||
|
tool_call_retries = 0
|
||||||
|
|
||||||
|
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
||||||
|
tool_call_retries += 1
|
||||||
|
|
||||||
|
response_tool_calls = tool_calls.pop(0)
|
||||||
|
|
||||||
|
content_blocks.append(
|
||||||
|
{
|
||||||
|
"type": "tool_calls",
|
||||||
|
"content": response_tool_calls,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
"type": "chat:completion",
|
"type": "chat:completion",
|
||||||
|
|
@ -1407,10 +1479,103 @@ async def process_chat_response(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.background:
|
tools = metadata.get("tools", {})
|
||||||
await response.background()
|
|
||||||
|
|
||||||
await stream_body_handler(response)
|
results = []
|
||||||
|
for tool_call in response_tool_calls:
|
||||||
|
tool_call_id = tool_call.get("id", "")
|
||||||
|
tool_name = tool_call.get("function", {}).get("name", "")
|
||||||
|
|
||||||
|
tool_function_params = {}
|
||||||
|
try:
|
||||||
|
tool_function_params = json.loads(
|
||||||
|
tool_call.get("function", {}).get("arguments", "{}")
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(e)
|
||||||
|
|
||||||
|
tool_result = None
|
||||||
|
|
||||||
|
if tool_name in tools:
|
||||||
|
tool = tools[tool_name]
|
||||||
|
spec = tool.get("spec", {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
required_params = spec.get("parameters", {}).get(
|
||||||
|
"required", []
|
||||||
|
)
|
||||||
|
tool_function = tool["callable"]
|
||||||
|
tool_function_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in tool_function_params.items()
|
||||||
|
if k in required_params
|
||||||
|
}
|
||||||
|
tool_result = await tool_function(
|
||||||
|
**tool_function_params
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
tool_result = str(e)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"content": tool_result,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
content_blocks[-1]["results"] = results
|
||||||
|
|
||||||
|
content_blocks.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"content": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await event_emitter(
|
||||||
|
{
|
||||||
|
"type": "chat:completion",
|
||||||
|
"data": {
|
||||||
|
"content": serialize_content_blocks(content_blocks),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = await generate_chat_completion(
|
||||||
|
request,
|
||||||
|
{
|
||||||
|
"model": model_id,
|
||||||
|
"stream": True,
|
||||||
|
"messages": [
|
||||||
|
*form_data["messages"],
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": serialize_content_blocks(
|
||||||
|
content_blocks, raw=True
|
||||||
|
),
|
||||||
|
"tool_calls": response_tool_calls,
|
||||||
|
},
|
||||||
|
*[
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": result["tool_call_id"],
|
||||||
|
"content": result["content"],
|
||||||
|
}
|
||||||
|
for result in results
|
||||||
|
],
|
||||||
|
],
|
||||||
|
},
|
||||||
|
user,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(res, StreamingResponse):
|
||||||
|
await stream_body_handler(res)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(e)
|
||||||
|
break
|
||||||
|
|
||||||
if DETECT_CODE_INTERPRETER:
|
if DETECT_CODE_INTERPRETER:
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
|
|
@ -1472,6 +1637,7 @@ async def process_chat_response(
|
||||||
output = str(e)
|
output = str(e)
|
||||||
|
|
||||||
content_blocks[-1]["output"] = output
|
content_blocks[-1]["output"] = output
|
||||||
|
|
||||||
content_blocks.append(
|
content_blocks.append(
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,18 @@ from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
|
|
||||||
|
|
||||||
|
def deep_update(d, u):
|
||||||
|
for k, v in u.items():
|
||||||
|
if isinstance(v, collections.abc.Mapping):
|
||||||
|
d[k] = deep_update(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
d[k] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def get_message_list(messages, message_id):
|
def get_message_list(messages, message_id):
|
||||||
"""
|
"""
|
||||||
Reconstructs a list of messages in order up to the specified message_id.
|
Reconstructs a list of messages in order up to the specified message_id.
|
||||||
|
|
@ -187,6 +199,7 @@ def openai_chat_chunk_message_template(
|
||||||
template = openai_chat_message_template(model)
|
template = openai_chat_message_template(model)
|
||||||
template["object"] = "chat.completion.chunk"
|
template["object"] = "chat.completion.chunk"
|
||||||
|
|
||||||
|
template["choices"][0]["index"] = 0
|
||||||
template["choices"][0]["delta"] = {}
|
template["choices"][0]["delta"] = {}
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue