refac: ENABLE_MODEL_LIST_CACHE -> ENABLE_BASE_MODELS_CACHE

This commit is contained in:
Timothy Jaeryang Baek 2025-06-30 13:27:07 +04:00
parent eac2f36f4f
commit 8a334decf6
5 changed files with 209 additions and 207 deletions

View file

@ -931,13 +931,13 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
#################################### ####################################
# MODEL_LIST # MODELS
#################################### ####################################
ENABLE_MODEL_LIST_CACHE = PersistentConfig( ENABLE_BASE_MODELS_CACHE = PersistentConfig(
"ENABLE_MODEL_LIST_CACHE", "ENABLE_BASE_MODELS_CACHE",
"models.cache", "models.base_models_cache",
os.environ.get("ENABLE_MODEL_LIST_CACHE", "False").lower() == "true", os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true",
) )

View file

@ -117,7 +117,7 @@ from open_webui.config import (
# Direct Connections # Direct Connections
ENABLE_DIRECT_CONNECTIONS, ENABLE_DIRECT_CONNECTIONS,
# Model list # Model list
ENABLE_MODEL_LIST_CACHE, ENABLE_BASE_MODELS_CACHE,
# Thread pool size for FastAPI/AnyIO # Thread pool size for FastAPI/AnyIO
THREAD_POOL_SIZE, THREAD_POOL_SIZE,
# Tool Server Configs # Tool Server Configs
@ -537,7 +537,7 @@ async def lifespan(app: FastAPI):
asyncio.create_task(periodic_usage_pool_cleanup()) asyncio.create_task(periodic_usage_pool_cleanup())
if app.state.config.ENABLE_MODEL_LIST_CACHE: if app.state.config.ENABLE_BASE_MODELS_CACHE:
await get_all_models( await get_all_models(
Request( Request(
# Creating a mock request object to pass to get_all_models # Creating a mock request object to pass to get_all_models
@ -643,11 +643,12 @@ app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS
######################################## ########################################
# #
# MODEL LIST # MODELS
# #
######################################## ########################################
app.state.config.ENABLE_MODEL_LIST_CACHE = ENABLE_MODEL_LIST_CACHE app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE
app.state.BASE_MODELS = []
######################################## ########################################
# #

View file

@ -45,14 +45,14 @@ async def export_config(user=Depends(get_admin_user)):
class ConnectionsConfigForm(BaseModel): class ConnectionsConfigForm(BaseModel):
ENABLE_DIRECT_CONNECTIONS: bool ENABLE_DIRECT_CONNECTIONS: bool
ENABLE_MODEL_LIST_CACHE: bool ENABLE_BASE_MODELS_CACHE: bool
@router.get("/connections", response_model=ConnectionsConfigForm) @router.get("/connections", response_model=ConnectionsConfigForm)
async def get_connections_config(request: Request, user=Depends(get_admin_user)): async def get_connections_config(request: Request, user=Depends(get_admin_user)):
return { return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
"ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE, "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
} }
@ -65,11 +65,13 @@ async def set_connections_config(
request.app.state.config.ENABLE_DIRECT_CONNECTIONS = ( request.app.state.config.ENABLE_DIRECT_CONNECTIONS = (
form_data.ENABLE_DIRECT_CONNECTIONS form_data.ENABLE_DIRECT_CONNECTIONS
) )
request.app.state.config.ENABLE_MODEL_LIST_CACHE = form_data.ENABLE_MODEL_LIST_CACHE request.app.state.config.ENABLE_BASE_MODELS_CACHE = (
form_data.ENABLE_BASE_MODELS_CACHE
)
return { return {
"ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS,
"ENABLE_MODEL_LIST_CACHE": request.app.state.config.ENABLE_MODEL_LIST_CACHE, "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE,
} }

View file

@ -77,176 +77,166 @@ async def get_all_base_models(request: Request, user: UserModel = None):
async def get_all_models(request, refresh: bool = False, user: UserModel = None): async def get_all_models(request, refresh: bool = False, user: UserModel = None):
if request.app.state.MODELS and ( if (
request.app.state.config.ENABLE_MODEL_LIST_CACHE and not refresh request.app.state.MODELS
and request.app.state.BASE_MODELS
and (request.app.state.config.ENABLE_BASE_MODELS_CACHE and not refresh)
): ):
return list(request.app.state.MODELS.values()) models = request.app.state.BASE_MODELS
else: else:
models = await get_all_base_models(request, user=user) models = await get_all_base_models(request, user=user)
request.app.state.BASE_MODELS = models
# If there are no models, return an empty list # If there are no models, return an empty list
if len(models) == 0: if len(models) == 0:
return [] return []
# Add arena models # Add arena models
if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS: if request.app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
arena_models = [] arena_models = []
if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0: if len(request.app.state.config.EVALUATION_ARENA_MODELS) > 0:
arena_models = [ arena_models = [
{ {
"id": model["id"], "id": model["id"],
"name": model["name"], "name": model["name"],
"info": { "info": {
"meta": model["meta"], "meta": model["meta"],
}, },
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "arena", "owned_by": "arena",
"arena": True, "arena": True,
} }
for model in request.app.state.config.EVALUATION_ARENA_MODELS for model in request.app.state.config.EVALUATION_ARENA_MODELS
] ]
else: else:
# Add default arena model # Add default arena model
arena_models = [ arena_models = [
{ {
"id": DEFAULT_ARENA_MODEL["id"], "id": DEFAULT_ARENA_MODEL["id"],
"name": DEFAULT_ARENA_MODEL["name"], "name": DEFAULT_ARENA_MODEL["name"],
"info": { "info": {
"meta": DEFAULT_ARENA_MODEL["meta"], "meta": DEFAULT_ARENA_MODEL["meta"],
}, },
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "arena", "owned_by": "arena",
"arena": True, "arena": True,
} }
] ]
models = models + arena_models models = models + arena_models
global_action_ids = [ global_action_ids = [
function.id for function in Functions.get_global_action_functions() function.id for function in Functions.get_global_action_functions()
] ]
enabled_action_ids = [ enabled_action_ids = [
function.id function.id
for function in Functions.get_functions_by_type("action", active_only=True) for function in Functions.get_functions_by_type("action", active_only=True)
] ]
global_filter_ids = [ global_filter_ids = [
function.id for function in Functions.get_global_filter_functions() function.id for function in Functions.get_global_filter_functions()
] ]
enabled_filter_ids = [ enabled_filter_ids = [
function.id function.id
for function in Functions.get_functions_by_type("filter", active_only=True) for function in Functions.get_functions_by_type("filter", active_only=True)
] ]
custom_models = Models.get_all_models() custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id is None: if custom_model.base_model_id is None:
for model in models: for model in models:
if custom_model.id == model["id"] or ( if custom_model.id == model["id"] or (
model.get("owned_by") == "ollama" model.get("owned_by") == "ollama"
and custom_model.id and custom_model.id
== model["id"].split(":")[ == model["id"].split(":")[
0 0
] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b') ] # Ollama may return model ids in different formats (e.g., 'llama3' vs. 'llama3:7b')
): ):
if custom_model.is_active: if custom_model.is_active:
model["name"] = custom_model.name model["name"] = custom_model.name
model["info"] = custom_model.model_dump() model["info"] = custom_model.model_dump()
# Set action_ids and filter_ids # Set action_ids and filter_ids
action_ids = [] action_ids = []
filter_ids = [] filter_ids = []
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
action_ids.extend( action_ids.extend(
model["info"]["meta"].get("actionIds", []) model["info"]["meta"].get("actionIds", [])
) )
filter_ids.extend( filter_ids.extend(
model["info"]["meta"].get("filterIds", []) model["info"]["meta"].get("filterIds", [])
) )
model["action_ids"] = action_ids model["action_ids"] = action_ids
model["filter_ids"] = filter_ids model["filter_ids"] = filter_ids
else: else:
models.remove(model) models.remove(model)
elif custom_model.is_active and ( elif custom_model.is_active and (
custom_model.id not in [model["id"] for model in models] custom_model.id not in [model["id"] for model in models]
): ):
owned_by = "openai" owned_by = "openai"
pipe = None pipe = None
action_ids = [] action_ids = []
filter_ids = [] filter_ids = []
for model in models: for model in models:
if ( if (
custom_model.base_model_id == model["id"] custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0] or custom_model.base_model_id == model["id"].split(":")[0]
): ):
owned_by = model.get("owned_by", "unknown owner") owned_by = model.get("owned_by", "unknown owner")
if "pipe" in model: if "pipe" in model:
pipe = model["pipe"] pipe = model["pipe"]
break break
if custom_model.meta: if custom_model.meta:
meta = custom_model.meta.model_dump() meta = custom_model.meta.model_dump()
if "actionIds" in meta: if "actionIds" in meta:
action_ids.extend(meta["actionIds"]) action_ids.extend(meta["actionIds"])
if "filterIds" in meta: if "filterIds" in meta:
filter_ids.extend(meta["filterIds"]) filter_ids.extend(meta["filterIds"])
models.append( models.append(
{ {
"id": f"{custom_model.id}", "id": f"{custom_model.id}",
"name": custom_model.name, "name": custom_model.name,
"object": "model", "object": "model",
"created": custom_model.created_at, "created": custom_model.created_at,
"owned_by": owned_by, "owned_by": owned_by,
"info": custom_model.model_dump(), "info": custom_model.model_dump(),
"preset": True, "preset": True,
**({"pipe": pipe} if pipe is not None else {}), **({"pipe": pipe} if pipe is not None else {}),
"action_ids": action_ids, "action_ids": action_ids,
"filter_ids": filter_ids, "filter_ids": filter_ids,
} }
) )
# Process action_ids to get the actions # Process action_ids to get the actions
def get_action_items_from_module(function, module): def get_action_items_from_module(function, module):
actions = [] actions = []
if hasattr(module, "actions"): if hasattr(module, "actions"):
actions = module.actions actions = module.actions
return [ return [
{ {
"id": f"{function.id}.{action['id']}", "id": f"{function.id}.{action['id']}",
"name": action.get("name", f"{function.name} ({action['id']})"), "name": action.get("name", f"{function.name} ({action['id']})"),
"description": function.meta.description, "description": function.meta.description,
"icon": action.get( "icon": action.get(
"icon_url", "icon_url",
function.meta.manifest.get("icon_url", None) function.meta.manifest.get("icon_url", None)
or getattr(module, "icon_url", None)
or getattr(module, "icon", None),
),
}
for action in actions
]
else:
return [
{
"id": function.id,
"name": function.name,
"description": function.meta.description,
"icon": function.meta.manifest.get("icon_url", None)
or getattr(module, "icon_url", None) or getattr(module, "icon_url", None)
or getattr(module, "icon", None), or getattr(module, "icon", None),
} ),
] }
for action in actions
# Process filter_ids to get the filters ]
def get_filter_items_from_module(function, module): else:
return [ return [
{ {
"id": function.id, "id": function.id,
@ -258,54 +248,63 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
} }
] ]
def get_function_module_by_id(function_id): # Process filter_ids to get the filters
function_module, _, _ = get_function_module_from_cache(request, function_id) def get_filter_items_from_module(function, module):
return function_module return [
{
"id": function.id,
"name": function.name,
"description": function.meta.description,
"icon": function.meta.manifest.get("icon_url", None)
or getattr(module, "icon_url", None)
or getattr(module, "icon", None),
}
]
for model in models: def get_function_module_by_id(function_id):
action_ids = [ function_module, _, _ = get_function_module_from_cache(request, function_id)
action_id return function_module
for action_id in list(
set(model.pop("action_ids", []) + global_action_ids)
)
if action_id in enabled_action_ids
]
filter_ids = [
filter_id
for filter_id in list(
set(model.pop("filter_ids", []) + global_filter_ids)
)
if filter_id in enabled_filter_ids
]
model["actions"] = [] for model in models:
for action_id in action_ids: action_ids = [
action_function = Functions.get_function_by_id(action_id) action_id
if action_function is None: for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
raise Exception(f"Action not found: {action_id}") if action_id in enabled_action_ids
]
filter_ids = [
filter_id
for filter_id in list(set(model.pop("filter_ids", []) + global_filter_ids))
if filter_id in enabled_filter_ids
]
function_module = get_function_module_by_id(action_id) model["actions"] = []
model["actions"].extend( for action_id in action_ids:
get_action_items_from_module(action_function, function_module) action_function = Functions.get_function_by_id(action_id)
if action_function is None:
raise Exception(f"Action not found: {action_id}")
function_module = get_function_module_by_id(action_id)
model["actions"].extend(
get_action_items_from_module(action_function, function_module)
)
model["filters"] = []
for filter_id in filter_ids:
filter_function = Functions.get_function_by_id(filter_id)
if filter_function is None:
raise Exception(f"Filter not found: {filter_id}")
function_module = get_function_module_by_id(filter_id)
if getattr(function_module, "toggle", None):
model["filters"].extend(
get_filter_items_from_module(filter_function, function_module)
) )
model["filters"] = [] log.debug(f"get_all_models() returned {len(models)} models")
for filter_id in filter_ids:
filter_function = Functions.get_function_by_id(filter_id)
if filter_function is None:
raise Exception(f"Filter not found: {filter_id}")
function_module = get_function_module_by_id(filter_id) request.app.state.MODELS = {model["id"]: model for model in models}
return models
if getattr(function_module, "toggle", None):
model["filters"].extend(
get_filter_items_from_module(filter_function, function_module)
)
log.debug(f"get_all_models() returned {len(models)} models")
request.app.state.MODELS = {model["id"]: model for model in models}
return models
def check_model_access(user, model): def check_model_access(user, model):

View file

@ -386,12 +386,12 @@
<div class="my-2"> <div class="my-2">
<div class="flex justify-between items-center text-sm"> <div class="flex justify-between items-center text-sm">
<div class=" text-xs font-medium">{$i18n.t('Cache Model List')}</div> <div class=" text-xs font-medium">{$i18n.t('Cache Base Model List')}</div>
<div class="flex items-center"> <div class="flex items-center">
<div class=""> <div class="">
<Switch <Switch
bind:state={connectionsConfig.ENABLE_MODEL_LIST_CACHE} bind:state={connectionsConfig.ENABLE_BASE_MODELS_CACHE}
on:change={async () => { on:change={async () => {
updateConnectionsHandler(); updateConnectionsHandler();
}} }}
@ -402,7 +402,7 @@
<div class="mt-1 text-xs text-gray-400 dark:text-gray-500"> <div class="mt-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t( {$i18n.t(
'Model List Cache speeds up access by fetching models only at startup or on settings save—faster, but may not show recent model changes.' 'Base Model List Cache speeds up access by fetching base models only at startup or on settings save—faster, but may not show recent base model changes.'
)} )}
</div> </div>
</div> </div>