diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index f7c14ab5ad..727bfe65dd 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -61,6 +61,7 @@ from open_webui.utils import logger from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware from open_webui.utils.logger import start_logger from open_webui.socket.main import ( + MODELS, app as socket_app, periodic_usage_pool_cleanup, get_event_emitter, @@ -1217,7 +1218,7 @@ app.state.config.VOICE_MODE_PROMPT_TEMPLATE = VOICE_MODE_PROMPT_TEMPLATE # ######################################## -app.state.MODELS = {} +app.state.MODELS = MODELS # Add the middleware to the app if ENABLE_COMPRESSION_MIDDLEWARE: diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index f79d1dd958..bbfbfa2703 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -118,6 +118,14 @@ if WEBSOCKET_MANAGER == "redis": redis_sentinels = get_sentinels_from_env( WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT ) + + MODELS = RedisDict( + f"{REDIS_KEY_PREFIX}:models", + redis_url=WEBSOCKET_REDIS_URL, + redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, + ) + SESSION_POOL = RedisDict( f"{REDIS_KEY_PREFIX}:session_pool", redis_url=WEBSOCKET_REDIS_URL, @@ -148,6 +156,8 @@ if WEBSOCKET_MANAGER == "redis": renew_func = clean_up_lock.renew_lock release_func = clean_up_lock.release_lock else: + MODELS = {} + SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 168d2fd88e..5739a8027a 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -86,6 +86,15 @@ class RedisDict: def items(self): return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()] + def set(self, mapping: dict): + pipe = self.redis.pipeline() + + pipe.delete(self.name) + if mapping: + pipe.hset(self.name, mapping={k: json.dumps(v) for k, v in mapping.items()}) + + pipe.execute() + def get(self, key, default=None): try: return self[key] diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 8b53ce5193..525ba22e76 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -6,6 +6,7 @@ import sys from aiocache import cached from fastapi import Request +from open_webui.socket.utils import RedisDict from open_webui.routers import openai, ollama from open_webui.functions import get_function_models @@ -323,7 +324,12 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None) log.debug(f"get_all_models() returned {len(models)} models") - request.app.state.MODELS = {model["id"]: model for model in models} + models_dict = {model["id"]: model for model in models} + if isinstance(request.app.state.MODELS, RedisDict): + request.app.state.MODELS.set(models_dict) + else: + request.app.state.MODELS = models_dict + return models