diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index abc1f53965..b863e84385 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -4,7 +4,7 @@ import mimetypes import os import shutil import asyncio - +import torch import uuid from datetime import datetime @@ -281,6 +281,14 @@ async def update_embedding_config( log.info( f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) + if request.app.state.config.RAG_EMBEDDING_ENGINE == '': + # unloads current internal embedding model and clears VRAM cache + request.app.state.ef = None + request.app.state.EMBEDDING_FUNCTION = None + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() try: request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model @@ -321,14 +329,6 @@ async def update_embedding_config( form_data.embedding_batch_size ) - # unloads current embedding model and clears VRAM cache - request.app.state.ef = None - request.app.state.EMBEDDING_FUNCTION = None - import gc - gc.collect() - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() request.app.state.ef = get_ef( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, @@ -814,6 +814,14 @@ async def update_rag_config( ) # Reranking settings + if request.app.state.config.RAG_RERANKING_ENGINE == '': + # Unloading the internal reranker and clear VRAM memory + request.app.state.rf = None + request.app.state.RERANKING_FUNCTION = None + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() request.app.state.config.RAG_RERANKING_ENGINE = ( form_data.RAG_RERANKING_ENGINE if form_data.RAG_RERANKING_ENGINE is not None @@ -843,15 +851,6 @@ async def update_rag_config( ) try: - # Unloading the reranker and clear VRAM memory. - if request.app.state.rf != None: - request.app.state.rf = None - request.app.state.RERANKING_FUNCTION = None - import gc - gc.collect() - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: request.app.state.rf = get_rf( request.app.state.config.RAG_RERANKING_ENGINE,