mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-22 17:25:25 +00:00
feat: custom model base model fallback
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
parent
60c93b4ccc
commit
b35aeb8f46
4 changed files with 39 additions and 4 deletions
|
|
@ -538,6 +538,10 @@ if LICENSE_PUBLIC_KEY:
|
|||
# MODELS
|
||||
####################################
|
||||
|
||||
ENABLE_CUSTOM_MODEL_FALLBACK = (
|
||||
os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true"
|
||||
)
|
||||
|
||||
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
|
||||
if MODELS_CACHE_TTL == "":
|
||||
MODELS_CACHE_TTL = None
|
||||
|
|
|
|||
|
|
@ -439,6 +439,7 @@ from open_webui.config import (
|
|||
reset_config,
|
||||
)
|
||||
from open_webui.env import (
|
||||
ENABLE_CUSTOM_MODEL_FALLBACK,
|
||||
LICENSE_KEY,
|
||||
AUDIT_EXCLUDED_PATHS,
|
||||
AUDIT_LOG_LEVEL,
|
||||
|
|
@ -1539,6 +1540,7 @@ async def chat_completion(
|
|||
|
||||
metadata = {}
|
||||
try:
|
||||
model_info = None
|
||||
if not model_item.get("direct", False):
|
||||
if model_id not in request.app.state.MODELS:
|
||||
raise Exception("Model not found")
|
||||
|
|
@ -1556,7 +1558,6 @@ async def chat_completion(
|
|||
raise e
|
||||
else:
|
||||
model = model_item
|
||||
model_info = None
|
||||
|
||||
request.state.direct = True
|
||||
request.state.model = model
|
||||
|
|
@ -1565,6 +1566,26 @@ async def chat_completion(
|
|||
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||
)
|
||||
|
||||
# Check base model existence for custom models
|
||||
if model_info_params.get("base_model_id"):
|
||||
base_model_id = model_info_params.get("base_model_id")
|
||||
if base_model_id not in request.app.state.MODELS:
|
||||
if ENABLE_CUSTOM_MODEL_FALLBACK:
|
||||
default_models = (
|
||||
request.app.state.config.DEFAULT_MODELS or ""
|
||||
).split(",")
|
||||
|
||||
fallback_model_id = (
|
||||
default_models[0].strip() if default_models[0] else None
|
||||
)
|
||||
|
||||
if fallback_model_id:
|
||||
request.base_model_id = fallback_model_id
|
||||
else:
|
||||
raise Exception("Model not found")
|
||||
else:
|
||||
raise Exception("Model not found")
|
||||
|
||||
# Chat Params
|
||||
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||
"stream_delta_chunk_size"
|
||||
|
|
|
|||
|
|
@ -1278,7 +1278,12 @@ async def generate_chat_completion(
|
|||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
base_model_id = (
|
||||
request.base_model_id
|
||||
if hasattr(request, "base_model_id")
|
||||
else model_info.base_model_id
|
||||
) # Use request's base_model_id if available
|
||||
payload["model"] = base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
|
|
|
|||
|
|
@ -812,8 +812,13 @@ async def generate_chat_completion(
|
|||
# Check model info and override the payload
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
model_id = model_info.base_model_id
|
||||
base_model_id = (
|
||||
request.base_model_id
|
||||
if hasattr(request, "base_model_id")
|
||||
else model_info.base_model_id
|
||||
) # Use request's base_model_id if available
|
||||
payload["model"] = base_model_id
|
||||
model_id = base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue