Merge pull request #16779 from mahenning/fix--clean-unload-embed/reranker-models

Fix: Free VRAM memory when updating embedding / reranking models
This commit is contained in:
Tim Jaeryang Baek 2025-08-21 21:38:37 +04:00 committed by GitHub
commit 5a66f69460
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 53 additions and 24 deletions

View file

@ -924,7 +924,10 @@ try:
app.state.config.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
) )
if (
app.state.config.ENABLE_RAG_HYBRID_SEARCH
and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
app.state.rf = get_rf( app.state.rf = get_rf(
app.state.config.RAG_RERANKING_ENGINE, app.state.config.RAG_RERANKING_ENGINE,
app.state.config.RAG_RERANKING_MODEL, app.state.config.RAG_RERANKING_MODEL,
@ -932,6 +935,8 @@ try:
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
else:
app.state.rf = None
except Exception as e: except Exception as e:
log.error(f"Error updating models: {e}") log.error(f"Error updating models: {e}")
pass pass

View file

@ -5,7 +5,6 @@ import os
import shutil import shutil
import asyncio import asyncio
import uuid import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -281,6 +280,18 @@ async def update_embedding_config(
log.info( log.info(
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
# unloads current internal embedding model and clears VRAM cache
request.app.state.ef = None
request.app.state.EMBEDDING_FUNCTION = None
import gc
gc.collect()
if DEVICE_TYPE == "cuda":
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
try: try:
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
@ -653,9 +664,6 @@ async def update_rag_config(
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
) )
# Free up memory if hybrid search is disabled
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
request.app.state.rf = None
request.app.state.config.TOP_K_RERANKER = ( request.app.state.config.TOP_K_RERANKER = (
form_data.TOP_K_RERANKER form_data.TOP_K_RERANKER
@ -809,6 +817,18 @@ async def update_rag_config(
) )
# Reranking settings # Reranking settings
if request.app.state.config.RAG_RERANKING_ENGINE == "":
# Unloading the internal reranker and clear VRAM memory
request.app.state.rf = None
request.app.state.RERANKING_FUNCTION = None
import gc
gc.collect()
if DEVICE_TYPE == "cuda":
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
request.app.state.config.RAG_RERANKING_ENGINE = ( request.app.state.config.RAG_RERANKING_ENGINE = (
form_data.RAG_RERANKING_ENGINE form_data.RAG_RERANKING_ENGINE
if form_data.RAG_RERANKING_ENGINE is not None if form_data.RAG_RERANKING_ENGINE is not None
@ -838,6 +858,10 @@ async def update_rag_config(
) )
try: try:
if (
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
):
request.app.state.rf = get_rf( request.app.state.rf = get_rf(
request.app.state.config.RAG_RERANKING_ENGINE, request.app.state.config.RAG_RERANKING_ENGINE,
request.app.state.config.RAG_RERANKING_MODEL, request.app.state.config.RAG_RERANKING_MODEL,