Correctly unloads embedding/reranker models

This commit is contained in:
Marko Henning 2025-08-20 13:30:45 +02:00
parent b3a95f40fc
commit 39fe385017
2 changed files with 40 additions and 23 deletions

View file

@ -924,14 +924,16 @@ try:
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
) )
if ENABLE_RAG_HYBRID_SEARCH and not BYPASS_EMBEDDING_AND_RETRIEVAL:
app.state.rf = get_rf( app.state.rf = get_rf(
app.state.config.RAG_RERANKING_ENGINE, app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL, app.state.config.RAG_RERANKING_MODEL,
app.state.config.RAG_EXTERNAL_RERANKER_URL, app.state.config.RAG_EXTERNAL_RERANKER_URL,
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
else:
app.state.rf = None
except Exception as e: except Exception as e:
log.error(f"Error updating models: {e}") log.error(f"Error updating models: {e}")
pass pass

View file

@ -321,6 +321,14 @@ async def update_embedding_config(
form_data.embedding_batch_size form_data.embedding_batch_size
) )
# unloads current embedding model and clears VRAM cache
request.app.state.ef = None
request.app.state.EMBEDDING_FUNCTION = None
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
request.app.state.ef = get_ef( request.app.state.ef = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL, request.app.state.config.RAG_EMBEDDING_MODEL,
@ -653,9 +661,6 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
) )
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
request.app.state.config.TOP_K_RERANKER = ( request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER form_data.TOP_K_RERANKER
@ -838,19 +843,29 @@ async def update_rag_config(
) )
try: try:
request.app.state.rf = get_rf( # Unloading the reranker and clear VRAM memory.
request.app.state.config.RAG_RERANKING_ENGINE, if request.app.state.rf != None:
request.app.state.config.RAG_RERANKING_MODEL, request.app.state.rf = None
request.app.state.config.RAG_EXTERNAL_RERANKER_URL, request.app.state.RERANKING_FUNCTION = None
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, import gc
True, gc.collect()
) import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.BYPASS_EMBEDDING_AND_RETRIEVAL:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True,
)
request.app.state.RERANKING_FUNCTION = get_reranking_function( request.app.state.RERANKING_FUNCTION = get_reranking_function(
request.app.state.config.RAG_RERANKING_ENGINE, request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL, request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.rf, request.app.state.rf,
) )
except Exception as e: except Exception as e:
log.error(f"Error loading reranking model: {e}") log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False