diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e04ff6c308..d5b89c8d50 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -57,6 +57,7 @@ from open_webui.utils.logger import start_logger from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, + get_event_emitter, get_models_in_use, get_active_user_ids, ) @@ -466,6 +467,7 @@ from open_webui.utils.redis import get_redis_connection from open_webui.tasks import ( redis_task_command_listener, list_task_ids_by_item_id, + create_task, stop_task, list_tasks, ) # Import from tasks.py @@ -1473,65 +1475,78 @@ async def chat_completion( request.state.metadata = metadata form_data["metadata"] = metadata - form_data, metadata, events = await process_chat_payload( - request, form_data, user, metadata, model - ) except Exception as e: - log.debug(f"Error processing chat payload: {e}") - if metadata.get("chat_id") and metadata.get("message_id"): - # Update the chat message with the error - try: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "error": {"content": str(e)}, - }, - ) - except: - pass - + log.debug(f"Error processing chat metadata: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) - try: - response = await chat_completion_handler(request, form_data, user) - if metadata.get("chat_id") and metadata.get("message_id"): - try: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "model": model_id, - }, - ) - except: - pass + async def process_chat(request, form_data, user, metadata, model): + try: + form_data, metadata, events = await process_chat_payload( + request, form_data, user, metadata, model + ) - return await process_chat_response( - request, response, form_data, user, metadata, model, events, tasks - ) - except Exception as e: - log.debug(f"Error in chat completion: {e}") - if metadata.get("chat_id") and metadata.get("message_id"): - # Update the chat message with the error - try: - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "error": {"content": str(e)}, - }, - ) - except: - pass + response = await chat_completion_handler(request, form_data, user) + if metadata.get("chat_id") and metadata.get("message_id"): + try: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "model": model_id, + }, + ) + except: + pass - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), + return await process_chat_response( + request, response, form_data, user, metadata, model, events, tasks + ) + except asyncio.CancelledError: + log.info("Chat processing was cancelled") + try: + event_emitter = get_event_emitter(metadata) + await event_emitter( + {"type": "task-cancelled"}, + ) + except Exception as e: + pass + except Exception as e: + log.debug(f"Error processing chat payload: {e}") + if metadata.get("chat_id") and metadata.get("message_id"): + # Update the chat message with the error + try: + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "error": {"content": str(e)}, + }, + ) + except: + pass + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + if ( + metadata.get("session_id") + and metadata.get("chat_id") + and metadata.get("message_id") + ): + # Asynchronous Chat Processing + task_id, _ = await create_task( + request.app.state.redis, + process_chat(request, form_data, user, metadata, model), + id=metadata["chat_id"], ) + return {"status": True, "task_id": task_id} + else: + return await process_chat(request, form_data, user, metadata, model) # Alias for chat_completion (Legacy) diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 0eb3aa8853..48ed2cd489 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -86,7 +86,6 @@ from open_webui.utils.filter import ( from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.payload import apply_model_system_prompt_to_body -from open_webui.tasks import create_task from open_webui.config import ( CACHE_DIR, @@ -2600,13 +2599,7 @@ async def process_chat_response( if response.background is not None: await response.background() - # background_tasks.add_task(response_handler, response, events) - task_id, _ = await create_task( - request.app.state.redis, - response_handler(response, events), - id=metadata["chat_id"], - ) - return {"status": True, "task_id": task_id} + return await response_handler(response, events) else: # Fallback to the original response