diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index c885d764b2..3fcca0e07b 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -139,7 +139,7 @@ async def update_task_config( async def generate_title( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -231,7 +231,7 @@ async def generate_chat_tags( content={"detail": "Tags generation is disabled"}, ) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -293,7 +293,7 @@ async def generate_chat_tags( async def generate_image_prompt( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -374,7 +374,7 @@ async def generate_queries( detail=f"Query generation is disabled", ) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -455,7 +455,7 @@ async def generate_autocompletion( detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", ) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -518,7 +518,7 @@ async def generate_emoji( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -587,7 +587,7 @@ async def generate_moa_response( request: Request, form_data: dict, user=Depends(get_verified_user) ): - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 72f8eafe3b..03b5d589c6 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -165,7 +165,7 @@ async def generate_chat_completion( if BYPASS_MODEL_ACCESS_CONTROL: bypass_filter = True - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -284,7 +284,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: await get_all_models(request) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } @@ -350,7 +350,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A if not request.app.state.MODELS: await get_all_models(request) - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index e630af6878..7195990011 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -622,7 +622,7 @@ async def process_chat_payload(request, form_data, metadata, user, model): # Initialize events to store additional event to be sent to the client # Initialize contexts and citation - if request.state.direct and request.state.model: + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, } diff --git a/src/routes/+layout.svelte b/src/routes/+layout.svelte index bef581e74b..cae8fba3a2 100644 --- a/src/routes/+layout.svelte +++ b/src/routes/+layout.svelte @@ -273,6 +273,11 @@ const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx]; try { + if (API_CONFIG?.prefix_id) { + const prefixId = API_CONFIG.prefix_id; + form_data['model'] = form_data['model'].replace(`${prefixId}.`, ``); + } + const [res, controller] = await chatCompletion( OPENAI_API_KEY, form_data,