diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 2848d21d6f..061856eef7 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -923,6 +923,18 @@ except Exception: pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" + +#################################### +# MODEL_LIST +#################################### + +ENABLE_MODEL_LIST_CACHE = PersistentConfig( + "ENABLE_MODEL_LIST_CACHE", + "models.cache", + os.environ.get("ENABLE_MODEL_LIST_CACHE", "False").lower() == "true", +) + + #################################### # TOOL_SERVERS #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8e37d9e530..90d8a63a5d 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -36,7 +36,6 @@ from fastapi import ( applications, BackgroundTasks, ) - from fastapi.openapi.docs import get_swagger_ui_html from fastapi.middleware.cors import CORSMiddleware @@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from starlette.datastructures import Headers from open_webui.utils import logger @@ -116,6 +116,8 @@ from open_webui.config import ( OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, + # Model list + ENABLE_MODEL_LIST_CACHE, # Thread pool size for FastAPI/AnyIO THREAD_POOL_SIZE, # Tool Server Configs @@ -534,6 +536,27 @@ async def lifespan(app: FastAPI): asyncio.create_task(periodic_usage_pool_cleanup()) + if app.state.config.ENABLE_MODEL_LIST_CACHE: + get_all_models( + Request( + # Creating a mock request object to pass to get_all_models + { + "type": "http", + "asgi.version": "3.0", + "asgi.spec_version": "2.0", + "method": "GET", + "path": "/internal", + "query_string": b"", + "headers": Headers({}).raw, + "client": ("127.0.0.1", 12345), + "server": ("127.0.0.1", 80), + "scheme": "http", + "app": app, + } + ), + None, + ) + yield if hasattr(app.state, "redis_task_command_listener"): @@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = [] app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS +######################################## +# +# MODEL LIST +# +######################################## + +app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE + ######################################## # # WEBUI @@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE: @app.get("/api/models") -async def get_models(request: Request, user=Depends(get_verified_user)): +async def get_models( + request: Request, refresh: bool = False, user=Depends(get_verified_user) +): def get_filtered_models(models, user): filtered_models = [] for model in models: @@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - all_models = await get_all_models(request, user=user) + if request.app.state.MODELS and ( + request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh + ): + all_models = list(request.app.state.MODELS.values()) + else: + all_models = await get_all_models(request, user=user) models = [] for model in all_models: diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 44b2ef40cf..5b08889d69 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -39,32 +39,37 @@ async def export_config(user=Depends(get_admin_user)): ############################ -# Direct Connections Config +# Connections Config ############################ -class DirectConnectionsConfigForm(BaseModel): +class ConnectionsConfigForm(BaseModel): ENABLE_DIRECT_CONNECTIONS: bool + ENABLE_MODEL_LIST_CACHE: bool -@router.get("/direct_connections", response_model=DirectConnectionsConfigForm) -async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)): +@router.get("/connections", response_model=ConnectionsConfigForm) +async def get_connections_config(request: Request, user=Depends(get_admin_user)): return { "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE, } -@router.post("/direct_connections", response_model=DirectConnectionsConfigForm) -async def set_direct_connections_config( +@router.post("/connections", response_model=ConnectionsConfigForm) +async def set_connections_config( request: Request, - form_data: DirectConnectionsConfigForm, + form_data: ConnectionsConfigForm, user=Depends(get_admin_user), ): request.app.state.config.ENABLE_DIRECT_CONNECTIONS = ( form_data.ENABLE_DIRECT_CONNECTIONS ) + request.app.state.config.ENABLE_MODEL_LIST_CACHE = form_data.ENABLE_MODEL_LIST_CACHE + return { "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + "ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE, } diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index 26dec26c9d..ef983e63bf 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -58,10 +58,10 @@ export const exportConfig = async (token: string) => { return res; }; -export const getDirectConnectionsConfig = async (token: string) => { +export const getConnectionsConfig = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, { method: 'GET', headers: { 'Content-Type': 'application/json', @@ -85,10 +85,10 @@ export const getDirectConnectionsConfig = async (token: string) => { return res; }; -export const setDirectConnectionsConfig = async (token: string, config: object) => { +export const setConnectionsConfig = async (token: string, config: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/components/admin/Settings/Connections.svelte b/src/lib/components/admin/Settings/Connections.svelte index df0b4f6809..fc61f57730 100644 --- a/src/lib/components/admin/Settings/Connections.svelte +++ b/src/lib/components/admin/Settings/Connections.svelte @@ -7,7 +7,7 @@ import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama'; import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai'; import { getModels as _getModels } from '$lib/apis'; - import { getDirectConnectionsConfig, setDirectConnectionsConfig } from '$lib/apis/configs'; + import { getConnectionsConfig, setConnectionsConfig } from '$lib/apis/configs'; import { config, models, settings, user } from '$lib/stores'; @@ -43,7 +43,7 @@ let ENABLE_OPENAI_API: null | boolean = null; let ENABLE_OLLAMA_API: null | boolean = null; - let directConnectionsConfig = null; + let connectionsConfig = null; let pipelineUrls = {}; let showAddOpenAIConnectionModal = false; @@ -106,15 +106,13 @@ } }; - const updateDirectConnectionsHandler = async () => { - const res = await setDirectConnectionsConfig(localStorage.token, directConnectionsConfig).catch( - (error) => { - toast.error(`${error}`); - } - ); + const updateConnectionsHandler = async () => { + const res = await setConnectionsConfig(localStorage.token, connectionsConfig).catch((error) => { + toast.error(`${error}`); + }); if (res) { - toast.success($i18n.t('Direct Connections settings updated')); + toast.success($i18n.t('Connections settings updated')); await models.set(await getModels()); } }; @@ -150,7 +148,7 @@ openaiConfig = await getOpenAIConfig(localStorage.token); })(), (async () => { - directConnectionsConfig = await getDirectConnectionsConfig(localStorage.token); + connectionsConfig = await getConnectionsConfig(localStorage.token); })() ]); @@ -217,7 +215,7 @@