enh: redis dict for internal models state

Co-Authored-By: cw.a <57549718+acwoo97@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-11-27 01:33:52 -05:00
parent ff4b1b9824
commit b5e5617a41
4 changed files with 28 additions and 2 deletions

View file

@ -61,6 +61,7 @@ from open_webui.utils import logger
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
from open_webui.utils.logger import start_logger
from open_webui.socket.main import (
MODELS,
app as socket_app,
periodic_usage_pool_cleanup,
get_event_emitter,
@ -1217,7 +1218,7 @@ app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE
#
########################################
app.state.MODELS = {}
app.state.MODELS = MODELS
# Add the middleware to the app
if ENABLE_COMPRESSION_MIDDLEWARE:

View file

@ -118,6 +118,14 @@ if WEBSOCKET_MANAGER == "redis":
redis_sentinels = get_sentinels_from_env(
WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
)
MODELS = RedisDict(
f"{REDIS_KEY_PREFIX}:models",
redis_url=WEBSOCKET_REDIS_URL,
redis_sentinels=redis_sentinels,
redis_cluster=WEBSOCKET_REDIS_CLUSTER,
)
SESSION_POOL = RedisDict(
f"{REDIS_KEY_PREFIX}:session_pool",
redis_url=WEBSOCKET_REDIS_URL,
@ -148,6 +156,8 @@ if WEBSOCKET_MANAGER == "redis":
renew_func = clean_up_lock.renew_lock
release_func = clean_up_lock.release_lock
else:
MODELS = {}
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}

View file

@ -86,6 +86,15 @@ class RedisDict:
def items(self):
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
def set(self, mapping: dict):
pipe = self.redis.pipeline()
pipe.delete(self.name)
if mapping:
pipe.hset(self.name, mapping={k: json.dumps(v) for k, v in mapping.items()})
pipe.execute()
def get(self, key, default=None):
try:
return self[key]

View file

@ -6,6 +6,7 @@ import sys
from aiocache import cached
from fastapi import Request
from open_webui.socket.utils import RedisDict
from open_webui.routers import openai, ollama
from open_webui.functions import get_function_models
@ -323,7 +324,12 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
log.debug(f"get_all_models() returned {len(models)} models")
request.app.state.MODELS = {model["id"]: model for model in models}
models_dict = {model["id"]: model for model in models}
if isinstance(request.app.state.MODELS, RedisDict):
request.app.state.MODELS.set(models_dict)
else:
request.app.state.MODELS = models_dict
return models