This commit is contained in:
Timothy Jaeryang Baek 2025-11-11 00:19:45 -05:00
parent 5aa2d01c17
commit edfa141ab4
4 changed files with 165 additions and 238 deletions

View file

@ -64,6 +64,7 @@ from open_webui.socket.main import (
app as socket_app,
periodic_usage_pool_cleanup,
get_event_emitter,
get_event_call,
get_models_in_use,
get_active_user_ids,
)
@ -481,7 +482,6 @@ from open_webui.utils.models import (
)
from open_webui.utils.chat import (
generate_chat_completion as chat_completion_handler,
chat_completed as chat_completed_handler,
chat_action as chat_action_handler,
)
from open_webui.utils.embeddings import generate_embeddings
@ -1566,10 +1566,40 @@ async def chat_completion(
detail=str(e),
)
async def process_chat(request, form_data, user, metadata, model):
try:
event_emitter = get_event_emitter(metadata)
event_call = get_event_call(metadata)
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_call,
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
"__oauth_token__": oauth_token,
}
except Exception as e:
log.debug(f"Error setting up extra params: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
async def process_chat(request, form_data, user, metadata, extra_params):
try:
form_data, metadata, events = await process_chat_payload(
request, form_data, user, metadata, model
request, form_data, user, metadata, extra_params
)
response = await chat_completion_handler(request, form_data, user)
@ -1587,7 +1617,14 @@ async def chat_completion(
pass
return await process_chat_response(
request, response, form_data, user, metadata, model, events, tasks
request,
response,
form_data,
user,
metadata,
extra_params,
events,
tasks,
)
except asyncio.CancelledError:
log.info("Chat processing was cancelled")
@ -1646,12 +1683,12 @@ async def chat_completion(
# Asynchronous Chat Processing
task_id, _ = await create_task(
request.app.state.redis,
process_chat(request, form_data, user, metadata, model),
process_chat(request, form_data, user, metadata, extra_params),
id=metadata["chat_id"],
)
return {"status": True, "task_id": task_id}
else:
return await process_chat(request, form_data, user, metadata, model)
return await process_chat(request, form_data, user, metadata, extra_params)
# Alias for chat_completion (Legacy)
@ -1659,25 +1696,6 @@ generate_chat_completions = chat_completion
generate_chat_completion = chat_completion
@app.post("/api/chat/completed")
async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
try:
model_item = form_data.pop("model_item", {})
if model_item.get("direct", False):
request.state.direct = True
request.state.model = model_item
return await chat_completed_handler(request, form_data, user)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
@app.post("/api/chat/actions/{action_id}")
async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)

View file

@ -290,10 +290,9 @@ async def generate_chat_completion(
chat_completion = generate_chat_completion
async def chat_completed(request: Request, form_data: dict, user: Any):
if not request.app.state.MODELS:
await get_all_models(request, user=user)
async def chat_completed(
request: Request, form_data: dict, user, metadata, extra_params
):
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
models = {
request.state.model["id"]: request.state.model,
@ -301,35 +300,19 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
else:
models = request.app.state.MODELS
data = form_data
model_id = data["model"]
model_id = form_data["model"]
if model_id not in models:
raise Exception("Model not found")
model = models[model_id]
try:
data = await process_pipeline_outlet_filter(request, data, user, models)
form_data = await process_pipeline_outlet_filter(
request, form_data, user, models
)
except Exception as e:
return Exception(f"Error: {e}")
metadata = {
"chat_id": data["chat_id"],
"message_id": data["id"],
"filter_ids": data.get("filter_ids", []),
"session_id": data["session_id"],
"user_id": user.id,
}
extra_params = {
"__event_emitter__": get_event_emitter(metadata),
"__event_call__": get_event_call(metadata),
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
}
try:
filter_functions = [
Functions.get_function_by_id(filter_id)
@ -338,14 +321,15 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
)
]
result, _ = await process_filter_functions(
form_data, _ = await process_filter_functions(
request=request,
filter_functions=filter_functions,
filter_type="outlet",
form_data=data,
form_data=form_data,
extra_params=extra_params,
)
return result
return form_data
except Exception as e:
return Exception(f"Error: {e}")

View file

@ -71,7 +71,10 @@ from open_webui.models.models import Models
from open_webui.retrieval.utils import get_sources_from_items
from open_webui.utils.chat import generate_chat_completion
from open_webui.utils.chat import (
generate_chat_completion,
chat_completed,
)
from open_webui.utils.task import (
get_task_model_id,
rag_template,
@ -1079,11 +1082,17 @@ def apply_params_to_form_data(form_data, model):
return form_data
async def process_chat_payload(request, form_data, user, metadata, model):
async def process_chat_payload(request, form_data, user, metadata, extra_params):
# Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
# -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
# -> Chat Files
event_emitter = extra_params.get("__event_emitter__", None)
event_caller = extra_params.get("__event_call__", None)
oauth_token = extra_params.get("__oauth_token__", None)
model = extra_params.get("__model__", None)
form_data = apply_params_to_form_data(form_data, model)
log.debug(f"form_data: {form_data}")
@ -1096,29 +1105,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
except:
pass
event_emitter = get_event_emitter(metadata)
event_call = get_event_call(metadata)
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_call,
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__request__": request,
"__model__": model,
"__oauth_token__": oauth_token,
}
# Initialize events to store additional event to be sent to the client
# Initialize contexts and citation
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
@ -1529,7 +1515,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
async def process_chat_response(
request, response, form_data, user, metadata, model, events, tasks
request, response, form_data, user, metadata, extra_params, events, tasks
):
async def background_tasks_handler():
message = None
@ -1752,18 +1738,9 @@ async def process_chat_response(
except Exception as e:
pass
event_emitter = None
event_caller = None
if (
"session_id" in metadata
and metadata["session_id"]
and "chat_id" in metadata
and metadata["chat_id"]
and "message_id" in metadata
and metadata["message_id"]
):
event_emitter = get_event_emitter(metadata)
event_caller = get_event_call(metadata)
model = extra_params.get("__model__", None)
event_emitter = extra_params.get("__event_emitter__", None)
event_caller = extra_params.get("__event_call__", None)
# Non-streaming response
if not isinstance(response, StreamingResponse):
@ -1832,8 +1809,18 @@ async def process_chat_response(
}
)
title = Chats.get_chat_title_by_id(metadata["chat_id"])
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"role": "assistant",
"content": content,
},
)
title = Chats.get_chat_title_by_id(metadata["chat_id"])
await event_emitter(
{
"type": "chat:completion",
@ -1845,16 +1832,6 @@ async def process_chat_response(
}
)
# Save message in the database
Chats.upsert_message_to_chat_by_id_and_message_id(
metadata["chat_id"],
metadata["message_id"],
{
"role": "assistant",
"content": content,
},
)
# Send a webhook notification if the user is not active
if not get_active_status_by_user_id(user.id):
webhook_url = Users.get_user_webhook_url_by_id(user.id)
@ -1923,32 +1900,12 @@ async def process_chat_response(
):
return response
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
extra_params = {
"__event_emitter__": event_emitter,
"__event_call__": event_caller,
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__oauth_token__": oauth_token,
"__request__": request,
"__model__": model,
}
filter_functions = [
Functions.get_function_by_id(filter_id)
for filter_id in get_sorted_filter_ids(
request, model, metadata.get("filter_ids", [])
)
]
# Streaming response
if event_emitter and event_caller:
task_id = str(uuid4()) # Create a unique task ID.
@ -3163,12 +3120,35 @@ async def process_chat_response(
},
)
completed_res = await chat_completed(
request,
{
"id": metadata.get("message_id"),
"chat_id": metadata.get("chat_id"),
"session_id": metadata.get("session_id"),
"filter_ids": metadata.get("filter_ids", []),
"model": form_data.get("model"),
"messages": [*form_data.get("messages", []), response_message],
},
user,
metadata,
extra_params,
)
if completed_res and completed_res.get("messages"):
for message in completed_res["messages"]:
if response.background is not None:
await response.background()
return await response_handler(response, events)
else:
response_message = {}
# Fallback to the original response
async def stream_wrapper(original_generator, events):
def wrap_item(item):
@ -3198,6 +3178,22 @@ async def process_chat_response(
if data:
yield data
await chat_completed(
request,
{
"id": metadata.get("message_id"),
"chat_id": metadata.get("chat_id"),
"session_id": metadata.get("session_id"),
"filter_ids": metadata.get("filter_ids", []),
"model": form_data.get("model"),
"messages": [*form_data.get("messages", []), response_message],
},
user,
metadata,
extra_params,
)
return StreamingResponse(
stream_wrapper(response.body_iterator, events),
headers=dict(response.headers),

View file

@ -1155,65 +1155,6 @@
});
}
};
const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => {
const res = await chatCompleted(localStorage.token, {
model: modelId,
messages: messages.map((m) => ({
id: m.id,
role: m.role,
content: m.content,
info: m.info ? m.info : undefined,
timestamp: m.timestamp,
...(m.usage ? { usage: m.usage } : {}),
...(m.sources ? { sources: m.sources } : {})
})),
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
model_item: $models.find((m) => m.id === modelId),
chat_id: _chatId,
session_id: $socket?.id,
id: responseMessageId
}).catch((error) => {
toast.error(`${error}`);
messages.at(-1).error = { content: error };
return null;
});
if (res !== null && res.messages) {
// Update chat history with the new messages
for (const message of res.messages) {
if (message?.id) {
// Add null check for message and message.id
history.messages[message.id] = {
...history.messages[message.id],
...(history.messages[message.id].content !== message.content
? { originalContent: history.messages[message.id].content }
: {}),
...message
};
}
}
}
await tick();
if ($chatId == _chatId) {
if (!$temporaryChatEnabled) {
chat = await updateChatById(localStorage.token, _chatId, {
models: selectedModels,
messages: messages,
history: history,
params: params,
files: chatFiles
});
currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage));
}
}
taskIds = null;
};
const chatActionHandler = async (_chatId, actionId, modelId, responseMessageId, event = null) => {
const messages = createMessagesList(history, responseMessageId);
@ -1401,17 +1342,53 @@
}
};
const chatCompletionEventHandler = async (data, message, chatId) => {
const emitChatTTSEvents = (message) => {
const messageContentParts = getMessageContentParts(
removeAllDetails(message.content),
$config?.audio?.tts?.split_on ?? 'punctuation'
);
messageContentParts.pop();
// dispatch only last sentence and make sure it hasn't been dispatched before
if (
messageContentParts.length > 0 &&
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
) {
message.lastSentence = messageContentParts[messageContentParts.length - 1];
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: {
id: message.id,
content: messageContentParts[messageContentParts.length - 1]
}
})
);
}
return message;
};
const chatCompletionEventHandler = async (data, message, _chatId) => {
const { id, done, choices, content, sources, selected_model_id, error, usage } = data;
if (error) {
await handleOpenAIError(error, message);
}
if (usage) {
message.usage = usage;
}
if (sources && !message?.sources) {
message.sources = sources;
}
if (selected_model_id) {
message.selectedModelId = selected_model_id;
message.arena = true;
}
// Raw response handling
if (choices) {
if (choices[0]?.message?.content) {
// Non-stream response
@ -1429,31 +1406,12 @@
}
// Emit chat event for TTS
const messageContentParts = getMessageContentParts(
removeAllDetails(message.content),
$config?.audio?.tts?.split_on ?? 'punctuation'
);
messageContentParts.pop();
// dispatch only last sentence and make sure it hasn't been dispatched before
if (
messageContentParts.length > 0 &&
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
) {
message.lastSentence = messageContentParts[messageContentParts.length - 1];
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: {
id: message.id,
content: messageContentParts[messageContentParts.length - 1]
}
})
);
}
message = emitChatTTSEvents(message);
}
}
}
// Normal response handling
if (content) {
// REALTIME_CHAT_SAVE is disabled
message.content = content;
@ -1463,36 +1421,7 @@
}
// Emit chat event for TTS
const messageContentParts = getMessageContentParts(
removeAllDetails(message.content),
$config?.audio?.tts?.split_on ?? 'punctuation'
);
messageContentParts.pop();
// dispatch only last sentence and make sure it hasn't been dispatched before
if (
messageContentParts.length > 0 &&
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
) {
message.lastSentence = messageContentParts[messageContentParts.length - 1];
eventTarget.dispatchEvent(
new CustomEvent('chat', {
detail: {
id: message.id,
content: messageContentParts[messageContentParts.length - 1]
}
})
);
}
}
if (selected_model_id) {
message.selectedModelId = selected_model_id;
message.arena = true;
}
if (usage) {
message.usage = usage;
message = emitChatTTSEvents(message);
}
history.messages[message.id] = message;
@ -1538,12 +1467,12 @@
scrollToBottom();
}
await chatCompletedHandler(
chatId,
message.model,
message.id,
createMessagesList(history, message.id)
);
if ($chatId == _chatId) {
if (!$temporaryChatEnabled) {
currentChatPage.set(1);
await chats.set(await getChatList(localStorage.token, $currentChatPage));
}
}
}
console.log(data);