This commit is contained in:
_00_ 2025-12-09 08:22:59 +01:00 committed by GitHub
commit 5c067cc193
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 68 additions and 11 deletions

View file

@ -623,8 +623,21 @@ async def lifespan(app: FastAPI):
yield
# In the lifespan shutdown
if hasattr(app.state, "redis_task_command_listener"):
app.state.redis_task_command_listener.cancel()
try:
app.state.redis_task_command_listener.cancel()
await app.state.redis_task_command_listener
except ExceptionGroup as eg:
log.error(
f"Multiple errors during listener shutdown: {len(eg.exceptions)} exceptions"
)
for exc in eg.exceptions:
log.error(f"Shutdown error: {type(exc).__name__}: {exc}")
except asyncio.CancelledError:
pass # Expected during shutdown
except Exception as e:
log.error(f"Error during listener shutdown: {e}")
app = FastAPI(

View file

@ -7,6 +7,7 @@ import logging
from redis.asyncio import Redis
from fastapi import Request
from typing import Dict, List, Optional
from builtins import ExceptionGroup
from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX
@ -33,12 +34,37 @@ async def redis_task_command_listener(app):
if message["type"] != "message":
continue
try:
command = json.loads(message["data"])
# Check if message data is empty or None
if not message["data"]:
log.warning("Received empty message data from Redis pub/sub")
continue
# Attempt to parse JSON
try:
command = json.loads(message["data"])
except json.JSONDecodeError as json_error:
log.warning(
f"Invalid JSON in Redis message: {message['data'][:100]}... Error: {json_error}"
)
continue
if command.get("action") == "stop":
task_id = command.get("task_id")
local_task = tasks.get(task_id)
if local_task:
local_task.cancel()
try:
local_task.cancel()
# Wait briefly for cancellation to complete
await asyncio.sleep(0.1)
except Exception as cancel_error:
log.error(f"Error cancelling task {task_id}: {cancel_error}")
except ExceptionGroup as eg:
# Handle multiple concurrent exceptions
log.error(
f"Multiple errors in task command processing: {len(eg.exceptions)} exceptions"
)
for i, exc in enumerate(eg.exceptions):
log.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
except Exception as e:
log.exception(f"Error handling distributed task command: {e}")
@ -80,18 +106,36 @@ async def redis_send_command(redis: Redis, command: dict):
async def cleanup_task(redis, task_id: str, id=None):
"""
Remove a completed or canceled task from the global `tasks` dictionary.
Remove a completed or canceled task from the global `tasks` dictionary with proper exception handling.
"""
cleanup_errors = []
# Redis cleanup
if redis:
await redis_cleanup_task(redis, task_id, id)
try:
await redis_cleanup_task(redis, task_id, id)
except Exception as e:
cleanup_errors.append(e)
log.error(f"Redis cleanup failed for task {task_id}: {e}")
tasks.pop(task_id, None) # Remove the task if it exists
# Local cleanup
try:
tasks.pop(task_id, None)
if id and task_id in item_tasks.get(id, []):
item_tasks[id].remove(task_id)
if not item_tasks[id]:
item_tasks.pop(id, None)
except Exception as e:
cleanup_errors.append(e)
log.error(f"Local cleanup failed for task {task_id}: {e}")
# If an ID is provided, remove the task from the item_tasks dictionary
if id and task_id in item_tasks.get(id, []):
item_tasks[id].remove(task_id)
if not item_tasks[id]: # If no tasks left for this ID, remove the entry
item_tasks.pop(id, None)
# If multiple errors occurred, group them
if len(cleanup_errors) > 1 and ExceptionGroup:
raise ExceptionGroup(
f"Multiple cleanup errors for task {task_id}", cleanup_errors
)
elif cleanup_errors:
raise cleanup_errors[0]
async def create_task(redis, coroutine, id=None):