Correctly unloads embedding/reranker models

This commit is contained in:
Marko Henning 2025-08-20 13:30:45 +02:00
parent b3a95f40fc
commit 39fe385017
2 changed files with 40 additions and 23 deletions

View file

@ -924,14 +924,16 @@ try:
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
app.state.rf = get_rf(
app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL,
app.state.config.RAG_EXTERNAL_RERANKER_URL,
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
if ENABLE_RAG_HYBRID_SEARCH and not BYPASS_EMBEDDING_AND_RETRIEVAL:
app.state.rf = get_rf(
app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL,
app.state.config.RAG_EXTERNAL_RERANKER_URL,
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
else:
app.state.rf = None
except Exception as e:
log.error(f"Error updating models: {e}")
pass

View file

@ -321,6 +321,14 @@ async def update_embedding_config(
form_data.embedding_batch_size
)
# unloads current embedding model and clears VRAM cache
request.app.state.ef = None
request.app.state.EMBEDDING_FUNCTION = None
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
request.app.state.ef = get_ef(
request.app.state.config.RAG_EMBEDDING_ENGINE,
request.app.state.config.RAG_EMBEDDING_MODEL,
@ -653,9 +661,6 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
)
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER
@ -838,19 +843,29 @@ async def update_rag_config(
)
try:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True,
)
# Unloading the reranker and clear VRAM memory.
if request.app.state.rf != None:
request.app.state.rf = None
request.app.state.RERANKING_FUNCTION = None
import gc
gc.collect()
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and not request.app.state.BYPASS_EMBEDDING_AND_RETRIEVAL:
request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.config.RAG_EXTERNAL_RERANKER_URL,
request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
True,
)
request.app.state.RERANKING_FUNCTION = get_reranking_function(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.rf,
)
request.app.state.RERANKING_FUNCTION = get_reranking_function(
request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL,
request.app.state.rf,
)
except Exception as e:
log.error(f"Error loading reranking model: {e}")
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False