diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6da60fc860..b6f26a8278 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -513,7 +513,7 @@ async def lifespan(app: FastAPI): async_mode=True, ) - if isinstance(app.state.redis, Redis): + if app.state.redis is not None: app.state.redis_task_command_listener = asyncio.create_task( redis_task_command_listener(app) ) @@ -1424,7 +1424,7 @@ async def stop_task_endpoint( @app.get("/api/tasks") async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): - return {"tasks": list_tasks(request)} + return {"tasks": await list_tasks(request)} @app.get("/api/tasks/chat/{chat_id}") @@ -1435,7 +1435,7 @@ async def list_tasks_by_chat_id_endpoint( if chat is None or chat.user_id != user.id: return {"task_ids": []} - task_ids = list_task_ids_by_chat_id(request, chat_id) + task_ids = await list_task_ids_by_chat_id(request, chat_id) print(f"Task IDs for chat {chat_id}: {task_ids}") return {"task_ids": task_ids} diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py index 0923159fb0..d73b5800fb 100644 --- a/backend/open_webui/tasks.py +++ b/backend/open_webui/tasks.py @@ -3,7 +3,7 @@ import asyncio from typing import Dict from uuid import uuid4 import json -from redis import Redis +from redis.asyncio import Redis from fastapi import Request from typing import Dict, List, Optional @@ -19,18 +19,16 @@ REDIS_PUBSUB_CHANNEL = "open-webui:tasks:commands" def is_redis(request: Request) -> bool: # Called everywhere a request is available to check Redis - return hasattr(request.app.state, "redis") and isinstance( - request.app.state.redis, Redis - ) + return hasattr(request.app.state, "redis") and (request.app.state.redis is not None) async def redis_task_command_listener(app): redis: Redis = app.state.redis pubsub = redis.pubsub() await pubsub.subscribe(REDIS_PUBSUB_CHANNEL) - print("Subscribed to Redis task command channel") async for message in pubsub.listen(): + print(f"Received message: {message}") if message["type"] != "message": continue try: @@ -49,42 +47,42 @@ async def redis_task_command_listener(app): ### ------------------------------ -def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]): +async def redis_save_task(redis: Redis, task_id: str, chat_id: Optional[str]): pipe = redis.pipeline() pipe.hset(REDIS_TASKS_KEY, task_id, chat_id or "") if chat_id: pipe.sadd(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id) - pipe.execute() + await pipe.execute() -def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]): +async def redis_cleanup_task(redis: Redis, task_id: str, chat_id: Optional[str]): pipe = redis.pipeline() pipe.hdel(REDIS_TASKS_KEY, task_id) if chat_id: pipe.srem(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}", task_id) - if pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute()[-1] == 0: + if (await pipe.scard(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}").execute())[-1] == 0: pipe.delete(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}") # Remove if empty set - pipe.execute() + await pipe.execute() -def redis_list_tasks(redis: Redis) -> List[str]: - return list(redis.hkeys(REDIS_TASKS_KEY)) +async def redis_list_tasks(redis: Redis) -> List[str]: + return list(await redis.hkeys(REDIS_TASKS_KEY)) -def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]: - return list(redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")) +async def redis_list_chat_tasks(redis: Redis, chat_id: str) -> List[str]: + return list(await redis.smembers(f"{REDIS_CHAT_TASKS_KEY}:{chat_id}")) -def redis_send_command(redis: Redis, command: dict): - redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command)) +async def redis_send_command(redis: Redis, command: dict): + await redis.publish(REDIS_PUBSUB_CHANNEL, json.dumps(command)) -def cleanup_task(request, task_id: str, id=None): +async def cleanup_task(request, task_id: str, id=None): """ Remove a completed or canceled task from the global `tasks` dictionary. """ if is_redis(request): - redis_cleanup_task(request.app.state.redis, task_id, id) + await redis_cleanup_task(request.app.state.redis, task_id, id) tasks.pop(task_id, None) # Remove the task if it exists @@ -95,7 +93,7 @@ def cleanup_task(request, task_id: str, id=None): chat_tasks.pop(id, None) -def create_task(request, coroutine, id=None): +async def create_task(request, coroutine, id=None): """ Create a new asyncio task and add it to the global task dictionary. """ @@ -103,7 +101,9 @@ def create_task(request, coroutine, id=None): task = asyncio.create_task(coroutine) # Create the task # Add a done callback for cleanup - task.add_done_callback(lambda t: cleanup_task(request, task_id, id)) + task.add_done_callback( + lambda t: asyncio.create_task(cleanup_task(request, task_id, id)) + ) tasks[task_id] = task # If an ID is provided, associate the task with that ID @@ -113,26 +113,26 @@ def create_task(request, coroutine, id=None): chat_tasks[id] = [task_id] if is_redis(request): - redis_save_task(request.app.state.redis, task_id, id) + await redis_save_task(request.app.state.redis, task_id, id) return task_id, task -def list_tasks(request): +async def list_tasks(request): """ List all currently active task IDs. """ if is_redis(request): - return redis_list_tasks(request.app.state.redis) + return await redis_list_tasks(request.app.state.redis) return list(tasks.keys()) -def list_task_ids_by_chat_id(request, id): +async def list_task_ids_by_chat_id(request, id): """ List all tasks associated with a specific ID. """ if is_redis(request): - return redis_list_chat_tasks(request.app.state.redis, id) + return await redis_list_chat_tasks(request.app.state.redis, id) return chat_tasks.get(id, []) @@ -142,7 +142,7 @@ async def stop_task(request, task_id: str): """ if is_redis(request): # PUBSUB: All instances check if they have this task, and stop if so. - redis_send_command( + await redis_send_command( request.app.state.redis, { "action": "stop", diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index a5a9b8e078..77124eabc0 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -2413,7 +2413,7 @@ async def process_chat_response( await response.background() # background_tasks.add_task(post_response_handler, response, events) - task_id, _ = create_task( + task_id, _ = await create_task( request, post_response_handler(response, events), id=metadata["chat_id"] ) return {"status": True, "task_id": task_id}