diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index d5b89c8d50..43ab357538 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -924,14 +924,16 @@ try: app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) - - app.state.rf = get_rf( - app.state.config.RAG_RERANKING_ENGINE, - app.state.config.RAG_RERANKING_MODEL, - app.state.config.RAG_EXTERNAL_RERANKER_URL, - app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - RAG_RERANKING_MODEL_AUTO_UPDATE, - ) + if ENABLE_RAG_HYBRID_SEARCH and not BYPASS_EMBEDDING_AND_RETRIEVAL: + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_ENGINE, + app.state.config.RAG_RERANKING_MODEL, + app.state.config.RAG_EXTERNAL_RERANKER_URL, + app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) + else: + app.state.rf = None except Exception as e: log.error(f"Error updating models: {e}") pass diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 4a0d327c0b..a6a0f05da0 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -321,6 +321,14 @@ async def update_embedding_config( form_data.embedding_batch_size ) + # unloads current embedding model and clears VRAM cache + request.app.state.ef = None + request.app.state.EMBEDDING_FUNCTION = None + import gc + gc.collect() + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() request.app.state.ef = get_ef( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, @@ -653,9 +661,6 @@ async def update_rag_config( if form_data.ENABLE_RAG_HYBRID_SEARCH is not None else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH ) - # Free up memory if hybrid search is disabled - if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf = None request.app.state.config.TOP_K_RERANKER = ( form_data.TOP_K_RERANKER @@ -838,19 +843,29 @@ async def update_rag_config( ) try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - True, - ) + # Unloading the reranker and clear VRAM memory. + if request.app.state.rf != None: + request.app.state.rf = None + request.app.state.RERANKING_FUNCTION = None + import gc + gc.collect() + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.BYPASS_EMBEDDING_AND_RETRIEVAL: + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + True, + ) - request.app.state.RERANKING_FUNCTION = get_reranking_function( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.rf, - ) + request.app.state.RERANKING_FUNCTION = get_reranking_function( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.rf, + ) except Exception as e: log.error(f"Error loading reranking model: {e}") request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False