diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 8c5e3da736..4f9b297bfb 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -501,50 +501,55 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: return response return None - def merge_models_lists(model_lists): + def is_supported_openai_models(model_id): + if any( + name in model_id + for name in [ + "babbage", + "dall-e", + "davinci", + "embedding", + "tts", + "whisper", + ] + ): + return False + return True + + def get_merged_models(model_lists): log.debug(f"merge_models_lists {model_lists}") - merged_list = [] + models = {} - for idx, models in enumerate(model_lists): - if models is not None and "error" not in models: + for idx, model_list in enumerate(model_lists): + if model_list is not None and "error" not in model_list: + for model in model_list: + model_id = model.get("id") or model.get("name") - merged_list.extend( - [ - { + if ( + "api.openai.com" + in request.app.state.config.OPENAI_API_BASE_URLS[idx] + and not is_supported_openai_models(model_id) + ): + # Skip unwanted OpenAI models + continue + + if model_id and model_id not in models: + models[model_id] = { **model, - "name": model.get("name", model["id"]), + "name": model.get("name", model_id), "owned_by": "openai", "openai": model, "connection_type": model.get("connection_type", "external"), "urlIdx": idx, } - for model in models - if (model.get("id") or model.get("name")) - and ( - "api.openai.com" - not in request.app.state.config.OPENAI_API_BASE_URLS[idx] - or not any( - name in model["id"] - for name in [ - "babbage", - "dall-e", - "davinci", - "embedding", - "tts", - "whisper", - ] - ) - ) - ] - ) - return merged_list + return models - models = {"data": merge_models_lists(map(extract_data, responses))} + models = get_merged_models(map(extract_data, responses)) log.debug(f"models: {models}") - request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]} - return models + request.app.state.OPENAI_MODELS = map(extract_data, responses) + return {"data": list(models.values())} @router.get("/models")