mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-14 21:35:19 +00:00
fix: Function/Pipeline valves not triggered for async chat completition
This commit is contained in:
parent
f47214314b
commit
fe178e6d33
2 changed files with 131 additions and 48 deletions
|
|
@ -71,7 +71,7 @@ from open_webui.models.models import Models
|
||||||
from open_webui.retrieval.utils import get_sources_from_items
|
from open_webui.retrieval.utils import get_sources_from_items
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.chat import generate_chat_completion
|
from open_webui.utils.chat import generate_chat_completion, chat_completed
|
||||||
from open_webui.utils.task import (
|
from open_webui.utils.task import (
|
||||||
get_task_model_id,
|
get_task_model_id,
|
||||||
rag_template,
|
rag_template,
|
||||||
|
|
@ -1730,6 +1730,16 @@ async def process_chat_response(
|
||||||
content = response_data["choices"][0]["message"]["content"]
|
content = response_data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
|
|
||||||
|
await dispatch_chat_completed()
|
||||||
|
|
||||||
|
if outlet_content_override is not None:
|
||||||
|
data["content"] = outlet_content_override
|
||||||
|
if collected_sources:
|
||||||
|
data["sources"] = collected_sources
|
||||||
|
if latest_usage:
|
||||||
|
data["usage"] = latest_usage
|
||||||
|
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
"type": "chat:completion",
|
"type": "chat:completion",
|
||||||
|
|
@ -2251,6 +2261,102 @@ async def process_chat_response(
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
latest_usage = None
|
||||||
|
completion_dispatched = False
|
||||||
|
collected_sources = []
|
||||||
|
source_hashes = set()
|
||||||
|
outlet_result_data = None
|
||||||
|
outlet_content_override = None
|
||||||
|
|
||||||
|
def extend_sources(items):
|
||||||
|
if not items:
|
||||||
|
return
|
||||||
|
for item in items:
|
||||||
|
try:
|
||||||
|
key = json.dumps(item, sort_keys=True)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
key = None
|
||||||
|
if key and key in source_hashes:
|
||||||
|
continue
|
||||||
|
if key:
|
||||||
|
source_hashes.add(key)
|
||||||
|
collected_sources.append(item)
|
||||||
|
|
||||||
|
async def dispatch_chat_completed():
|
||||||
|
nonlocal completion_dispatched, outlet_result_data, outlet_content_override, latest_usage
|
||||||
|
if completion_dispatched:
|
||||||
|
return outlet_result_data
|
||||||
|
|
||||||
|
base_messages = [dict(message) for message in form_data.get("messages", [])]
|
||||||
|
generated_messages = convert_content_blocks_to_messages(
|
||||||
|
content_blocks, raw=True
|
||||||
|
)
|
||||||
|
final_messages = [*base_messages, *generated_messages]
|
||||||
|
|
||||||
|
if final_messages:
|
||||||
|
last_message = final_messages[-1]
|
||||||
|
if isinstance(last_message, dict) and last_message.get("role") == "assistant":
|
||||||
|
last_message = {**last_message}
|
||||||
|
if collected_sources:
|
||||||
|
last_message["sources"] = collected_sources
|
||||||
|
if latest_usage:
|
||||||
|
last_message["usage"] = latest_usage
|
||||||
|
final_messages[-1] = last_message
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model_id,
|
||||||
|
"messages": final_messages,
|
||||||
|
"chat_id": metadata["chat_id"],
|
||||||
|
"session_id": metadata["session_id"],
|
||||||
|
"id": metadata["message_id"],
|
||||||
|
"model_item": model,
|
||||||
|
}
|
||||||
|
if metadata.get("filter_ids"):
|
||||||
|
payload["filter_ids"] = metadata["filter_ids"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
outlet_result_data = await chat_completed(request, payload, user)
|
||||||
|
|
||||||
|
if isinstance(outlet_result_data, dict):
|
||||||
|
extend_sources(outlet_result_data.get("sources"))
|
||||||
|
message_updates = outlet_result_data.get("messages")
|
||||||
|
if isinstance(message_updates, list):
|
||||||
|
for message_update in message_updates:
|
||||||
|
if not isinstance(message_update, dict):
|
||||||
|
continue
|
||||||
|
if message_update.get("id") != metadata["message_id"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if message_update.get("sources"):
|
||||||
|
extend_sources(message_update.get("sources"))
|
||||||
|
|
||||||
|
usage_update = message_update.get("usage")
|
||||||
|
if usage_update:
|
||||||
|
try:
|
||||||
|
latest_usage = dict(usage_update)
|
||||||
|
except Exception:
|
||||||
|
latest_usage = usage_update
|
||||||
|
|
||||||
|
outlet_content_override = message_update.get("content")
|
||||||
|
|
||||||
|
try:
|
||||||
|
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||||
|
metadata["chat_id"],
|
||||||
|
metadata["message_id"],
|
||||||
|
{
|
||||||
|
**message_update,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed to upsert outlet message: {e}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"chat_completed outlet failed: {e}")
|
||||||
|
finally:
|
||||||
|
completion_dispatched = True
|
||||||
|
|
||||||
|
return outlet_result_data
|
||||||
|
|
||||||
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
|
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
|
||||||
DETECT_REASONING_TAGS = reasoning_tags_param is not False
|
DETECT_REASONING_TAGS = reasoning_tags_param is not False
|
||||||
|
|
@ -2272,6 +2378,7 @@ async def process_chat_response(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for event in events:
|
for event in events:
|
||||||
|
extend_sources(event.get("sources"))
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
"type": "chat:completion",
|
"type": "chat:completion",
|
||||||
|
|
@ -2291,6 +2398,7 @@ async def process_chat_response(
|
||||||
async def stream_body_handler(response, form_data):
|
async def stream_body_handler(response, form_data):
|
||||||
nonlocal content
|
nonlocal content
|
||||||
nonlocal content_blocks
|
nonlocal content_blocks
|
||||||
|
nonlocal latest_usage
|
||||||
|
|
||||||
response_tool_calls = []
|
response_tool_calls = []
|
||||||
|
|
||||||
|
|
@ -2349,6 +2457,7 @@ async def process_chat_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
|
extend_sources(data.get("sources"))
|
||||||
if "event" in data and not getattr(
|
if "event" in data and not getattr(
|
||||||
request.state, "direct", False
|
request.state, "direct", False
|
||||||
):
|
):
|
||||||
|
|
@ -2376,6 +2485,7 @@ async def process_chat_response(
|
||||||
usage = data.get("usage", {}) or {}
|
usage = data.get("usage", {}) or {}
|
||||||
usage.update(data.get("timings", {})) # llama.cpp
|
usage.update(data.get("timings", {})) # llama.cpp
|
||||||
if usage:
|
if usage:
|
||||||
|
latest_usage = dict(usage)
|
||||||
await event_emitter(
|
await event_emitter(
|
||||||
{
|
{
|
||||||
"type": "chat:completion",
|
"type": "chat:completion",
|
||||||
|
|
@ -3056,6 +3166,8 @@ async def process_chat_response(
|
||||||
"content": serialize_content_blocks(content_blocks),
|
"content": serialize_content_blocks(content_blocks),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
await dispatch_chat_completed()
|
||||||
|
|
||||||
if response.background is not None:
|
if response.background is not None:
|
||||||
await response.background()
|
await response.background()
|
||||||
|
|
@ -3064,7 +3176,9 @@ async def process_chat_response(
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback to the original response
|
# Fallback to the original response
|
||||||
|
latest_usage = None
|
||||||
async def stream_wrapper(original_generator, events):
|
async def stream_wrapper(original_generator, events):
|
||||||
|
nonlocal latest_usage
|
||||||
def wrap_item(item):
|
def wrap_item(item):
|
||||||
return f"data: {item}\n\n"
|
return f"data: {item}\n\n"
|
||||||
|
|
||||||
|
|
@ -3078,19 +3192,27 @@ async def process_chat_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
if event:
|
if event:
|
||||||
|
extend_sources(event.get("sources"))
|
||||||
yield wrap_item(json.dumps(event))
|
yield wrap_item(json.dumps(event))
|
||||||
|
|
||||||
async for data in original_generator:
|
try:
|
||||||
data, _ = await process_filter_functions(
|
async for data in original_generator:
|
||||||
request=request,
|
data, _ = await process_filter_functions(
|
||||||
filter_functions=filter_functions,
|
request=request,
|
||||||
filter_type="stream",
|
filter_functions=filter_functions,
|
||||||
form_data=data,
|
filter_type="stream",
|
||||||
extra_params=extra_params,
|
form_data=data,
|
||||||
)
|
extra_params=extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
|
extend_sources(data.get("sources"))
|
||||||
|
usage = data.get("usage")
|
||||||
|
if usage:
|
||||||
|
latest_usage = dict(usage)
|
||||||
yield data
|
yield data
|
||||||
|
finally:
|
||||||
|
await dispatch_chat_completed()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_wrapper(response.body_iterator, events),
|
stream_wrapper(response.body_iterator, events),
|
||||||
|
|
|
||||||
|
|
@ -1140,45 +1140,6 @@
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => {
|
const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => {
|
||||||
const res = await chatCompleted(localStorage.token, {
|
|
||||||
model: modelId,
|
|
||||||
messages: messages.map((m) => ({
|
|
||||||
id: m.id,
|
|
||||||
role: m.role,
|
|
||||||
content: m.content,
|
|
||||||
info: m.info ? m.info : undefined,
|
|
||||||
timestamp: m.timestamp,
|
|
||||||
...(m.usage ? { usage: m.usage } : {}),
|
|
||||||
...(m.sources ? { sources: m.sources } : {})
|
|
||||||
})),
|
|
||||||
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
|
|
||||||
model_item: $models.find((m) => m.id === modelId),
|
|
||||||
chat_id: _chatId,
|
|
||||||
session_id: $socket?.id,
|
|
||||||
id: responseMessageId
|
|
||||||
}).catch((error) => {
|
|
||||||
toast.error(`${error}`);
|
|
||||||
messages.at(-1).error = { content: error };
|
|
||||||
|
|
||||||
return null;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (res !== null && res.messages) {
|
|
||||||
// Update chat history with the new messages
|
|
||||||
for (const message of res.messages) {
|
|
||||||
if (message?.id) {
|
|
||||||
// Add null check for message and message.id
|
|
||||||
history.messages[message.id] = {
|
|
||||||
...history.messages[message.id],
|
|
||||||
...(history.messages[message.id].content !== message.content
|
|
||||||
? { originalContent: history.messages[message.id].content }
|
|
||||||
: {}),
|
|
||||||
...message
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await tick();
|
await tick();
|
||||||
|
|
||||||
if ($chatId == _chatId) {
|
if ($chatId == _chatId) {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue