diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 370737ba55..2c790582fd 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -3,6 +3,8 @@ import os from typing import Optional, Union import requests +import aiohttp +import asyncio import hashlib from concurrent.futures import ThreadPoolExecutor import time @@ -27,6 +29,7 @@ from open_webui.models.notes import Notes from open_webui.retrieval.vector.main import GetResult from open_webui.utils.access_control import has_access +from open_webui.utils.headers import include_user_info_headers from open_webui.utils.misc import get_message_list from open_webui.retrieval.web.utils import get_web_loader @@ -87,15 +90,16 @@ class VectorSearchRetriever(BaseRetriever): embedding_function: Any top_k: int - def _get_relevant_documents( + async def _aget_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> list[Document]: + embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) result = VECTOR_DB_CLIENT.search( collection_name=self.collection_name, - vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)], + vectors=[embedding], limit=self.top_k, ) @@ -186,7 +190,7 @@ def get_enriched_texts(collection_result: GetResult) -> list[str]: return enriched_texts -def query_doc_with_hybrid_search( +async def query_doc_with_hybrid_search( collection_name: str, collection_result: GetResult, query: str, @@ -262,7 +266,7 @@ def query_doc_with_hybrid_search( base_compressor=compressor, base_retriever=ensemble_retriever ) - result = compression_retriever.invoke(query) + result = await compression_retriever.ainvoke(query) distances = [d.metadata.get("score") for d in result] documents = [d.page_content for d in result] @@ -381,7 +385,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict: return merge_get_results(results) -def query_collection( +async def query_collection( collection_names: list[str], queries: list[str], embedding_function, @@ -406,7 +410,7 @@ def query_collection( return None, e # Generate all query embeddings (in one call) - query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) log.debug( f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" ) @@ -433,7 +437,7 @@ def query_collection( return merge_and_sort_query_results(results, k=k) -def query_collection_with_hybrid_search( +async def query_collection_with_hybrid_search( collection_names: list[str], queries: list[str], embedding_function, @@ -465,9 +469,9 @@ def query_collection_with_hybrid_search( f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..." ) - def process_query(collection_name, query): + async def process_query(collection_name, query): try: - result = query_doc_with_hybrid_search( + result = await query_doc_with_hybrid_search( collection_name=collection_name, collection_result=collection_results[collection_name], query=query, @@ -493,9 +497,8 @@ def query_collection_with_hybrid_search( for q in queries ] - with ThreadPoolExecutor() as executor: - future_results = [executor.submit(process_query, cn, q) for cn, q in tasks] - task_results = [future.result() for future in future_results] + # Run all queries in parallel using asyncio.gather + task_results = await asyncio.gather(*[process_query(cn, q) for cn, q in tasks]) for result, err in task_results: if err is not None: @@ -511,6 +514,146 @@ def query_collection_with_hybrid_search( return merge_and_sort_query_results(results, k=k) +async def generate_openai_batch_embeddings_async( + model: str, + texts: list[str], + url: str = "https://api.openai.com/v1", + key: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"generate_openai_batch_embeddings_async:model {model} batch size: {len(texts)}" + ) + form_data = {"input": texts, "model": model} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user: + headers = include_user_info_headers(headers, user) + + async with aiohttp.ClientSession() as session: + async with session.post(f"{url}/embeddings", headers=headers, json=form_data) as r: + r.raise_for_status() + data = await r.json() + if "data" in data: + return [item["embedding"] for item in data["data"]] + else: + raise Exception("Something went wrong :/") + except Exception as e: + log.exception(f"Error generating openai batch embeddings: {e}") + return None + + +async def generate_azure_openai_batch_embeddings_async( + model: str, + texts: list[str], + url: str, + key: str = "", + version: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"generate_azure_openai_batch_embeddings_async:deployment {model} batch size: {len(texts)}" + ) + form_data = {"input": texts} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" + + headers = { + "Content-Type": "application/json", + "api-key": key, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user: + headers = include_user_info_headers(headers, user) + + async with aiohttp.ClientSession() as session: + async with session.post(full_url, headers=headers, json=form_data) as r: + r.raise_for_status() + data = await r.json() + if "data" in data: + return [item["embedding"] for item in data["data"]] + else: + raise Exception("Something went wrong :/") + except Exception as e: + log.exception(f"Error generating azure openai batch embeddings: {e}") + return None + + +async def generate_ollama_batch_embeddings_async( + model: str, + texts: list[str], + url: str, + key: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"generate_ollama_batch_embeddings_async:model {model} batch size: {len(texts)}" + ) + form_data = {"input": texts, "model": model} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user: + headers = include_user_info_headers(headers, user) + + async with aiohttp.ClientSession() as session: + async with session.post(f"{url}/api/embed", headers=headers, json=form_data) as r: + r.raise_for_status() + data = await r.json() + if "embeddings" in data: + return data["embeddings"] + else: + raise Exception("Something went wrong :/") + except Exception as e: + log.exception(f"Error generating ollama batch embeddings: {e}") + return None + + +async def generate_multiple_async(query, prefix, user, func_async, embedding_batch_size): + if isinstance(query, list): + # Create batches + batches = [ + query[i : i + embedding_batch_size] + for i in range(0, len(query), embedding_batch_size) + ] + + log.debug(f"generate_multiple_async: Processing {len(batches)} batches in parallel") + + # Execute all batches in parallel + tasks = [ + func_async(batch, prefix=prefix, user=user) + for batch in batches + ] + batch_results = await asyncio.gather(*tasks) + + # Flatten results + embeddings = [] + for batch_embeddings in batch_results: + if isinstance(batch_embeddings, list): + embeddings.extend(batch_embeddings) + + log.debug(f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches") + return embeddings + else: + return await func_async([query], prefix=prefix, user=user) + + def get_embedding_function( embedding_engine, embedding_model, @@ -521,40 +664,58 @@ def get_embedding_function( azure_api_version=None, ): if embedding_engine == "": - return lambda query, prefix=None, user=None: embedding_function.encode( - query, **({"prompt": prefix} if prefix else {}) - ).tolist() + # Sentence transformers: CPU-bound sync operation + # Run in thread pool to prevent blocking async event loop + def sync_encode(query, prefix=None): + return embedding_function.encode( + query, **({"prompt": prefix} if prefix else {}) + ).tolist() + + async def async_wrapper(query, prefix=None, user=None): + # Run CPU-bound operation in thread pool + return await asyncio.to_thread(sync_encode, query, prefix) + + return async_wrapper elif embedding_engine in ["ollama", "openai", "azure_openai"]: - func = lambda query, prefix=None, user=None: generate_embeddings( - engine=embedding_engine, - model=embedding_model, - text=query, - prefix=prefix, - url=url, - key=key, - user=user, - azure_api_version=azure_api_version, - ) + # Create async function based on engine + if embedding_engine == "openai": + func_async = lambda texts, prefix=None, user=None: generate_openai_batch_embeddings_async( + model=embedding_model, + texts=texts, + url=url, + key=key, + prefix=prefix, + user=user, + ) + elif embedding_engine == "azure_openai": + func_async = lambda texts, prefix=None, user=None: generate_azure_openai_batch_embeddings_async( + model=embedding_model, + texts=texts, + url=url, + key=key, + version=azure_api_version, + prefix=prefix, + user=user, + ) + elif embedding_engine == "ollama": + func_async = lambda texts, prefix=None, user=None: generate_ollama_batch_embeddings_async( + model=embedding_model, + texts=texts, + url=url, + key=key, + prefix=prefix, + user=user, + ) - def generate_multiple(query, prefix, user, func): + # Return async function for parallel batch processing (FastAPI-compatible) + async def embedding_wrapper(query, prefix=None, user=None): if isinstance(query, list): - embeddings = [] - for i in range(0, len(query), embedding_batch_size): - batch_embeddings = func( - query[i : i + embedding_batch_size], - prefix=prefix, - user=user, - ) - - if isinstance(batch_embeddings, list): - embeddings.extend(batch_embeddings) - return embeddings + return await generate_multiple_async(query, prefix, user, func_async, embedding_batch_size) else: - return func(query, prefix, user) + result = await func_async([query], prefix=prefix, user=user) + return result[0] if result else None - return lambda query, prefix=None, user=None: generate_multiple( - query, prefix, user, func - ) + return embedding_wrapper else: raise ValueError(f"Unknown embedding engine: {embedding_engine}") @@ -572,7 +733,7 @@ def get_reranking_function(reranking_engine, reranking_model, reranking_function ) -def get_sources_from_items( +async def get_sources_from_items( request, items, queries, @@ -800,7 +961,7 @@ def get_sources_from_items( query_result = None # Initialize to None if hybrid_search: try: - query_result = query_collection_with_hybrid_search( + query_result = await query_collection_with_hybrid_search( collection_names=collection_names, queries=queries, embedding_function=embedding_function, @@ -818,7 +979,7 @@ def get_sources_from_items( # fallback to non-hybrid search if not hybrid_search and query_result is None: - query_result = query_collection( + query_result = await query_collection( collection_names=collection_names, queries=queries, embedding_function=embedding_function, @@ -1056,26 +1217,25 @@ def generate_embeddings( else: text = f"{prefix}{text}" + # Run async embedding generation synchronously if engine == "ollama": - embeddings = generate_ollama_batch_embeddings( - **{ - "model": model, - "texts": text if isinstance(text, list) else [text], - "url": url, - "key": key, - "prefix": prefix, - "user": user, - } - ) + embeddings = asyncio.run(generate_ollama_batch_embeddings_async( + model=model, + texts=text if isinstance(text, list) else [text], + url=url, + key=key, + prefix=prefix, + user=user, + )) return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": - embeddings = generate_openai_batch_embeddings( + embeddings = asyncio.run(generate_openai_batch_embeddings_async( model, text if isinstance(text, list) else [text], url, key, prefix, user - ) + )) return embeddings[0] if isinstance(text, str) else embeddings elif engine == "azure_openai": azure_api_version = kwargs.get("azure_api_version", "") - embeddings = generate_azure_openai_batch_embeddings( + embeddings = asyncio.run(generate_azure_openai_batch_embeddings_async( model, text if isinstance(text, list) else [text], url, @@ -1083,7 +1243,7 @@ def generate_embeddings( azure_api_version, prefix, user, - ) + )) return embeddings[0] if isinstance(text, str) else embeddings @@ -1104,7 +1264,7 @@ class RerankCompressor(BaseDocumentCompressor): extra = "forbid" arbitrary_types_allowed = True - def compress_documents( + async def acompress_documents( self, documents: Sequence[Document], query: str, @@ -1114,12 +1274,14 @@ class RerankCompressor(BaseDocumentCompressor): scores = None if reranking: - scores = self.reranking_function(query, documents) + scores = self.reranking_function( + [(query, doc.page_content) for doc in documents] + ) else: from sentence_transformers import util - query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) - document_embedding = self.embedding_function( + query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) + document_embedding = await self.embedding_function( [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX ) scores = util.cos_sim(query_embedding, document_embedding)[0] diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 11b3d0c96c..fb2a3a0bfd 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import BaseModel import logging +import asyncio from typing import Optional from open_webui.models.memories import Memories, MemoryModel @@ -17,7 +18,7 @@ router = APIRouter() @router.get("/ef") async def get_embeddings(request: Request): - return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")} + return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")} ############################ @@ -51,15 +52,17 @@ async def add_memory( ): memory = Memories.insert_new_memory(user.id, form_data.content) + vector = await request.app.state.EMBEDDING_FUNCTION( + memory.content, user=user + ) + VECTOR_DB_CLIENT.upsert( collection_name=f"user-memory-{user.id}", items=[ { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION( - memory.content, user=user - ), + "vector": vector, "metadata": {"created_at": memory.created_at}, } ], @@ -86,9 +89,11 @@ async def query_memory( if not memories: raise HTTPException(status_code=404, detail="No memories found for user") + vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user) + results = VECTOR_DB_CLIENT.search( collection_name=f"user-memory-{user.id}", - vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], + vectors=[vector], limit=form_data.k, ) @@ -105,21 +110,26 @@ async def reset_memory_from_vector_db( VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") memories = Memories.get_memories_by_user_id(user.id) + + # Generate vectors in parallel + vectors = await asyncio.gather(*[ + request.app.state.EMBEDDING_FUNCTION(memory.content, user=user) + for memory in memories + ]) + VECTOR_DB_CLIENT.upsert( collection_name=f"user-memory-{user.id}", items=[ { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION( - memory.content, user=user - ), + "vector": vectors[idx], "metadata": { "created_at": memory.created_at, "updated_at": memory.updated_at, }, } - for memory in memories + for idx, memory in enumerate(memories) ], ) @@ -164,15 +174,17 @@ async def update_memory_by_id( raise HTTPException(status_code=404, detail="Memory not found") if form_data.content is not None: + vector = await request.app.state.EMBEDDING_FUNCTION( + memory.content, user=user + ) + VECTOR_DB_CLIENT.upsert( collection_name=f"user-memory-{user.id}", items=[ { "id": memory.id, "text": memory.content, - "vector": request.app.state.EMBEDDING_FUNCTION( - memory.content, user=user - ), + "vector": vector, "metadata": { "created_at": memory.created_at, "updated_at": memory.updated_at, diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 80ef02caf8..28d9d11b16 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1467,11 +1467,12 @@ def save_docs_to_vector_db( ), ) - embeddings = embedding_function( + # Run async embedding in sync context + embeddings = asyncio.run(embedding_function( list(map(lambda x: x.replace("\n", " "), texts)), prefix=RAG_EMBEDDING_CONTENT_PREFIX, user=user, - ) + )) log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items") items = [ @@ -2262,7 +2263,7 @@ class QueryDocForm(BaseModel): @router.post("/query/doc") -def query_doc_handler( +async def query_doc_handler( request: Request, form_data: QueryDocForm, user=Depends(get_verified_user), @@ -2275,7 +2276,7 @@ def query_doc_handler( collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_name=form_data.collection_name ) - return query_doc_with_hybrid_search( + return await query_doc_with_hybrid_search( collection_name=form_data.collection_name, collection_result=collection_results[form_data.collection_name], query=form_data.query, @@ -2285,8 +2286,8 @@ def query_doc_handler( k=form_data.k if form_data.k else request.app.state.config.TOP_K, reranking_function=( ( - lambda sentences: request.app.state.RERANKING_FUNCTION( - sentences, user=user + lambda query, documents: request.app.state.RERANKING_FUNCTION( + query, documents, user=user ) ) if request.app.state.RERANKING_FUNCTION @@ -2307,11 +2308,12 @@ def query_doc_handler( user=user, ) else: + query_embedding = await request.app.state.EMBEDDING_FUNCTION( + form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user + ) return query_doc( collection_name=form_data.collection_name, - query_embedding=request.app.state.EMBEDDING_FUNCTION( - form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user - ), + query_embedding=query_embedding, k=form_data.k if form_data.k else request.app.state.config.TOP_K, user=user, ) @@ -2335,7 +2337,7 @@ class QueryCollectionsForm(BaseModel): @router.post("/query/collection") -def query_collection_handler( +async def query_collection_handler( request: Request, form_data: QueryCollectionsForm, user=Depends(get_verified_user), @@ -2344,7 +2346,7 @@ def query_collection_handler( if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( form_data.hybrid is None or form_data.hybrid ): - return query_collection_with_hybrid_search( + return await query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( @@ -2379,7 +2381,7 @@ def query_collection_handler( ), ) else: - return query_collection( + return await query_collection( collection_names=form_data.collection_names, queries=[form_data.query], embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( @@ -2461,7 +2463,7 @@ if ENV == "dev": @router.get("/ef/{text}") async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"): return { - "result": request.app.state.EMBEDDING_FUNCTION( + "result": await request.app.state.EMBEDDING_FUNCTION( text, prefix=RAG_EMBEDDING_QUERY_PREFIX ) } diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index eebe41d571..4a4e0ea6be 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -993,37 +993,32 @@ async def chat_completion_files_handler( queries = [get_last_user_message(body["messages"])] try: - # Offload get_sources_from_items to a separate thread - loop = asyncio.get_running_loop() - with ThreadPoolExecutor() as executor: - sources = await loop.run_in_executor( - executor, - lambda: get_sources_from_items( - request=request, - items=files, - queries=queries, - embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( - query, prefix=prefix, user=user - ), - k=request.app.state.config.TOP_K, - reranking_function=( - ( - lambda query, documents: request.app.state.RERANKING_FUNCTION( - query, documents, user=user - ) - ) - if request.app.state.RERANKING_FUNCTION - else None - ), - k_reranker=request.app.state.config.TOP_K_RERANKER, - r=request.app.state.config.RELEVANCE_THRESHOLD, - hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, - hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, - full_context=all_full_context - or request.app.state.config.RAG_FULL_CONTEXT, - user=user, - ), - ) + # Directly await async get_sources_from_items (no thread needed - fully async now) + sources = await get_sources_from_items( + request=request, + items=files, + queries=queries, + embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( + query, prefix=prefix, user=user + ), + k=request.app.state.config.TOP_K, + reranking_function=( + ( + lambda query, documents: request.app.state.RERANKING_FUNCTION( + query, documents, user=user + ) + ) + if request.app.state.RERANKING_FUNCTION + else None + ), + k_reranker=request.app.state.config.TOP_K_RERANKER, + r=request.app.state.config.RELEVANCE_THRESHOLD, + hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT, + hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, + full_context=all_full_context + or request.app.state.config.RAG_FULL_CONTEXT, + user=user, + ) except Exception as e: log.exception(e)