diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index a7db457278..74e9f02f11 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -7,7 +7,7 @@ import redis from datetime import datetime from pathlib import Path -from typing import Generic, Optional, TypeVar +from typing import Generic, Union, Optional, TypeVar from urllib.parse import urlparse import requests @@ -213,13 +213,14 @@ class PersistentConfig(Generic[T]): class AppConfig: _state: dict[str, PersistentConfig] - _redis: Optional[redis.Redis] = None + _redis: Union[redis.Redis, redis.cluster.RedisCluster] = None _redis_key_prefix: str def __init__( self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = [], + redis_cluster: Optional[bool] = False, redis_key_prefix: str = "open-webui", ): super().__setattr__("_state", {}) @@ -227,7 +228,12 @@ class AppConfig: if redis_url: super().__setattr__( "_redis", - get_redis_connection(redis_url, redis_sentinels, decode_responses=True), + get_redis_connection( + redis_url, + redis_sentinels, + redis_cluster, + decode_responses=True, + ), ) def __setattr__(self, key, value): diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 7ac7c991a0..2505514dc9 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -346,7 +346,10 @@ ENABLE_REALTIME_CHAT_SAVE = ( #################################### REDIS_URL = os.environ.get("REDIS_URL", "") +REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true" + REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui") + REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") @@ -489,6 +492,9 @@ ENABLE_WEBSOCKET_SUPPORT = ( WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) +WEBSOCKET_REDIS_CLUSTER = ( + os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true" +) websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60") @@ -498,9 +504,9 @@ except ValueError: WEBSOCKET_REDIS_LOCK_TIMEOUT = 60 WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") - WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") + AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") if AIOHTTP_CLIENT_TIMEOUT == "": diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 679cc38a11..19cdd87e98 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -399,6 +399,7 @@ from open_webui.env import ( AUDIT_LOG_LEVEL, CHANGELOG, REDIS_URL, + REDIS_CLUSTER, REDIS_KEY_PREFIX, REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT, @@ -525,6 +526,7 @@ async def lifespan(app: FastAPI): redis_sentinels=get_sentinels_from_env( REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT ), + redis_cluster=REDIS_CLUSTER, async_mode=True, ) @@ -580,6 +582,7 @@ app.state.instance_id = None app.state.config = AppConfig( redis_url=REDIS_URL, redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), + redis_cluster=REDIS_CLUSTER, redis_key_prefix=REDIS_KEY_PREFIX, ) app.state.redis = None diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 997026acad..49323db975 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -22,6 +22,7 @@ from open_webui.env import ( ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, + WEBSOCKET_REDIS_CLUSTER, WEBSOCKET_REDIS_LOCK_TIMEOUT, WEBSOCKET_SENTINEL_PORT, WEBSOCKET_SENTINEL_HOSTS, @@ -86,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis": redis_sentinels=get_sentinels_from_env( WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT ), + redis_cluster=WEBSOCKET_REDIS_CLUSTER, async_mode=True, ) @@ -96,16 +98,19 @@ if WEBSOCKET_MANAGER == "redis": f"{REDIS_KEY_PREFIX}:session_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) USER_POOL = RedisDict( f"{REDIS_KEY_PREFIX}:user_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) USAGE_POOL = RedisDict( f"{REDIS_KEY_PREFIX}:usage_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) clean_up_lock = RedisLock( @@ -113,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis": lock_name="usage_cleanup_lock", timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) aquire_func = clean_up_lock.aquire_lock renew_func = clean_up_lock.renew_lock diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 84ef334156..168d2fd88e 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -7,13 +7,24 @@ import pycrdt as Y class RedisLock: - def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]): + def __init__( + self, + redis_url, + lock_name, + timeout_secs, + redis_sentinels=[], + redis_cluster=False, + ): + self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs self.lock_obtained = False self.redis = get_redis_connection( - redis_url, redis_sentinels, decode_responses=True + redis_url, + redis_sentinels, + redis_cluster=redis_cluster, + decode_responses=True, ) def aquire_lock(self): @@ -36,10 +47,13 @@ class RedisLock: class RedisDict: - def __init__(self, name, redis_url, redis_sentinels=[]): + def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False): self.name = name self.redis = get_redis_connection( - redis_url, redis_sentinels, decode_responses=True + redis_url, + redis_sentinels, + redis_cluster=redis_cluster, + decode_responses=True, ) def __setitem__(self, key, value): diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index 7834295375..c60a6fa517 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -96,8 +96,8 @@ class SentinelRedisProxy: def parse_redis_service_url(redis_url): parsed_url = urlparse(redis_url) - if parsed_url.scheme != "redis": - raise ValueError("Invalid Redis URL scheme. Must be 'redis'.") + if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss": + raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.") return { "username": parsed_url.username or None, @@ -109,10 +109,19 @@ def parse_redis_service_url(redis_url): def get_redis_connection( - redis_url, redis_sentinels, async_mode=False, decode_responses=True + redis_url, + redis_sentinels, + redis_cluster=False, + async_mode=False, + decode_responses=True, ): - cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses) + cache_key = ( + redis_url, + tuple(redis_sentinels) if redis_sentinels else (), + async_mode, + decode_responses, + ) if cache_key in _CONNECTION_CACHE: return _CONNECTION_CACHE[cache_key] @@ -138,6 +147,12 @@ def get_redis_connection( redis_config["service"], async_mode=async_mode, ) + elif redis_cluster: + if not redis_url: + raise ValueError("Redis URL must be provided for cluster mode.") + return redis.cluster.RedisCluster.from_url( + redis_url, decode_responses=decode_responses + ) elif redis_url: connection = redis.from_url(redis_url, decode_responses=decode_responses) else: @@ -158,8 +173,16 @@ def get_redis_connection( redis_config["service"], async_mode=async_mode, ) + elif redis_cluster: + if not redis_url: + raise ValueError("Redis URL must be provided for cluster mode.") + return redis.cluster.RedisCluster.from_url( + redis_url, decode_responses=decode_responses + ) elif redis_url: - connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses) + connection = redis.Redis.from_url( + redis_url, decode_responses=decode_responses + ) _CONNECTION_CACHE[cache_key] = connection return connection