feat: custom model base model fallback

Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-12-21 20:22:37 +04:00
parent 60c93b4ccc
commit b35aeb8f46
4 changed files with 39 additions and 4 deletions

View file

@ -538,6 +538,10 @@ if LICENSE_PUBLIC_KEY:
# MODELS
####################################
ENABLE_CUSTOM_MODEL_FALLBACK = (
os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true"
)
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
if MODELS_CACHE_TTL == "":
MODELS_CACHE_TTL = None

View file

@ -439,6 +439,7 @@ from open_webui.config import (
reset_config,
)
from open_webui.env import (
ENABLE_CUSTOM_MODEL_FALLBACK,
LICENSE_KEY,
AUDIT_EXCLUDED_PATHS,
AUDIT_LOG_LEVEL,
@ -1539,6 +1540,7 @@ async def chat_completion(
metadata = {}
try:
model_info = None
if not model_item.get("direct", False):
if model_id not in request.app.state.MODELS:
raise Exception("Model not found")
@ -1556,7 +1558,6 @@ async def chat_completion(
raise e
else:
model = model_item
model_info = None
request.state.direct = True
request.state.model = model
@ -1565,6 +1566,26 @@ async def chat_completion(
model_info.params.model_dump() if model_info and model_info.params else {}
)
# Check base model existence for custom models
if model_info_params.get("base_model_id"):
base_model_id = model_info_params.get("base_model_id")
if base_model_id not in request.app.state.MODELS:
if ENABLE_CUSTOM_MODEL_FALLBACK:
default_models = (
request.app.state.config.DEFAULT_MODELS or ""
).split(",")
fallback_model_id = (
default_models[0].strip() if default_models[0] else None
)
if fallback_model_id:
request.base_model_id = fallback_model_id
else:
raise Exception("Model not found")
else:
raise Exception("Model not found")
# Chat Params
stream_delta_chunk_size = form_data.get("params", {}).get(
"stream_delta_chunk_size"

View file

@ -1278,7 +1278,12 @@ async def generate_chat_completion(
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
base_model_id = (
request.base_model_id
if hasattr(request, "base_model_id")
else model_info.base_model_id
) # Use request's base_model_id if available
payload["model"] = base_model_id
params = model_info.params.model_dump()

View file

@ -812,8 +812,13 @@ async def generate_chat_completion(
# Check model info and override the payload
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_id = model_info.base_model_id
base_model_id = (
request.base_model_id
if hasattr(request, "base_model_id")
else model_info.base_model_id
) # Use request's base_model_id if available
payload["model"] = base_model_id
model_id = base_model_id
params = model_info.params.model_dump()