mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
enh/refac: redis cluster support
This commit is contained in:
parent
01320d99d6
commit
35400daf19
6 changed files with 71 additions and 13 deletions
|
|
@ -7,7 +7,7 @@ import redis
|
|||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generic, Optional, TypeVar
|
||||
from typing import Generic, Union, Optional, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
|
@ -213,13 +213,14 @@ class PersistentConfig(Generic[T]):
|
|||
|
||||
class AppConfig:
|
||||
_state: dict[str, PersistentConfig]
|
||||
_redis: Optional[redis.Redis] = None
|
||||
_redis: Union[redis.Redis, redis.cluster.RedisCluster] = None
|
||||
_redis_key_prefix: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: Optional[str] = None,
|
||||
redis_sentinels: Optional[list] = [],
|
||||
redis_cluster: Optional[bool] = False,
|
||||
redis_key_prefix: str = "open-webui",
|
||||
):
|
||||
super().__setattr__("_state", {})
|
||||
|
|
@ -227,7 +228,12 @@ class AppConfig:
|
|||
if redis_url:
|
||||
super().__setattr__(
|
||||
"_redis",
|
||||
get_redis_connection(redis_url, redis_sentinels, decode_responses=True),
|
||||
get_redis_connection(
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster,
|
||||
decode_responses=True,
|
||||
),
|
||||
)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
|
|
|
|||
|
|
@ -346,7 +346,10 @@ ENABLE_REALTIME_CHAT_SAVE = (
|
|||
####################################
|
||||
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "")
|
||||
REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true"
|
||||
|
||||
REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui")
|
||||
|
||||
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
|
||||
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
|
||||
|
||||
|
|
@ -489,6 +492,9 @@ ENABLE_WEBSOCKET_SUPPORT = (
|
|||
WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
||||
|
||||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
WEBSOCKET_REDIS_CLUSTER = (
|
||||
os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true"
|
||||
)
|
||||
|
||||
websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60")
|
||||
|
||||
|
|
@ -498,9 +504,9 @@ except ValueError:
|
|||
WEBSOCKET_REDIS_LOCK_TIMEOUT = 60
|
||||
|
||||
WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "")
|
||||
|
||||
WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379")
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT == "":
|
||||
|
|
|
|||
|
|
@ -399,6 +399,7 @@ from open_webui.env import (
|
|||
AUDIT_LOG_LEVEL,
|
||||
CHANGELOG,
|
||||
REDIS_URL,
|
||||
REDIS_CLUSTER,
|
||||
REDIS_KEY_PREFIX,
|
||||
REDIS_SENTINEL_HOSTS,
|
||||
REDIS_SENTINEL_PORT,
|
||||
|
|
@ -525,6 +526,7 @@ async def lifespan(app: FastAPI):
|
|||
redis_sentinels=get_sentinels_from_env(
|
||||
REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT
|
||||
),
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
|
|
@ -580,6 +582,7 @@ app.state.instance_id = None
|
|||
app.state.config = AppConfig(
|
||||
redis_url=REDIS_URL,
|
||||
redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT),
|
||||
redis_cluster=REDIS_CLUSTER,
|
||||
redis_key_prefix=REDIS_KEY_PREFIX,
|
||||
)
|
||||
app.state.redis = None
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from open_webui.env import (
|
|||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
WEBSOCKET_REDIS_CLUSTER,
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
WEBSOCKET_SENTINEL_PORT,
|
||||
WEBSOCKET_SENTINEL_HOSTS,
|
||||
|
|
@ -86,6 +87,7 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
redis_sentinels=get_sentinels_from_env(
|
||||
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
|
||||
),
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
|
|
@ -96,16 +98,19 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
f"{REDIS_KEY_PREFIX}:session_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
USER_POOL = RedisDict(
|
||||
f"{REDIS_KEY_PREFIX}:user_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
USAGE_POOL = RedisDict(
|
||||
f"{REDIS_KEY_PREFIX}:usage_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
|
||||
clean_up_lock = RedisLock(
|
||||
|
|
@ -113,6 +118,7 @@ if WEBSOCKET_MANAGER == "redis":
|
|||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
redis_sentinels=redis_sentinels,
|
||||
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
|
||||
)
|
||||
aquire_func = clean_up_lock.aquire_lock
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
|
|
|
|||
|
|
@ -7,13 +7,24 @@ import pycrdt as Y
|
|||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]):
|
||||
def __init__(
|
||||
self,
|
||||
redis_url,
|
||||
lock_name,
|
||||
timeout_secs,
|
||||
redis_sentinels=[],
|
||||
redis_cluster=False,
|
||||
):
|
||||
|
||||
self.lock_name = lock_name
|
||||
self.lock_id = str(uuid.uuid4())
|
||||
self.timeout_secs = timeout_secs
|
||||
self.lock_obtained = False
|
||||
self.redis = get_redis_connection(
|
||||
redis_url, redis_sentinels, decode_responses=True
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster=redis_cluster,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def aquire_lock(self):
|
||||
|
|
@ -36,10 +47,13 @@ class RedisLock:
|
|||
|
||||
|
||||
class RedisDict:
|
||||
def __init__(self, name, redis_url, redis_sentinels=[]):
|
||||
def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
|
||||
self.name = name
|
||||
self.redis = get_redis_connection(
|
||||
redis_url, redis_sentinels, decode_responses=True
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster=redis_cluster,
|
||||
decode_responses=True,
|
||||
)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
|
|
|
|||
|
|
@ -96,8 +96,8 @@ class SentinelRedisProxy:
|
|||
|
||||
def parse_redis_service_url(redis_url):
|
||||
parsed_url = urlparse(redis_url)
|
||||
if parsed_url.scheme != "redis":
|
||||
raise ValueError("Invalid Redis URL scheme. Must be 'redis'.")
|
||||
if parsed_url.scheme != "redis" and parsed_url.scheme != "rediss":
|
||||
raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.")
|
||||
|
||||
return {
|
||||
"username": parsed_url.username or None,
|
||||
|
|
@ -109,10 +109,19 @@ def parse_redis_service_url(redis_url):
|
|||
|
||||
|
||||
def get_redis_connection(
|
||||
redis_url, redis_sentinels, async_mode=False, decode_responses=True
|
||||
redis_url,
|
||||
redis_sentinels,
|
||||
redis_cluster=False,
|
||||
async_mode=False,
|
||||
decode_responses=True,
|
||||
):
|
||||
|
||||
cache_key = (redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses)
|
||||
cache_key = (
|
||||
redis_url,
|
||||
tuple(redis_sentinels) if redis_sentinels else (),
|
||||
async_mode,
|
||||
decode_responses,
|
||||
)
|
||||
|
||||
if cache_key in _CONNECTION_CACHE:
|
||||
return _CONNECTION_CACHE[cache_key]
|
||||
|
|
@ -138,6 +147,12 @@ def get_redis_connection(
|
|||
redis_config["service"],
|
||||
async_mode=async_mode,
|
||||
)
|
||||
elif redis_cluster:
|
||||
if not redis_url:
|
||||
raise ValueError("Redis URL must be provided for cluster mode.")
|
||||
return redis.cluster.RedisCluster.from_url(
|
||||
redis_url, decode_responses=decode_responses
|
||||
)
|
||||
elif redis_url:
|
||||
connection = redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
|
|
@ -158,8 +173,16 @@ def get_redis_connection(
|
|||
redis_config["service"],
|
||||
async_mode=async_mode,
|
||||
)
|
||||
elif redis_cluster:
|
||||
if not redis_url:
|
||||
raise ValueError("Redis URL must be provided for cluster mode.")
|
||||
return redis.cluster.RedisCluster.from_url(
|
||||
redis_url, decode_responses=decode_responses
|
||||
)
|
||||
elif redis_url:
|
||||
connection = redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
connection = redis.Redis.from_url(
|
||||
redis_url, decode_responses=decode_responses
|
||||
)
|
||||
|
||||
_CONNECTION_CACHE[cache_key] = connection
|
||||
return connection
|
||||
|
|
|
|||
Loading…
Reference in a new issue