diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 600c33afa1..6080337250 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1441,6 +1441,9 @@ def process_file( form_data: ProcessFileForm, user=Depends(get_verified_user), ): + """ + Process a file and save its content to the vector database. + """ if user.role == "admin": file = Files.get_file_by_id(form_data.file_id) else: @@ -1667,7 +1670,7 @@ class ProcessTextForm(BaseModel): @router.post("/process/text") -def process_text( +async def process_text( request: Request, form_data: ProcessTextForm, user=Depends(get_verified_user), @@ -1685,7 +1688,9 @@ def process_text( text_content = form_data.content log.debug(f"text_content: {text_content}") - result = save_docs_to_vector_db(request, docs, collection_name, user=user) + result = await run_in_threadpool( + save_docs_to_vector_db, request, docs, collection_name, user + ) if result: return { "status": True, @@ -1701,7 +1706,7 @@ def process_text( @router.post("/process/youtube") @router.post("/process/web") -def process_web( +async def process_web( request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) ): try: @@ -1709,16 +1714,14 @@ def process_web( if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] - content, docs = get_content_from_url(request, form_data.url) + content, docs = await run_in_threadpool( + get_content_from_url, request, form_data.url + ) log.debug(f"text_content: {content}") if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: - save_docs_to_vector_db( - request, - docs, - collection_name, - overwrite=True, - user=user, + await run_in_threadpool( + save_docs_to_vector_db, request, docs, collection_name, True, user ) else: collection_name = None @@ -2405,7 +2408,7 @@ class BatchProcessFilesResponse(BaseModel): @router.post("/process/files/batch") -def process_files_batch( +async def process_files_batch( request: Request, form_data: BatchProcessFilesForm, user=Depends(get_verified_user), @@ -2460,12 +2463,8 @@ def process_files_batch( # Save all documents in one batch if all_docs: try: - save_docs_to_vector_db( - request=request, - docs=all_docs, - collection_name=collection_name, - add=True, - user=user, + await run_in_threadpool( + save_docs_to_vector_db, request, all_docs, collection_name, True, user ) # Update all files with collection name diff --git a/src/lib/apis/retrieval/index.ts b/src/lib/apis/retrieval/index.ts index 6df927fec6..5cb0f60a72 100644 --- a/src/lib/apis/retrieval/index.ts +++ b/src/lib/apis/retrieval/index.ts @@ -295,42 +295,6 @@ export interface SearchDocument { filenames: string[]; } -export const processFile = async ( - token: string, - file_id: string, - collection_name: string | null = null -) => { - let error = null; - - const res = await fetch(`${RETRIEVAL_API_BASE_URL}/process/file`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - file_id: file_id, - collection_name: collection_name ? collection_name : undefined - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .catch((err) => { - error = err.detail; - console.error(err); - return null; - }); - - if (error) { - throw error; - } - - return res; -}; - export const processYoutubeVideo = async (token: string, url: string) => { let error = null;