diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5609289166..d7a3f6e735 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -623,8 +623,21 @@ async def lifespan(app: FastAPI): yield + # In the lifespan shutdown if hasattr(app.state, "redis_task_command_listener"): - app.state.redis_task_command_listener.cancel() + try: + app.state.redis_task_command_listener.cancel() + await app.state.redis_task_command_listener + except ExceptionGroup as eg: + log.error( + f"Multiple errors during listener shutdown: {len(eg.exceptions)} exceptions" + ) + for exc in eg.exceptions: + log.error(f"Shutdown error: {type(exc).__name__}: {exc}") + except asyncio.CancelledError: + pass # Expected during shutdown + except Exception as e: + log.error(f"Error during listener shutdown: {e}") app = FastAPI( diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py index 3e31438281..0b5d1e2cd0 100644 --- a/backend/open_webui/tasks.py +++ b/backend/open_webui/tasks.py @@ -7,6 +7,7 @@ import logging from redis.asyncio import Redis from fastapi import Request from typing import Dict, List, Optional +from builtins import ExceptionGroup from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX @@ -33,12 +34,37 @@ async def redis_task_command_listener(app): if message["type"] != "message": continue try: - command = json.loads(message["data"]) + # Check if message data is empty or None + if not message["data"]: + log.warning("Received empty message data from Redis pub/sub") + continue + + # Attempt to parse JSON + try: + command = json.loads(message["data"]) + except json.JSONDecodeError as json_error: + log.warning( + f"Invalid JSON in Redis message: {message['data'][:100]}... Error: {json_error}" + ) + continue + if command.get("action") == "stop": task_id = command.get("task_id") local_task = tasks.get(task_id) if local_task: - local_task.cancel() + try: + local_task.cancel() + # Wait briefly for cancellation to complete + await asyncio.sleep(0.1) + except Exception as cancel_error: + log.error(f"Error cancelling task {task_id}: {cancel_error}") + except ExceptionGroup as eg: + # Handle multiple concurrent exceptions + log.error( + f"Multiple errors in task command processing: {len(eg.exceptions)} exceptions" + ) + for i, exc in enumerate(eg.exceptions): + log.error(f" Exception {i+1}: {type(exc).__name__}: {exc}") except Exception as e: log.exception(f"Error handling distributed task command: {e}") @@ -80,18 +106,36 @@ async def redis_send_command(redis: Redis, command: dict): async def cleanup_task(redis, task_id: str, id=None): """ - Remove a completed or canceled task from the global `tasks` dictionary. + Remove a completed or canceled task from the global `tasks` dictionary with proper exception handling. """ + cleanup_errors = [] + + # Redis cleanup if redis: - await redis_cleanup_task(redis, task_id, id) + try: + await redis_cleanup_task(redis, task_id, id) + except Exception as e: + cleanup_errors.append(e) + log.error(f"Redis cleanup failed for task {task_id}: {e}") - tasks.pop(task_id, None) # Remove the task if it exists + # Local cleanup + try: + tasks.pop(task_id, None) + if id and task_id in item_tasks.get(id, []): + item_tasks[id].remove(task_id) + if not item_tasks[id]: + item_tasks.pop(id, None) + except Exception as e: + cleanup_errors.append(e) + log.error(f"Local cleanup failed for task {task_id}: {e}") - # If an ID is provided, remove the task from the item_tasks dictionary - if id and task_id in item_tasks.get(id, []): - item_tasks[id].remove(task_id) - if not item_tasks[id]: # If no tasks left for this ID, remove the entry - item_tasks.pop(id, None) + # If multiple errors occurred, group them + if len(cleanup_errors) > 1 and ExceptionGroup: + raise ExceptionGroup( + f"Multiple cleanup errors for task {task_id}", cleanup_errors + ) + elif cleanup_errors: + raise cleanup_errors[0] async def create_task(redis, coroutine, id=None):