This commit is contained in:
_00_ 2025-12-09 08:22:59 +01:00 committed by GitHub
commit 5c067cc193
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 68 additions and 11 deletions

View file

@ -623,8 +623,21 @@ async def lifespan(app: FastAPI):
yield yield
# In the lifespan shutdown
if hasattr(app.state, "redis_task_command_listener"): if hasattr(app.state, "redis_task_command_listener"):
try:
app.state.redis_task_command_listener.cancel() app.state.redis_task_command_listener.cancel()
await app.state.redis_task_command_listener
except ExceptionGroup as eg:
log.error(
f"Multiple errors during listener shutdown: {len(eg.exceptions)} exceptions"
)
for exc in eg.exceptions:
log.error(f"Shutdown error: {type(exc).__name__}: {exc}")
except asyncio.CancelledError:
pass # Expected during shutdown
except Exception as e:
log.error(f"Error during listener shutdown: {e}")
app = FastAPI( app = FastAPI(

View file

@ -7,6 +7,7 @@ import logging
from redis.asyncio import Redis from redis.asyncio import Redis
from fastapi import Request from fastapi import Request
from typing import Dict, List, Optional from typing import Dict, List, Optional
from builtins import ExceptionGroup
from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX from open_webui.env import SRC_LOG_LEVELS, REDIS_KEY_PREFIX
@ -32,13 +33,38 @@ async def redis_task_command_listener(app):
async for message in pubsub.listen(): async for message in pubsub.listen():
if message["type"] != "message": if message["type"] != "message":
continue continue
try:
# Check if message data is empty or None
if not message["data"]:
log.warning("Received empty message data from Redis pub/sub")
continue
# Attempt to parse JSON
try: try:
command = json.loads(message["data"]) command = json.loads(message["data"])
except json.JSONDecodeError as json_error:
log.warning(
f"Invalid JSON in Redis message: {message['data'][:100]}... Error: {json_error}"
)
continue
if command.get("action") == "stop": if command.get("action") == "stop":
task_id = command.get("task_id") task_id = command.get("task_id")
local_task = tasks.get(task_id) local_task = tasks.get(task_id)
if local_task: if local_task:
try:
local_task.cancel() local_task.cancel()
# Wait briefly for cancellation to complete
await asyncio.sleep(0.1)
except Exception as cancel_error:
log.error(f"Error cancelling task {task_id}: {cancel_error}")
except ExceptionGroup as eg:
# Handle multiple concurrent exceptions
log.error(
f"Multiple errors in task command processing: {len(eg.exceptions)} exceptions"
)
for i, exc in enumerate(eg.exceptions):
log.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
except Exception as e: except Exception as e:
log.exception(f"Error handling distributed task command: {e}") log.exception(f"Error handling distributed task command: {e}")
@ -80,18 +106,36 @@ async def redis_send_command(redis: Redis, command: dict):
async def cleanup_task(redis, task_id: str, id=None): async def cleanup_task(redis, task_id: str, id=None):
""" """
Remove a completed or canceled task from the global `tasks` dictionary. Remove a completed or canceled task from the global `tasks` dictionary with proper exception handling.
""" """
cleanup_errors = []
# Redis cleanup
if redis: if redis:
try:
await redis_cleanup_task(redis, task_id, id) await redis_cleanup_task(redis, task_id, id)
except Exception as e:
cleanup_errors.append(e)
log.error(f"Redis cleanup failed for task {task_id}: {e}")
tasks.pop(task_id, None) # Remove the task if it exists # Local cleanup
try:
# If an ID is provided, remove the task from the item_tasks dictionary tasks.pop(task_id, None)
if id and task_id in item_tasks.get(id, []): if id and task_id in item_tasks.get(id, []):
item_tasks[id].remove(task_id) item_tasks[id].remove(task_id)
if not item_tasks[id]: # If no tasks left for this ID, remove the entry if not item_tasks[id]:
item_tasks.pop(id, None) item_tasks.pop(id, None)
except Exception as e:
cleanup_errors.append(e)
log.error(f"Local cleanup failed for task {task_id}: {e}")
# If multiple errors occurred, group them
if len(cleanup_errors) > 1 and ExceptionGroup:
raise ExceptionGroup(
f"Multiple cleanup errors for task {task_id}", cleanup_errors
)
elif cleanup_errors:
raise cleanup_errors[0]
async def create_task(redis, coroutine, id=None): async def create_task(redis, coroutine, id=None):