mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
Merge pull request #16779 from mahenning/fix--clean-unload-embed/reranker-models
Fix: Free VRAM memory when updating embedding / reranking models
This commit is contained in:
commit
5a66f69460
2 changed files with 53 additions and 24 deletions
|
|
@ -924,7 +924,10 @@ try:
|
|||
app.state.config.RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
if (
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
app.state.rf = get_rf(
|
||||
app.state.config.RAG_RERANKING_ENGINE,
|
||||
app.state.config.RAG_RERANKING_MODEL,
|
||||
|
|
@ -932,6 +935,8 @@ try:
|
|||
app.state.config.RAG_EXTERNAL_RERANKER_API_KEY,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
else:
|
||||
app.state.rf = None
|
||||
except Exception as e:
|
||||
log.error(f"Error updating models: {e}")
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import os
|
|||
import shutil
|
||||
import asyncio
|
||||
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
|
@ -281,6 +280,18 @@ async def update_embedding_config(
|
|||
log.info(
|
||||
f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
if request.app.state.config.RAG_EMBEDDING_ENGINE == "":
|
||||
# unloads current internal embedding model and clears VRAM cache
|
||||
request.app.state.ef = None
|
||||
request.app.state.EMBEDDING_FUNCTION = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if DEVICE_TYPE == "cuda":
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
|
@ -653,9 +664,6 @@ async def update_rag_config(
|
|||
if form_data.ENABLE_RAG_HYBRID_SEARCH is not None
|
||||
else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
)
|
||||
# Free up memory if hybrid search is disabled
|
||||
if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
request.app.state.rf = None
|
||||
|
||||
request.app.state.config.TOP_K_RERANKER = (
|
||||
form_data.TOP_K_RERANKER
|
||||
|
|
@ -809,6 +817,18 @@ async def update_rag_config(
|
|||
)
|
||||
|
||||
# Reranking settings
|
||||
if request.app.state.config.RAG_RERANKING_ENGINE == "":
|
||||
# Unloading the internal reranker and clear VRAM memory
|
||||
request.app.state.rf = None
|
||||
request.app.state.RERANKING_FUNCTION = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
if DEVICE_TYPE == "cuda":
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
request.app.state.config.RAG_RERANKING_ENGINE = (
|
||||
form_data.RAG_RERANKING_ENGINE
|
||||
if form_data.RAG_RERANKING_ENGINE is not None
|
||||
|
|
@ -838,6 +858,10 @@ async def update_rag_config(
|
|||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
request.app.state.config.ENABLE_RAG_HYBRID_SEARCH
|
||||
and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL
|
||||
):
|
||||
request.app.state.rf = get_rf(
|
||||
request.app.state.config.RAG_RERANKING_ENGINE,
|
||||
request.app.state.config.RAG_RERANKING_MODEL,
|
||||
|
|
|
|||
Loading…
Reference in a new issue