diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index b863e84385..51a81b1fd7 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -4,7 +4,6 @@ import mimetypes import os import shutil import asyncio -import torch import uuid from datetime import datetime @@ -287,8 +286,10 @@ async def update_embedding_config( request.app.state.EMBEDDING_FUNCTION = None import gc gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if DEVICE_TYPE == 'cuda': + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() try: request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model @@ -820,8 +821,10 @@ async def update_rag_config( request.app.state.RERANKING_FUNCTION = None import gc gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if DEVICE_TYPE == 'cuda': + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() request.app.state.config.RAG_RERANKING_ENGINE = ( form_data.RAG_RERANKING_ENGINE if form_data.RAG_RERANKING_ENGINE is not None