mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
refac
This commit is contained in:
parent
5aa2d01c17
commit
edfa141ab4
4 changed files with 165 additions and 238 deletions
|
|
@ -64,6 +64,7 @@ from open_webui.socket.main import (
|
|||
app as socket_app,
|
||||
periodic_usage_pool_cleanup,
|
||||
get_event_emitter,
|
||||
get_event_call,
|
||||
get_models_in_use,
|
||||
get_active_user_ids,
|
||||
)
|
||||
|
|
@ -481,7 +482,6 @@ from open_webui.utils.models import (
|
|||
)
|
||||
from open_webui.utils.chat import (
|
||||
generate_chat_completion as chat_completion_handler,
|
||||
chat_completed as chat_completed_handler,
|
||||
chat_action as chat_action_handler,
|
||||
)
|
||||
from open_webui.utils.embeddings import generate_embeddings
|
||||
|
|
@ -1566,10 +1566,40 @@ async def chat_completion(
|
|||
detail=str(e),
|
||||
)
|
||||
|
||||
async def process_chat(request, form_data, user, metadata, model):
|
||||
try:
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
"__oauth_token__": oauth_token,
|
||||
}
|
||||
except Exception as e:
|
||||
log.debug(f"Error setting up extra params: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
async def process_chat(request, form_data, user, metadata, extra_params):
|
||||
try:
|
||||
form_data, metadata, events = await process_chat_payload(
|
||||
request, form_data, user, metadata, model
|
||||
request, form_data, user, metadata, extra_params
|
||||
)
|
||||
|
||||
response = await chat_completion_handler(request, form_data, user)
|
||||
|
|
@ -1587,7 +1617,14 @@ async def chat_completion(
|
|||
pass
|
||||
|
||||
return await process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
request,
|
||||
response,
|
||||
form_data,
|
||||
user,
|
||||
metadata,
|
||||
extra_params,
|
||||
events,
|
||||
tasks,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
log.info("Chat processing was cancelled")
|
||||
|
|
@ -1646,12 +1683,12 @@ async def chat_completion(
|
|||
# Asynchronous Chat Processing
|
||||
task_id, _ = await create_task(
|
||||
request.app.state.redis,
|
||||
process_chat(request, form_data, user, metadata, model),
|
||||
process_chat(request, form_data, user, metadata, extra_params),
|
||||
id=metadata["chat_id"],
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
else:
|
||||
return await process_chat(request, form_data, user, metadata, model)
|
||||
return await process_chat(request, form_data, user, metadata, extra_params)
|
||||
|
||||
|
||||
# Alias for chat_completion (Legacy)
|
||||
|
|
@ -1659,25 +1696,6 @@ generate_chat_completions = chat_completion
|
|||
generate_chat_completion = chat_completion
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
try:
|
||||
model_item = form_data.pop("model_item", {})
|
||||
|
||||
if model_item.get("direct", False):
|
||||
request.state.direct = True
|
||||
request.state.model = model_item
|
||||
|
||||
return await chat_completed_handler(request, form_data, user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/actions/{action_id}")
|
||||
async def chat_action(
|
||||
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
|
|
|
|||
|
|
@ -290,10 +290,9 @@ async def generate_chat_completion(
|
|||
chat_completion = generate_chat_completion
|
||||
|
||||
|
||||
async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
if not request.app.state.MODELS:
|
||||
await get_all_models(request, user=user)
|
||||
|
||||
async def chat_completed(
|
||||
request: Request, form_data: dict, user, metadata, extra_params
|
||||
):
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
|
|
@ -301,35 +300,19 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
|
||||
model = models[model_id]
|
||||
|
||||
try:
|
||||
data = await process_pipeline_outlet_filter(request, data, user, models)
|
||||
form_data = await process_pipeline_outlet_filter(
|
||||
request, form_data, user, models
|
||||
)
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
metadata = {
|
||||
"chat_id": data["chat_id"],
|
||||
"message_id": data["id"],
|
||||
"filter_ids": data.get("filter_ids", []),
|
||||
"session_id": data["session_id"],
|
||||
"user_id": user.id,
|
||||
}
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
try:
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
|
|
@ -338,14 +321,15 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
|||
)
|
||||
]
|
||||
|
||||
result, _ = await process_filter_functions(
|
||||
form_data, _ = await process_filter_functions(
|
||||
request=request,
|
||||
filter_functions=filter_functions,
|
||||
filter_type="outlet",
|
||||
form_data=data,
|
||||
form_data=form_data,
|
||||
extra_params=extra_params,
|
||||
)
|
||||
return result
|
||||
|
||||
return form_data
|
||||
except Exception as e:
|
||||
return Exception(f"Error: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,10 @@ from open_webui.models.models import Models
|
|||
from open_webui.retrieval.utils import get_sources_from_items
|
||||
|
||||
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.utils.chat import (
|
||||
generate_chat_completion,
|
||||
chat_completed,
|
||||
)
|
||||
from open_webui.utils.task import (
|
||||
get_task_model_id,
|
||||
rag_template,
|
||||
|
|
@ -1079,11 +1082,17 @@ def apply_params_to_form_data(form_data, model):
|
|||
return form_data
|
||||
|
||||
|
||||
async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
async def process_chat_payload(request, form_data, user, metadata, extra_params):
|
||||
# Pipeline Inlet -> Filter Inlet -> Chat Memory -> Chat Web Search -> Chat Image Generation
|
||||
# -> Chat Code Interpreter (Form Data Update) -> (Default) Chat Tools Function Calling
|
||||
# -> Chat Files
|
||||
|
||||
event_emitter = extra_params.get("__event_emitter__", None)
|
||||
event_caller = extra_params.get("__event_call__", None)
|
||||
|
||||
oauth_token = extra_params.get("__oauth_token__", None)
|
||||
model = extra_params.get("__model__", None)
|
||||
|
||||
form_data = apply_params_to_form_data(form_data, model)
|
||||
log.debug(f"form_data: {form_data}")
|
||||
|
||||
|
|
@ -1096,29 +1105,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
except:
|
||||
pass
|
||||
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_call = get_event_call(metadata)
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
"__oauth_token__": oauth_token,
|
||||
}
|
||||
|
||||
# Initialize events to store additional event to be sent to the client
|
||||
# Initialize contexts and citation
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
|
|
@ -1529,7 +1515,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
|
||||
|
||||
async def process_chat_response(
|
||||
request, response, form_data, user, metadata, model, events, tasks
|
||||
request, response, form_data, user, metadata, extra_params, events, tasks
|
||||
):
|
||||
async def background_tasks_handler():
|
||||
message = None
|
||||
|
|
@ -1752,18 +1738,9 @@ async def process_chat_response(
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
event_emitter = None
|
||||
event_caller = None
|
||||
if (
|
||||
"session_id" in metadata
|
||||
and metadata["session_id"]
|
||||
and "chat_id" in metadata
|
||||
and metadata["chat_id"]
|
||||
and "message_id" in metadata
|
||||
and metadata["message_id"]
|
||||
):
|
||||
event_emitter = get_event_emitter(metadata)
|
||||
event_caller = get_event_call(metadata)
|
||||
model = extra_params.get("__model__", None)
|
||||
event_emitter = extra_params.get("__event_emitter__", None)
|
||||
event_caller = extra_params.get("__event_call__", None)
|
||||
|
||||
# Non-streaming response
|
||||
if not isinstance(response, StreamingResponse):
|
||||
|
|
@ -1832,8 +1809,18 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
title = Chats.get_chat_title_by_id(metadata["chat_id"])
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
|
|
@ -1845,16 +1832,6 @@ async def process_chat_response(
|
|||
}
|
||||
)
|
||||
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# Send a webhook notification if the user is not active
|
||||
if not get_active_status_by_user_id(user.id):
|
||||
webhook_url = Users.get_user_webhook_url_by_id(user.id)
|
||||
|
|
@ -1923,32 +1900,12 @@ async def process_chat_response(
|
|||
):
|
||||
return response
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__oauth_token__": oauth_token,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
}
|
||||
filter_functions = [
|
||||
Functions.get_function_by_id(filter_id)
|
||||
for filter_id in get_sorted_filter_ids(
|
||||
request, model, metadata.get("filter_ids", [])
|
||||
)
|
||||
]
|
||||
|
||||
# Streaming response
|
||||
if event_emitter and event_caller:
|
||||
task_id = str(uuid4()) # Create a unique task ID.
|
||||
|
|
@ -3163,12 +3120,35 @@ async def process_chat_response(
|
|||
},
|
||||
)
|
||||
|
||||
|
||||
completed_res = await chat_completed(
|
||||
request,
|
||||
{
|
||||
"id": metadata.get("message_id"),
|
||||
"chat_id": metadata.get("chat_id"),
|
||||
"session_id": metadata.get("session_id"),
|
||||
"filter_ids": metadata.get("filter_ids", []),
|
||||
|
||||
"model": form_data.get("model"),
|
||||
"messages": [*form_data.get("messages", []), response_message],
|
||||
},
|
||||
user,
|
||||
metadata,
|
||||
extra_params,
|
||||
)
|
||||
|
||||
if completed_res and completed_res.get("messages"):
|
||||
for message in completed_res["messages"]:
|
||||
|
||||
|
||||
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
|
||||
return await response_handler(response, events)
|
||||
|
||||
else:
|
||||
response_message = {}
|
||||
# Fallback to the original response
|
||||
async def stream_wrapper(original_generator, events):
|
||||
def wrap_item(item):
|
||||
|
|
@ -3198,6 +3178,22 @@ async def process_chat_response(
|
|||
if data:
|
||||
yield data
|
||||
|
||||
await chat_completed(
|
||||
request,
|
||||
{
|
||||
"id": metadata.get("message_id"),
|
||||
"chat_id": metadata.get("chat_id"),
|
||||
"session_id": metadata.get("session_id"),
|
||||
"filter_ids": metadata.get("filter_ids", []),
|
||||
|
||||
"model": form_data.get("model"),
|
||||
"messages": [*form_data.get("messages", []), response_message],
|
||||
},
|
||||
user,
|
||||
metadata,
|
||||
extra_params,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
stream_wrapper(response.body_iterator, events),
|
||||
headers=dict(response.headers),
|
||||
|
|
|
|||
|
|
@ -1155,65 +1155,6 @@
|
|||
});
|
||||
}
|
||||
};
|
||||
const chatCompletedHandler = async (_chatId, modelId, responseMessageId, messages) => {
|
||||
const res = await chatCompleted(localStorage.token, {
|
||||
model: modelId,
|
||||
messages: messages.map((m) => ({
|
||||
id: m.id,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
info: m.info ? m.info : undefined,
|
||||
timestamp: m.timestamp,
|
||||
...(m.usage ? { usage: m.usage } : {}),
|
||||
...(m.sources ? { sources: m.sources } : {})
|
||||
})),
|
||||
filter_ids: selectedFilterIds.length > 0 ? selectedFilterIds : undefined,
|
||||
model_item: $models.find((m) => m.id === modelId),
|
||||
chat_id: _chatId,
|
||||
session_id: $socket?.id,
|
||||
id: responseMessageId
|
||||
}).catch((error) => {
|
||||
toast.error(`${error}`);
|
||||
messages.at(-1).error = { content: error };
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
if (res !== null && res.messages) {
|
||||
// Update chat history with the new messages
|
||||
for (const message of res.messages) {
|
||||
if (message?.id) {
|
||||
// Add null check for message and message.id
|
||||
history.messages[message.id] = {
|
||||
...history.messages[message.id],
|
||||
...(history.messages[message.id].content !== message.content
|
||||
? { originalContent: history.messages[message.id].content }
|
||||
: {}),
|
||||
...message
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await tick();
|
||||
|
||||
if ($chatId == _chatId) {
|
||||
if (!$temporaryChatEnabled) {
|
||||
chat = await updateChatById(localStorage.token, _chatId, {
|
||||
models: selectedModels,
|
||||
messages: messages,
|
||||
history: history,
|
||||
params: params,
|
||||
files: chatFiles
|
||||
});
|
||||
|
||||
currentChatPage.set(1);
|
||||
await chats.set(await getChatList(localStorage.token, $currentChatPage));
|
||||
}
|
||||
}
|
||||
|
||||
taskIds = null;
|
||||
};
|
||||
|
||||
const chatActionHandler = async (_chatId, actionId, modelId, responseMessageId, event = null) => {
|
||||
const messages = createMessagesList(history, responseMessageId);
|
||||
|
|
@ -1401,17 +1342,53 @@
|
|||
}
|
||||
};
|
||||
|
||||
const chatCompletionEventHandler = async (data, message, chatId) => {
|
||||
const emitChatTTSEvents = (message) => {
|
||||
const messageContentParts = getMessageContentParts(
|
||||
removeAllDetails(message.content),
|
||||
$config?.audio?.tts?.split_on ?? 'punctuation'
|
||||
);
|
||||
messageContentParts.pop();
|
||||
|
||||
// dispatch only last sentence and make sure it hasn't been dispatched before
|
||||
if (
|
||||
messageContentParts.length > 0 &&
|
||||
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
|
||||
) {
|
||||
message.lastSentence = messageContentParts[messageContentParts.length - 1];
|
||||
eventTarget.dispatchEvent(
|
||||
new CustomEvent('chat', {
|
||||
detail: {
|
||||
id: message.id,
|
||||
content: messageContentParts[messageContentParts.length - 1]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return message;
|
||||
};
|
||||
|
||||
const chatCompletionEventHandler = async (data, message, _chatId) => {
|
||||
const { id, done, choices, content, sources, selected_model_id, error, usage } = data;
|
||||
|
||||
if (error) {
|
||||
await handleOpenAIError(error, message);
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
message.usage = usage;
|
||||
}
|
||||
|
||||
if (sources && !message?.sources) {
|
||||
message.sources = sources;
|
||||
}
|
||||
|
||||
if (selected_model_id) {
|
||||
message.selectedModelId = selected_model_id;
|
||||
message.arena = true;
|
||||
}
|
||||
|
||||
// Raw response handling
|
||||
if (choices) {
|
||||
if (choices[0]?.message?.content) {
|
||||
// Non-stream response
|
||||
|
|
@ -1429,31 +1406,12 @@
|
|||
}
|
||||
|
||||
// Emit chat event for TTS
|
||||
const messageContentParts = getMessageContentParts(
|
||||
removeAllDetails(message.content),
|
||||
$config?.audio?.tts?.split_on ?? 'punctuation'
|
||||
);
|
||||
messageContentParts.pop();
|
||||
|
||||
// dispatch only last sentence and make sure it hasn't been dispatched before
|
||||
if (
|
||||
messageContentParts.length > 0 &&
|
||||
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
|
||||
) {
|
||||
message.lastSentence = messageContentParts[messageContentParts.length - 1];
|
||||
eventTarget.dispatchEvent(
|
||||
new CustomEvent('chat', {
|
||||
detail: {
|
||||
id: message.id,
|
||||
content: messageContentParts[messageContentParts.length - 1]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
message = emitChatTTSEvents(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normal response handling
|
||||
if (content) {
|
||||
// REALTIME_CHAT_SAVE is disabled
|
||||
message.content = content;
|
||||
|
|
@ -1463,36 +1421,7 @@
|
|||
}
|
||||
|
||||
// Emit chat event for TTS
|
||||
const messageContentParts = getMessageContentParts(
|
||||
removeAllDetails(message.content),
|
||||
$config?.audio?.tts?.split_on ?? 'punctuation'
|
||||
);
|
||||
messageContentParts.pop();
|
||||
|
||||
// dispatch only last sentence and make sure it hasn't been dispatched before
|
||||
if (
|
||||
messageContentParts.length > 0 &&
|
||||
messageContentParts[messageContentParts.length - 1] !== message.lastSentence
|
||||
) {
|
||||
message.lastSentence = messageContentParts[messageContentParts.length - 1];
|
||||
eventTarget.dispatchEvent(
|
||||
new CustomEvent('chat', {
|
||||
detail: {
|
||||
id: message.id,
|
||||
content: messageContentParts[messageContentParts.length - 1]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected_model_id) {
|
||||
message.selectedModelId = selected_model_id;
|
||||
message.arena = true;
|
||||
}
|
||||
|
||||
if (usage) {
|
||||
message.usage = usage;
|
||||
message = emitChatTTSEvents(message);
|
||||
}
|
||||
|
||||
history.messages[message.id] = message;
|
||||
|
|
@ -1538,12 +1467,12 @@
|
|||
scrollToBottom();
|
||||
}
|
||||
|
||||
await chatCompletedHandler(
|
||||
chatId,
|
||||
message.model,
|
||||
message.id,
|
||||
createMessagesList(history, message.id)
|
||||
);
|
||||
if ($chatId == _chatId) {
|
||||
if (!$temporaryChatEnabled) {
|
||||
currentChatPage.set(1);
|
||||
await chats.set(await getChatList(localStorage.token, $currentChatPage));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
console.log(data);
|
||||
|
|
|
|||
Loading…
Reference in a new issue