mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
enh: ENABLE_MODEL_LIST_CACHE
This commit is contained in:
parent
2b88f66762
commit
1a52585769
5 changed files with 105 additions and 27 deletions
|
|
@ -923,6 +923,18 @@ except Exception:
|
|||
pass
|
||||
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
####################################
|
||||
# MODEL_LIST
|
||||
####################################
|
||||
|
||||
ENABLE_MODEL_LIST_CACHE = PersistentConfig(
|
||||
"ENABLE_MODEL_LIST_CACHE",
|
||||
"models.cache",
|
||||
os.environ.get("ENABLE_MODEL_LIST_CACHE", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# TOOL_SERVERS
|
||||
####################################
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from fastapi import (
|
|||
applications,
|
||||
BackgroundTasks,
|
||||
)
|
||||
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -49,6 +48,7 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
|||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
|
||||
from open_webui.utils import logger
|
||||
|
|
@ -116,6 +116,8 @@ from open_webui.config import (
|
|||
OPENAI_API_CONFIGS,
|
||||
# Direct Connections
|
||||
ENABLE_DIRECT_CONNECTIONS,
|
||||
# Model list
|
||||
ENABLE_MODEL_LIST_CACHE,
|
||||
# Thread pool size for FastAPI/AnyIO
|
||||
THREAD_POOL_SIZE,
|
||||
# Tool Server Configs
|
||||
|
|
@ -534,6 +536,27 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
asyncio.create_task(periodic_usage_pool_cleanup())
|
||||
|
||||
if app.state.config.ENABLE_MODEL_LIST_CACHE:
|
||||
get_all_models(
|
||||
Request(
|
||||
# Creating a mock request object to pass to get_all_models
|
||||
{
|
||||
"type": "http",
|
||||
"asgi.version": "3.0",
|
||||
"asgi.spec_version": "2.0",
|
||||
"method": "GET",
|
||||
"path": "/internal",
|
||||
"query_string": b"",
|
||||
"headers": Headers({}).raw,
|
||||
"client": ("127.0.0.1", 12345),
|
||||
"server": ("127.0.0.1", 80),
|
||||
"scheme": "http",
|
||||
"app": app,
|
||||
}
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if hasattr(app.state, "redis_task_command_listener"):
|
||||
|
|
@ -616,6 +639,14 @@ app.state.TOOL_SERVERS = []
|
|||
|
||||
app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
|
||||
|
||||
########################################
|
||||
#
|
||||
# MODEL LIST
|
||||
#
|
||||
########################################
|
||||
|
||||
app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE
|
||||
|
||||
########################################
|
||||
#
|
||||
# WEBUI
|
||||
|
|
@ -1191,7 +1222,9 @@ if audit_level != AuditLevel.NONE:
|
|||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
async def get_models(
|
||||
request: Request, refresh: bool = False, user=Depends(get_verified_user)
|
||||
):
|
||||
def get_filtered_models(models, user):
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
|
|
@ -1215,7 +1248,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)):
|
|||
|
||||
return filtered_models
|
||||
|
||||
all_models = await get_all_models(request, user=user)
|
||||
if request.app.state.MODELS and (
|
||||
request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh
|
||||
):
|
||||
all_models = list(request.app.state.MODELS.values())
|
||||
else:
|
||||
all_models = await get_all_models(request, user=user)
|
||||
|
||||
models = []
|
||||
for model in all_models:
|
||||
|
|
|
|||
|
|
@ -39,32 +39,37 @@ async def export_config(user=Depends(get_admin_user)):
|
|||
|
||||
|
||||
############################
|
||||
# Direct Connections Config
|
||||
# Connections Config
|
||||
############################
|
||||
|
||||
|
||||
class DirectConnectionsConfigForm(BaseModel):
|
||||
class ConnectionsConfigForm(BaseModel):
|
||||
ENABLE_DIRECT_CONNECTIONS: bool
|
||||
ENABLE_MODEL_LIST_CACHE: bool
|
||||
|
||||
|
||||
@router.get("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)):
|
||||
@router.get("/connections", response_model=ConnectionsConfigForm)
|
||||
async def get_connections_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
"ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/direct_connections", response_model=DirectConnectionsConfigForm)
|
||||
async def set_direct_connections_config(
|
||||
@router.post("/connections", response_model=ConnectionsConfigForm)
|
||||
async def set_connections_config(
|
||||
request: Request,
|
||||
form_data: DirectConnectionsConfigForm,
|
||||
form_data: ConnectionsConfigForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
|
||||
form_data.ENABLE_DIRECT_CONNECTIONS
|
||||
)
|
||||
request.app.state.config.ENABLE_MODEL_LIST_CACHE = form_data.ENABLE_MODEL_LIST_CACHE
|
||||
|
||||
return {
|
||||
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
|
||||
"ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -58,10 +58,10 @@ export const exportConfig = async (token: string) => {
|
|||
return res;
|
||||
};
|
||||
|
||||
export const getDirectConnectionsConfig = async (token: string) => {
|
||||
export const getConnectionsConfig = async (token: string) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, {
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
@ -85,10 +85,10 @@ export const getDirectConnectionsConfig = async (token: string) => {
|
|||
return res;
|
||||
};
|
||||
|
||||
export const setDirectConnectionsConfig = async (token: string, config: object) => {
|
||||
export const setConnectionsConfig = async (token: string, config: object) => {
|
||||
let error = null;
|
||||
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/direct_connections`, {
|
||||
const res = await fetch(`${WEBUI_API_BASE_URL}/configs/connections`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import { getOllamaConfig, updateOllamaConfig } from '$lib/apis/ollama';
|
||||
import { getOpenAIConfig, updateOpenAIConfig, getOpenAIModels } from '$lib/apis/openai';
|
||||
import { getModels as _getModels } from '$lib/apis';
|
||||
import { getDirectConnectionsConfig, setDirectConnectionsConfig } from '$lib/apis/configs';
|
||||
import { getConnectionsConfig, setConnectionsConfig } from '$lib/apis/configs';
|
||||
|
||||
import { config, models, settings, user } from '$lib/stores';
|
||||
|
||||
|
|
@ -43,7 +43,7 @@
|
|||
let ENABLE_OPENAI_API: null | boolean = null;
|
||||
let ENABLE_OLLAMA_API: null | boolean = null;
|
||||
|
||||
let directConnectionsConfig = null;
|
||||
let connectionsConfig = null;
|
||||
|
||||
let pipelineUrls = {};
|
||||
let showAddOpenAIConnectionModal = false;
|
||||
|
|
@ -106,15 +106,13 @@
|
|||
}
|
||||
};
|
||||
|
||||
const updateDirectConnectionsHandler = async () => {
|
||||
const res = await setDirectConnectionsConfig(localStorage.token, directConnectionsConfig).catch(
|
||||
(error) => {
|
||||
toast.error(`${error}`);
|
||||
}
|
||||
);
|
||||
const updateConnectionsHandler = async () => {
|
||||
const res = await setConnectionsConfig(localStorage.token, connectionsConfig).catch((error) => {
|
||||
toast.error(`${error}`);
|
||||
});
|
||||
|
||||
if (res) {
|
||||
toast.success($i18n.t('Direct Connections settings updated'));
|
||||
toast.success($i18n.t('Connections settings updated'));
|
||||
await models.set(await getModels());
|
||||
}
|
||||
};
|
||||
|
|
@ -150,7 +148,7 @@
|
|||
openaiConfig = await getOpenAIConfig(localStorage.token);
|
||||
})(),
|
||||
(async () => {
|
||||
directConnectionsConfig = await getDirectConnectionsConfig(localStorage.token);
|
||||
connectionsConfig = await getConnectionsConfig(localStorage.token);
|
||||
})()
|
||||
]);
|
||||
|
||||
|
|
@ -217,7 +215,7 @@
|
|||
|
||||
<form class="flex flex-col h-full justify-between text-sm" on:submit|preventDefault={submitHandler}>
|
||||
<div class=" overflow-y-scroll scrollbar-hidden h-full">
|
||||
{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && directConnectionsConfig !== null}
|
||||
{#if ENABLE_OPENAI_API !== null && ENABLE_OLLAMA_API !== null && connectionsConfig !== null}
|
||||
<div class="mb-3.5">
|
||||
<div class=" mb-2.5 text-base font-medium">{$i18n.t('General')}</div>
|
||||
|
||||
|
|
@ -368,9 +366,9 @@
|
|||
<div class="flex items-center">
|
||||
<div class="">
|
||||
<Switch
|
||||
bind:state={directConnectionsConfig.ENABLE_DIRECT_CONNECTIONS}
|
||||
bind:state={connectionsConfig.ENABLE_DIRECT_CONNECTIONS}
|
||||
on:change={async () => {
|
||||
updateDirectConnectionsHandler();
|
||||
updateConnectionsHandler();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -383,6 +381,31 @@
|
|||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr class=" border-gray-100 dark:border-gray-850 my-2" />
|
||||
|
||||
<div class="my-2">
|
||||
<div class="flex justify-between items-center text-sm">
|
||||
<div class=" text-xs font-medium">{$i18n.t('Cache Model List')}</div>
|
||||
|
||||
<div class="flex items-center">
|
||||
<div class="">
|
||||
<Switch
|
||||
bind:state={connectionsConfig.ENABLE_MODEL_LIST_CACHE}
|
||||
on:change={async () => {
|
||||
updateConnectionsHandler();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
|
||||
{$i18n.t(
|
||||
'Model List Cache allows for faster access to model information by caching it locally.'
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex h-full justify-center">
|
||||
|
|
|
|||
Loading…
Reference in a new issue