mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-22 17:25:25 +00:00
feat: custom model base model fallback
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
parent
60c93b4ccc
commit
b35aeb8f46
4 changed files with 39 additions and 4 deletions
|
|
@ -538,6 +538,10 @@ if LICENSE_PUBLIC_KEY:
|
||||||
# MODELS
|
# MODELS
|
||||||
####################################
|
####################################
|
||||||
|
|
||||||
|
ENABLE_CUSTOM_MODEL_FALLBACK = (
|
||||||
|
os.environ.get("ENABLE_CUSTOM_MODEL_FALLBACK", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
|
MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1")
|
||||||
if MODELS_CACHE_TTL == "":
|
if MODELS_CACHE_TTL == "":
|
||||||
MODELS_CACHE_TTL = None
|
MODELS_CACHE_TTL = None
|
||||||
|
|
|
||||||
|
|
@ -439,6 +439,7 @@ from open_webui.config import (
|
||||||
reset_config,
|
reset_config,
|
||||||
)
|
)
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
ENABLE_CUSTOM_MODEL_FALLBACK,
|
||||||
LICENSE_KEY,
|
LICENSE_KEY,
|
||||||
AUDIT_EXCLUDED_PATHS,
|
AUDIT_EXCLUDED_PATHS,
|
||||||
AUDIT_LOG_LEVEL,
|
AUDIT_LOG_LEVEL,
|
||||||
|
|
@ -1539,6 +1540,7 @@ async def chat_completion(
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
try:
|
try:
|
||||||
|
model_info = None
|
||||||
if not model_item.get("direct", False):
|
if not model_item.get("direct", False):
|
||||||
if model_id not in request.app.state.MODELS:
|
if model_id not in request.app.state.MODELS:
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
@ -1556,7 +1558,6 @@ async def chat_completion(
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
model = model_item
|
model = model_item
|
||||||
model_info = None
|
|
||||||
|
|
||||||
request.state.direct = True
|
request.state.direct = True
|
||||||
request.state.model = model
|
request.state.model = model
|
||||||
|
|
@ -1565,6 +1566,26 @@ async def chat_completion(
|
||||||
model_info.params.model_dump() if model_info and model_info.params else {}
|
model_info.params.model_dump() if model_info and model_info.params else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check base model existence for custom models
|
||||||
|
if model_info_params.get("base_model_id"):
|
||||||
|
base_model_id = model_info_params.get("base_model_id")
|
||||||
|
if base_model_id not in request.app.state.MODELS:
|
||||||
|
if ENABLE_CUSTOM_MODEL_FALLBACK:
|
||||||
|
default_models = (
|
||||||
|
request.app.state.config.DEFAULT_MODELS or ""
|
||||||
|
).split(",")
|
||||||
|
|
||||||
|
fallback_model_id = (
|
||||||
|
default_models[0].strip() if default_models[0] else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if fallback_model_id:
|
||||||
|
request.base_model_id = fallback_model_id
|
||||||
|
else:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
else:
|
||||||
|
raise Exception("Model not found")
|
||||||
|
|
||||||
# Chat Params
|
# Chat Params
|
||||||
stream_delta_chunk_size = form_data.get("params", {}).get(
|
stream_delta_chunk_size = form_data.get("params", {}).get(
|
||||||
"stream_delta_chunk_size"
|
"stream_delta_chunk_size"
|
||||||
|
|
|
||||||
|
|
@ -1278,7 +1278,12 @@ async def generate_chat_completion(
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
payload["model"] = model_info.base_model_id
|
base_model_id = (
|
||||||
|
request.base_model_id
|
||||||
|
if hasattr(request, "base_model_id")
|
||||||
|
else model_info.base_model_id
|
||||||
|
) # Use request's base_model_id if available
|
||||||
|
payload["model"] = base_model_id
|
||||||
|
|
||||||
params = model_info.params.model_dump()
|
params = model_info.params.model_dump()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -812,8 +812,13 @@ async def generate_chat_completion(
|
||||||
# Check model info and override the payload
|
# Check model info and override the payload
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
payload["model"] = model_info.base_model_id
|
base_model_id = (
|
||||||
model_id = model_info.base_model_id
|
request.base_model_id
|
||||||
|
if hasattr(request, "base_model_id")
|
||||||
|
else model_info.base_model_id
|
||||||
|
) # Use request's base_model_id if available
|
||||||
|
payload["model"] = base_model_id
|
||||||
|
model_id = base_model_id
|
||||||
|
|
||||||
params = model_info.params.model_dump()
|
params = model_info.params.model_dump()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue