enh/refac: redis cluster support

This commit is contained in:
Timothy Jaeryang Baek 2025-08-04 14:15:08 +04:00
parent 01320d99d6
commit 35400daf19
6 changed files with 71 additions and 13 deletions

View file

@ -7,7 +7,7 @@ import redis
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Generic, Optional, TypeVar from typing import Generic, Union, Optional, TypeVar
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
@ -213,13 +213,14 @@ class PersistentConfig(Generic[T]):
class AppConfig: class AppConfig:
_state: dict[str, PersistentConfig] _state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None _redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
_redis_key_prefix: str _redis_key_prefix: str
def __init__( def __init__(
self, self,
redis_url: Optional[str] = None, redis_url: Optional[str] = None,
redis_sentinels: Optional[list] = [], redis_sentinels: Optional[list] = [],
redis_cluster: Optional[bool] = False,
redis_key_prefix: str = "open-webui", redis_key_prefix: str = "open-webui",
): ):
super().__setattr__("_state", {}) super().__setattr__("_state", {})
@ -227,7 +228,12 @@ class AppConfig:
if redis_url: if redis_url:
super().__setattr__( super().__setattr__(
"_redis", "_redis",
get_redis_connection(redis_url, redis_sentinels, decode_responses=True), get_redis_connection(
redis_url,
redis_sentinels,
redis_cluster,
decode_responses=True,
),
) )
def __setattr__(self, key, value): def __setattr__(self, key, value):

View file

@ -346,7 +346,10 @@ ENABLE_REALTIME_CHAT_SAVE = (
#################################### ####################################
REDIS_URL = os.environ.get("REDIS_URL", "") REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui") REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
@ -489,6 +492,9 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_CLUSTER = (
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
)
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60") websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
@ -498,9 +504,9 @@ except ValueError:
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60 WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
if AIOHTTP_CLIENT_TIMEOUT == "": if AIOHTTP_CLIENT_TIMEOUT == "":

View file

@ -399,6 +399,7 @@ from open_webui.env import (
AUDIT_LOG_LEVEL, AUDIT_LOG_LEVEL,
CHANGELOG, CHANGELOG,
REDIS_URL, REDIS_URL,
REDIS_CLUSTER,
REDIS_KEY_PREFIX, REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT, REDIS_SENTINEL_PORT,
@ -525,6 +526,7 @@ async def lifespan(app: FastAPI):
redis_sentinels=get_sentinels_from_env( redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
), ),
redis_cluster=REDIS_CLUSTER,
async_mode=True, async_mode=True,
) )
@ -580,6 +582,7 @@ app.state.instance_id = None
app.state.config = AppConfig( app.state.config = AppConfig(
redis_url=REDIS_URL, redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
redis_cluster=REDIS_CLUSTER,
redis_key_prefix=REDIS_KEY_PREFIX, redis_key_prefix=REDIS_KEY_PREFIX,
) )
app.state.redis = None app.state.redis = None

View file

@ -22,6 +22,7 @@ from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT, ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER, WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL, WEBSOCKET_REDIS_URL,
WEBSOCKET_REDIS_CLUSTER,
WEBSOCKET_REDIS_LOCK_TIMEOUT, WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_SENTINEL_PORT, WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_HOSTS,
@ -86,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis":
redis_sentinels=get_sentinels_from_env( redis_sentinels=get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
), ),
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
async_mode=True, async_mode=True,
) )
@ -96,16 +98,19 @@ if WEBSOCKET_MANAGER == "redis":
f"{REDIS_KEY_PREFIX}:session_pool", f"{REDIS_KEY_PREFIX}:session_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
) )
USER_POOL = RedisDict( USER_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:user_pool", f"{REDIS_KEY_PREFIX}:user_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
) )
USAGE_POOL = RedisDict( USAGE_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:usage_pool", f"{REDIS_KEY_PREFIX}:usage_pool",
redis_url=WEBSOCKET_REDIS_URL, redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
) )
clean_up_lock = RedisLock( clean_up_lock = RedisLock(
@ -113,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis":
lock_name="usage_cleanup_lock", lock_name="usage_cleanup_lock",
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
redis_sentinels=redis_sentinels, redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
) )
aquire_func = clean_up_lock.aquire_lock aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock renew_func = clean_up_lock.renew_lock

View file

@ -7,13 +7,24 @@ import pycrdt as Y
class RedisLock: class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]): def __init__(
self,
redis_url,
lock_name,
timeout_secs,
redis_sentinels=[],
redis_cluster=False,
):
self.lock_name = lock_name self.lock_name = lock_name
self.lock_id = str(uuid.uuid4()) self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs self.timeout_secs = timeout_secs
self.lock_obtained = False self.lock_obtained = False
self.redis = get_redis_connection( self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True redis_url,
redis_sentinels,
redis_cluster=redis_cluster,
decode_responses=True,
) )
def aquire_lock(self): def aquire_lock(self):
@ -36,10 +47,13 @@ class RedisLock:
class RedisDict: class RedisDict:
def __init__(self, name, redis_url, redis_sentinels=[]): def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
self.name = name self.name = name
self.redis = get_redis_connection( self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True redis_url,
redis_sentinels,
redis_cluster=redis_cluster,
decode_responses=True,
) )
def __setitem__(self, key, value): def __setitem__(self, key, value):

View file

@ -96,8 +96,8 @@ class SentinelRedisProxy:
def parse_redis_service_url(redis_url): def parse_redis_service_url(redis_url):
parsed_url = urlparse(redis_url) parsed_url = urlparse(redis_url)
if parsed_url.scheme != "redis": if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.") raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
return { return {
"username": parsed_url.username or None, "username": parsed_url.username or None,
@ -109,10 +109,19 @@ def parse_redis_service_url(redis_url):
def get_redis_connection( def get_redis_connection(
redis_url, redis_sentinels, async_mode=False, decode_responses=True redis_url,
redis_sentinels,
redis_cluster=False,
async_mode=False,
decode_responses=True,
): ):
cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses) cache_key = (
redis_url,
tuple(redis_sentinels) if redis_sentinels else (),
async_mode,
decode_responses,
)
if cache_key in _CONNECTION_CACHE: if cache_key in _CONNECTION_CACHE:
return _CONNECTION_CACHE[cache_key] return _CONNECTION_CACHE[cache_key]
@ -138,6 +147,12 @@ def get_redis_connection(
redis_config["service"], redis_config["service"],
async_mode=async_mode, async_mode=async_mode,
) )
elif redis_cluster:
if not redis_url:
raise ValueError("Redis URL must be provided for cluster mode.")
return redis.cluster.RedisCluster.from_url(
redis_url, decode_responses=decode_responses
)
elif redis_url: elif redis_url:
connection = redis.from_url(redis_url, decode_responses=decode_responses) connection = redis.from_url(redis_url, decode_responses=decode_responses)
else: else:
@ -158,8 +173,16 @@ def get_redis_connection(
redis_config["service"], redis_config["service"],
async_mode=async_mode, async_mode=async_mode,
) )
elif redis_cluster:
if not redis_url:
raise ValueError("Redis URL must be provided for cluster mode.")
return redis.cluster.RedisCluster.from_url(
redis_url, decode_responses=decode_responses
)
elif redis_url: elif redis_url:
connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses) connection = redis.Redis.from_url(
redis_url, decode_responses=decode_responses
)
_CONNECTION_CACHE[cache_key] = connection _CONNECTION_CACHE[cache_key] = connection
return connection return connection