enh/refac: redis cluster support

This commit is contained in:
Timothy Jaeryang Baek 2025-08-04 14:15:08 +04:00
parent 01320d99d6
commit 35400daf19
6 changed files with 71 additions and 13 deletions

View file

@ -7,7 +7,7 @@ import redis
from datetime import datetime
from pathlib import Path
from typing import Generic, Optional, TypeVar
from typing import Generic, Union, Optional, TypeVar
from urllib.parse import urlparse
import requests
@ -213,13 +213,14 @@ class PersistentConfig(Generic[T]):
class AppConfig:
_state: dict[str, PersistentConfig]
_redis: Optional[redis.Redis] = None
_redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
_redis_key_prefix: str
def __init__(
self,
redis_url: Optional[str] = None,
redis_sentinels: Optional[list] = [],
redis_cluster: Optional[bool] = False,
redis_key_prefix: str = "open-webui",
):
super().__setattr__("_state", {})
@ -227,7 +228,12 @@ class AppConfig:
if redis_url:
super().__setattr__(
"_redis",
get_redis_connection(redis_url, redis_sentinels, decode_responses=True),
get_redis_connection(
redis_url,
redis_sentinels,
redis_cluster,
decode_responses=True,
),
)
def __setattr__(self, key, value):

View file

@ -346,7 +346,10 @@ ENABLE_REALTIME_CHAT_SAVE = (
####################################
REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
@ -489,6 +492,9 @@ ENABLE_WEBSOCKET_SUPPORT = (
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_CLUSTER = (
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
)
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
@ -498,9 +504,9 @@ except ValueError:
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
if AIOHTTP_CLIENT_TIMEOUT == "":

View file

@ -399,6 +399,7 @@ from open_webui.env import (
AUDIT_LOG_LEVEL,
CHANGELOG,
REDIS_URL,
REDIS_CLUSTER,
REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
@ -525,6 +526,7 @@ async def lifespan(app: FastAPI):
redis_sentinels=get_sentinels_from_env(
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
),
redis_cluster=REDIS_CLUSTER,
async_mode=True,
)
@ -580,6 +582,7 @@ app.state.instance_id = None
app.state.config = AppConfig(
redis_url=REDIS_URL,
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
redis_cluster=REDIS_CLUSTER,
redis_key_prefix=REDIS_KEY_PREFIX,
)
app.state.redis = None

View file

@ -22,6 +22,7 @@ from open_webui.env import (
ENABLE_WEBSOCKET_SUPPORT,
WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL,
WEBSOCKET_REDIS_CLUSTER,
WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_SENTINEL_PORT,
WEBSOCKET_SENTINEL_HOSTS,
@ -86,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis":
redis_sentinels=get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
),
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
async_mode=True,
)
@ -96,16 +98,19 @@ if WEBSOCKET_MANAGER == "redis":
f"{REDIS_KEY_PREFIX}:session_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
)
USER_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:user_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
)
USAGE_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:usage_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
)
clean_up_lock = RedisLock(
@ -113,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis":
lock_name="usage_cleanup_lock",
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
)
aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock

View file

@ -7,13 +7,24 @@ import pycrdt as Y
class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
def __init__(
self,
redis_url,
lock_name,
timeout_secs,
redis_sentinels=[],
redis_cluster=False,
):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
redis_url,
redis_sentinels,
redis_cluster=redis_cluster,
decode_responses=True,
)
def aquire_lock(self):
@ -36,10 +47,13 @@ class RedisLock:
class RedisDict:
def __init__(self, name, redis_url, redis_sentinels=[]):
def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
self.name = name
self.redis = get_redis_connection(
redis_url, redis_sentinels, decode_responses=True
redis_url,
redis_sentinels,
redis_cluster=redis_cluster,
decode_responses=True,
)
def __setitem__(self, key, value):

View file

@ -96,8 +96,8 @@ class SentinelRedisProxy:
def parse_redis_service_url(redis_url):
parsed_url = urlparse(redis_url)
if parsed_url.scheme != "redis":
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
return {
"username": parsed_url.username or None,
@ -109,10 +109,19 @@ def parse_redis_service_url(redis_url):
def get_redis_connection(
redis_url, redis_sentinels, async_mode=False, decode_responses=True
redis_url,
redis_sentinels,
redis_cluster=False,
async_mode=False,
decode_responses=True,
):
cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses)
cache_key = (
redis_url,
tuple(redis_sentinels) if redis_sentinels else (),
async_mode,
decode_responses,
)
if cache_key in _CONNECTION_CACHE:
return _CONNECTION_CACHE[cache_key]
@ -138,6 +147,12 @@ def get_redis_connection(
redis_config["service"],
async_mode=async_mode,
)
elif redis_cluster:
if not redis_url:
raise ValueError("Redis URL must be provided for cluster mode.")
return redis.cluster.RedisCluster.from_url(
redis_url, decode_responses=decode_responses
)
elif redis_url:
connection = redis.from_url(redis_url, decode_responses=decode_responses)
else:
@ -158,8 +173,16 @@ def get_redis_connection(
redis_config["service"],
async_mode=async_mode,
)
elif redis_cluster:
if not redis_url:
raise ValueError("Redis URL must be provided for cluster mode.")
return redis.cluster.RedisCluster.from_url(
redis_url, decode_responses=decode_responses
)
elif redis_url:
connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses)
connection = redis.Redis.from_url(
redis_url, decode_responses=decode_responses
)
_CONNECTION_CACHE[cache_key] = connection
return connection