fix: Function/Pipeline valves not triggered for async chat completition

This commit is contained in:
Daniel Elišák 2025-10-27 12:16:04 +01:00
parent f47214314b
commit fe178e6d33
2 changed files with 131 additions and 48 deletions

View file

@ -71,7 +71,7 @@ from open_webui.models.models import Models
from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.chat import generate_chat_completion, chat_completed
from open_webui.utils.task import (
get_task_model_id,
rag_template,
@ -1730,6 +1730,16 @@ async def process_chat_response(
content = response_data["choices"][0]["message"]["content"]
if content:
await dispatch_chat_completed()
if outlet_content_override is not None:
data["content"] = outlet_content_override
if collected_sources:
data["sources"] = collected_sources
if latest_usage:
data["usage"] = latest_usage
await event_emitter(
{
"type": "chat:completion",
@ -2251,6 +2261,102 @@ async def process_chat_response(
"content": content,
}
]
latest_usage = None
completion_dispatched = False
collected_sources = []
source_hashes = set()
outlet_result_data = None
outlet_content_override = None
def extend_sources(items):
if not items:
return
for item in items:
try:
key = json.dumps(item, sort_keys=True)
except (TypeError, ValueError):
key = None
if key and key in source_hashes:
continue
if key:
source_hashes.add(key)
collected_sources.append(item)
async def dispatch_chat_completed():
nonlocal completion_dispatched, outlet_result_data, outlet_content_override, latest_usage
if completion_dispatched:
return outlet_result_data
base_messages = [dict(message) for message in form_data.get("messages", [])]
generated_messages = convert_content_blocks_to_messages(
content_blocks, raw=True
)
final_messages = [*base_messages, *generated_messages]
if final_messages:
last_message = final_messages[-1]
if isinstance(last_message, dict) and last_message.get("role") == "assistant":
last_message = {**last_message}
if collected_sources:
last_message["sources"] = collected_sources
if latest_usage:
last_message["usage"] = latest_usage
final_messages[-1] = last_message
payload = {
"model": model_id,
"messages": final_messages,
"chat_id": metadata["chat_id"],
"session_id": metadata["session_id"],
"id": metadata["message_id"],
"model_item": model,
}
if metadata.get("filter_ids"):
payload["filter_ids"] = metadata["filter_ids"]
try:
outlet_result_data = await chat_completed(request, payload, user)
if isinstance(outlet_result_data, dict):
extend_sources(outlet_result_data.get("sources"))
message_updates = outlet_result_data.get("messages")
if isinstance(message_updates, list):
for message_update in message_updates:
if not isinstance(message_update, dict):
continue
if message_update.get("id") != metadata["message_id"]:
continue
if message_update.get("sources"):
extend_sources(message_update.get("sources"))
usage_update = message_update.get("usage")
if usage_update:
try:
latest_usage = dict(usage_update)
except Exception:
latest_usage = usage_update
outlet_content_override = message_update.get("content")
try:
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
**message_update,
},
)
except Exception as e:
log.debug(f"Failed to upsert outlet message: {e}")
break
except Exception as e:
log.warning(f"chat_completed outlet failed: {e}")
finally:
completion_dispatched = True
return outlet_result_data
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
DETECT_REASONING_TAGS = reasoning_tags_param is not False
@ -2272,6 +2378,7 @@ async def process_chat_response(
try:
for event in events:
extend_sources(event.get("sources"))
await event_emitter(
{
"type": "chat:completion",
@ -2291,6 +2398,7 @@ async def process_chat_response(
async def stream_body_handler(response, form_data):
nonlocal content
nonlocal content_blocks
nonlocal latest_usage
response_tool_calls = []
@ -2349,6 +2457,7 @@ async def process_chat_response(
)
if data:
extend_sources(data.get("sources"))
if "event" in data and not getattr(
request.state, "direct", False
):
@ -2376,6 +2485,7 @@ async def process_chat_response(
usage = data.get("usage", {}) or {}
usage.update(data.get("timings", {})) # llama.cpp
if usage:
latest_usage = dict(usage)
await event_emitter(
{
"type": "chat:completion",
@ -3056,6 +3166,8 @@ async def process_chat_response(
"content": serialize_content_blocks(content_blocks),
},
)
finally:
await dispatch_chat_completed()
if response.background is not None:
await response.background()
@ -3064,7 +3176,9 @@ async def process_chat_response(
else:
# Fallback to the original response
latest_usage = None
async def stream_wrapper(original_generator, events):
nonlocal latest_usage
def wrap_item(item):
return f"data: {item}\n\n"
@ -3078,19 +3192,27 @@ async def process_chat_response(
)
if event:
extend_sources(event.get("sources"))
yield wrap_item(json.dumps(event))
async for data in original_generator:
data, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
try:
async for data in original_generator:
data, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="stream",
form_data=data,
extra_params=extra_params,
)
if data:
extend_sources(data.get("sources"))
usage = data.get("usage")
if usage:
latest_usage = dict(usage)
yield data
finally:
await dispatch_chat_completed()
return StreamingResponse(
stream_wrapper(response.body_iterator, events),

View file

@ -1140,45 +1140,6 @@
}
};
const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => {
const res = await chatCompleted(localStorage.token, {
model: modelId,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
info: m.info ? m.info : undefined,
timestamp: m.timestamp,
...(m.usage ? { usage: m.usage } : {}),
...(m.sources ? { sources: m.sources } : {})
})),
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
model_item: $models.find((m) => m.id === modelId),
chat_id: _chatId,
session_id: $socket?.id,
id: responseMessageId
}).catch((error) => {
toast.error(`${error}`);
messages.at(-1).error = { content: error };
return null;
});
if (res !== null && res.messages) {
// Update chat history with the new messages
for (const message of res.messages) {
if (message?.id) {
// Add null check for message and message.id
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
}
await tick();
if ($chatId == _chatId) {