diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 5dc8d7f2a1..cbaefe1f3e 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -924,14 +924,19 @@ try: app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) - - app.state.rf = get_rf( - app.state.config.RAG_RERANKING_ENGINE, - app.state.config.RAG_RERANKING_MODEL, - app.state.config.RAG_EXTERNAL_RERANKER_URL, - app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - RAG_RERANKING_MODEL_AUTO_UPDATE, - ) + if ( + app.state.config.ENABLE_RAG_HYBRID_SEARCH + and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_ENGINE, + app.state.config.RAG_RERANKING_MODEL, + app.state.config.RAG_EXTERNAL_RERANKER_URL, + app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) + else: + app.state.rf = None except Exception as e: log.error(f"Error updating models: {e}") pass diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index c9e0b4d0c8..738f2d05fc 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -5,7 +5,6 @@ import os import shutil import asyncio - import uuid from datetime import datetime from pathlib import Path @@ -281,6 +280,18 @@ async def update_embedding_config( log.info( f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) + if request.app.state.config.RAG_EMBEDDING_ENGINE == "": + # unloads current internal embedding model and clears VRAM cache + request.app.state.ef = None + request.app.state.EMBEDDING_FUNCTION = None + import gc + + gc.collect() + if DEVICE_TYPE == "cuda": + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() try: request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model @@ -653,9 +664,6 @@ async def update_rag_config( if form_data.ENABLE_RAG_HYBRID_SEARCH is not None else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH ) - # Free up memory if hybrid search is disabled - if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf = None request.app.state.config.TOP_K_RERANKER = ( form_data.TOP_K_RERANKER @@ -809,6 +817,18 @@ async def update_rag_config( ) # Reranking settings + if request.app.state.config.RAG_RERANKING_ENGINE == "": + # Unloading the internal reranker and clear VRAM memory + request.app.state.rf = None + request.app.state.RERANKING_FUNCTION = None + import gc + + gc.collect() + if DEVICE_TYPE == "cuda": + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() request.app.state.config.RAG_RERANKING_ENGINE = ( form_data.RAG_RERANKING_ENGINE if form_data.RAG_RERANKING_ENGINE is not None @@ -838,19 +858,23 @@ async def update_rag_config( ) try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - True, - ) + if ( + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH + and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + True, + ) - request.app.state.RERANKING_FUNCTION = get_reranking_function( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.rf, - ) + request.app.state.RERANKING_FUNCTION = get_reranking_function( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.rf, + ) except Exception as e: log.error(f"Error loading reranking model: {e}") request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False