mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 20:35:19 +00:00
fix: Function/Pipeline valves not triggered for async chat completition
This commit is contained in:
parent
f47214314b
commit
fe178e6d33
2 changed files with 131 additions and 48 deletions
|
|
@ -71,7 +71,7 @@ from open_webui.models.models import Models
|
|||
from open_webui.retrieval.utils import get_sources_from_items
|
||||
|
||||
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.chat import generate_chat_completion, chat_completed
|
||||
from open_webui.utils.task import (
|
||||
get_task_model_id,
|
||||
rag_template,
|
||||
|
|
@ -1730,6 +1730,16 @@ async def process_chat_response(
|
|||
content = response_data["choices"][0]["message"]["content"]
|
||||
|
||||
if content:
|
||||
|
||||
await dispatch_chat_completed()
|
||||
|
||||
if outlet_content_override is not None:
|
||||
data["content"] = outlet_content_override
|
||||
if collected_sources:
|
||||
data["sources"] = collected_sources
|
||||
if latest_usage:
|
||||
data["usage"] = latest_usage
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
|
|
@ -2251,6 +2261,102 @@ async def process_chat_response(
|
|||
"content": content,
|
||||
}
|
||||
]
|
||||
|
||||
latest_usage = None
|
||||
completion_dispatched = False
|
||||
collected_sources = []
|
||||
source_hashes = set()
|
||||
outlet_result_data = None
|
||||
outlet_content_override = None
|
||||
|
||||
def extend_sources(items):
|
||||
if not items:
|
||||
return
|
||||
for item in items:
|
||||
try:
|
||||
key = json.dumps(item, sort_keys=True)
|
||||
except (TypeError, ValueError):
|
||||
key = None
|
||||
if key and key in source_hashes:
|
||||
continue
|
||||
if key:
|
||||
source_hashes.add(key)
|
||||
collected_sources.append(item)
|
||||
|
||||
async def dispatch_chat_completed():
|
||||
nonlocal completion_dispatched, outlet_result_data, outlet_content_override, latest_usage
|
||||
if completion_dispatched:
|
||||
return outlet_result_data
|
||||
|
||||
base_messages = [dict(message) for message in form_data.get("messages", [])]
|
||||
generated_messages = convert_content_blocks_to_messages(
|
||||
content_blocks, raw=True
|
||||
)
|
||||
final_messages = [*base_messages, *generated_messages]
|
||||
|
||||
if final_messages:
|
||||
last_message = final_messages[-1]
|
||||
if isinstance(last_message, dict) and last_message.get("role") == "assistant":
|
||||
last_message = {**last_message}
|
||||
if collected_sources:
|
||||
last_message["sources"] = collected_sources
|
||||
if latest_usage:
|
||||
last_message["usage"] = latest_usage
|
||||
final_messages[-1] = last_message
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": final_messages,
|
||||
"chat_id": metadata["chat_id"],
|
||||
"session_id": metadata["session_id"],
|
||||
"id": metadata["message_id"],
|
||||
"model_item": model,
|
||||
}
|
||||
if metadata.get("filter_ids"):
|
||||
payload["filter_ids"] = metadata["filter_ids"]
|
||||
|
||||
try:
|
||||
outlet_result_data = await chat_completed(request, payload, user)
|
||||
|
||||
if isinstance(outlet_result_data, dict):
|
||||
extend_sources(outlet_result_data.get("sources"))
|
||||
message_updates = outlet_result_data.get("messages")
|
||||
if isinstance(message_updates, list):
|
||||
for message_update in message_updates:
|
||||
if not isinstance(message_update, dict):
|
||||
continue
|
||||
if message_update.get("id") != metadata["message_id"]:
|
||||
continue
|
||||
|
||||
if message_update.get("sources"):
|
||||
extend_sources(message_update.get("sources"))
|
||||
|
||||
usage_update = message_update.get("usage")
|
||||
if usage_update:
|
||||
try:
|
||||
latest_usage = dict(usage_update)
|
||||
except Exception:
|
||||
latest_usage = usage_update
|
||||
|
||||
outlet_content_override = message_update.get("content")
|
||||
|
||||
try:
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
**message_update,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to upsert outlet message: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
log.warning(f"chat_completed outlet failed: {e}")
|
||||
finally:
|
||||
completion_dispatched = True
|
||||
|
||||
return outlet_result_data
|
||||
|
||||
reasoning_tags_param = metadata.get("params", {}).get("reasoning_tags")
|
||||
DETECT_REASONING_TAGS = reasoning_tags_param is not False
|
||||
|
|
@ -2272,6 +2378,7 @@ async def process_chat_response(
|
|||
|
||||
try:
|
||||
for event in events:
|
||||
extend_sources(event.get("sources"))
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
|
|
@ -2291,6 +2398,7 @@ async def process_chat_response(
|
|||
async def stream_body_handler(response, form_data):
|
||||
nonlocal content
|
||||
nonlocal content_blocks
|
||||
nonlocal latest_usage
|
||||
|
||||
response_tool_calls = []
|
||||
|
||||
|
|
@ -2349,6 +2457,7 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
if data:
|
||||
extend_sources(data.get("sources"))
|
||||
if "event" in data and not getattr(
|
||||
request.state, "direct", False
|
||||
):
|
||||
|
|
@ -2376,6 +2485,7 @@ async def process_chat_response(
|
|||
usage = data.get("usage", {}) or {}
|
||||
usage.update(data.get("timings", {})) # llama.cpp
|
||||
if usage:
|
||||
latest_usage = dict(usage)
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
|
|
@ -3056,6 +3166,8 @@ async def process_chat_response(
|
|||
"content": serialize_content_blocks(content_blocks),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await dispatch_chat_completed()
|
||||
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
|
|
@ -3064,7 +3176,9 @@ async def process_chat_response(
|
|||
|
||||
else:
|
||||
# Fallback to the original response
|
||||
latest_usage = None
|
||||
async def stream_wrapper(original_generator, events):
|
||||
nonlocal latest_usage
|
||||
def wrap_item(item):
|
||||
return f"data: {item}\n\n"
|
||||
|
||||
|
|
@ -3078,19 +3192,27 @@ async def process_chat_response(
|
|||
)
|
||||
|
||||
if event:
|
||||
extend_sources(event.get("sources"))
|
||||
yield wrap_item(json.dumps(event))
|
||||
|
||||
async for data in original_generator:
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_functions=filter_functions,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
try:
|
||||
async for data in original_generator:
|
||||
data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_functions=filter_functions,
|
||||
filter_type="stream",
|
||||
form_data=data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
|
||||
if data:
|
||||
extend_sources(data.get("sources"))
|
||||
usage = data.get("usage")
|
||||
if usage:
|
||||
latest_usage = dict(usage)
|
||||
yield data
|
||||
finally:
|
||||
await dispatch_chat_completed()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
|
|
|
|||
|
|
@ -1140,45 +1140,6 @@
|
|||
}
|
||||
};
|
||||
const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => {
|
||||
const res = await chatCompleted(localStorage.token, {
|
||||
model: modelId,
|
||||
messages: messages.map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
info: m.info ? m.info : undefined,
|
||||
timestamp: m.timestamp,
|
||||
...(m.usage ? { usage: m.usage } : {}),
|
||||
...(m.sources ? { sources: m.sources } : {})
|
||||
})),
|
||||
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
|
||||
model_item: $models.find((m) => m.id === modelId),
|
||||
chat_id: _chatId,
|
||||
session_id: $socket?.id,
|
||||
id: responseMessageId
|
||||
}).catch((error) => {
|
||||
toast.error(`${error}`);
|
||||
messages.at(-1).error = { content: error };
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
if (res !== null && res.messages) {
|
||||
// Update chat history with the new messages
|
||||
for (const message of res.messages) {
|
||||
if (message?.id) {
|
||||
// Add null check for message and message.id
|
||||
history.messages[message.id] = {
|
||||
...history.messages[message.id],
|
||||
...(history.messages[message.id].content !== message.content
|
||||
? { originalContent: history.messages[message.id].content }
|
||||
: {}),
|
||||
...message
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await tick();
|
||||
|
||||
if ($chatId == _chatId) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue