perf: 50x performance improvement for external embeddings (#19296)

* Update utils.py (#77)

Co-authored-by: Claude <noreply@anthropic.com>

* refactor: address code review feedback for embedding performance improvements (#92)

Co-authored-by: Claude <noreply@anthropic.com>

* fix: prevent sentence transformers from blocking async event loop (#95)

Co-authored-by: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Classic298 2025-11-23 02:54:59 +01:00 committed by GitHub
parent bb3e222e09
commit 902c6cfbea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 289 additions and 118 deletions

View file

@ -3,6 +3,8 @@ import os
from typing import Optional, Union
import requests
import aiohttp
import asyncio
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
@ -27,6 +29,7 @@ from open_webui.models.notes import Notes
from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access
from open_webui.utils.headers import include_user_info_headers
from open_webui.utils.misc import get_message_list
from open_webui.retrieval.web.utils import get_web_loader
@ -87,15 +90,16 @@ class VectorSearchRetriever(BaseRetriever):
embedding_function: Any
top_k: int
def _get_relevant_documents(
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
vectors=[embedding],
limit=self.top_k,
)
@ -186,7 +190,7 @@ def get_enriched_texts(collection_result: GetResult) -> list[str]:
return enriched_texts
def query_doc_with_hybrid_search(
async def query_doc_with_hybrid_search(
collection_name: str,
collection_result: GetResult,
query: str,
@ -262,7 +266,7 @@ def query_doc_with_hybrid_search(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = await compression_retriever.ainvoke(query)
distances = [d.metadata.get("score") for d in result]
documents = [d.page_content for d in result]
@ -381,7 +385,7 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
return merge_get_results(results)
def query_collection(
async def query_collection(
collection_names: list[str],
queries: list[str],
embedding_function,
@ -406,7 +410,7 @@ def query_collection(
return None, e
# Generate all query embeddings (in one call)
query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(
f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
)
@ -433,7 +437,7 @@ def query_collection(
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
async def query_collection_with_hybrid_search(
collection_names: list[str],
queries: list[str],
embedding_function,
@ -465,9 +469,9 @@ def query_collection_with_hybrid_search(
f"Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections..."
)
def process_query(collection_name, query):
async def process_query(collection_name, query):
try:
result = query_doc_with_hybrid_search(
result = await query_doc_with_hybrid_search(
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
@ -493,9 +497,8 @@ def query_collection_with_hybrid_search(
for q in queries
]
with ThreadPoolExecutor() as executor:
future_results = [executor.submit(process_query, cn, q) for cn, q in tasks]
task_results = [future.result() for future in future_results]
# Run all queries in parallel using asyncio.gather
task_results = await asyncio.gather(*[process_query(cn, q) for cn, q in tasks])
for result, err in task_results:
if err is not None:
@ -511,6 +514,146 @@ def query_collection_with_hybrid_search(
return merge_and_sort_query_results(results, k=k)
async def generate_openai_batch_embeddings_async(
model: str,
texts: list[str],
url: str = "https://api.openai.com/v1",
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_openai_batch_embeddings_async:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession() as session:
async with session.post(f"{url}/embeddings", headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating openai batch embeddings: {e}")
return None
async def generate_azure_openai_batch_embeddings_async(
model: str,
texts: list[str],
url: str,
key: str = "",
version: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_azure_openai_batch_embeddings_async:deployment {model} batch size: {len(texts)}"
)
form_data = {"input": texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
full_url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}"
headers = {
"Content-Type": "application/json",
"api-key": key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession() as session:
async with session.post(full_url, headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "data" in data:
return [item["embedding"] for item in data["data"]]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating azure openai batch embeddings: {e}")
return None
async def generate_ollama_batch_embeddings_async(
model: str,
texts: list[str],
url: str,
key: str = "",
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
log.debug(
f"generate_ollama_batch_embeddings_async:model {model} batch size: {len(texts)}"
)
form_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession() as session:
async with session.post(f"{url}/api/embed", headers=headers, json=form_data) as r:
r.raise_for_status()
data = await r.json()
if "embeddings" in data:
return data["embeddings"]
else:
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(f"Error generating ollama batch embeddings: {e}")
return None
async def generate_multiple_async(query, prefix, user, func_async, embedding_batch_size):
if isinstance(query, list):
# Create batches
batches = [
query[i : i + embedding_batch_size]
for i in range(0, len(query), embedding_batch_size)
]
log.debug(f"generate_multiple_async: Processing {len(batches)} batches in parallel")
# Execute all batches in parallel
tasks = [
func_async(batch, prefix=prefix, user=user)
for batch in batches
]
batch_results = await asyncio.gather(*tasks)
# Flatten results
embeddings = []
for batch_embeddings in batch_results:
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
log.debug(f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches")
return embeddings
else:
return await func_async([query], prefix=prefix, user=user)
def get_embedding_function(
embedding_engine,
embedding_model,
@ -521,40 +664,58 @@ def get_embedding_function(
azure_api_version=None,
):
if embedding_engine == "":
return lambda query, prefix=None, user=None: embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
# Sentence transformers: CPU-bound sync operation
# Run in thread pool to prevent blocking async event loop
def sync_encode(query, prefix=None):
return embedding_function.encode(
query, **({"prompt": prefix} if prefix else {})
).tolist()
async def async_wrapper(query, prefix=None, user=None):
# Run CPU-bound operation in thread pool
return await asyncio.to_thread(sync_encode, query, prefix)
return async_wrapper
elif embedding_engine in ["ollama", "openai", "azure_openai"]:
func = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
# Create async function based on engine
if embedding_engine == "openai":
func_async = lambda texts, prefix=None, user=None: generate_openai_batch_embeddings_async(
model=embedding_model,
texts=texts,
url=url,
key=key,
prefix=prefix,
user=user,
)
elif embedding_engine == "azure_openai":
func_async = lambda texts, prefix=None, user=None: generate_azure_openai_batch_embeddings_async(
model=embedding_model,
texts=texts,
url=url,
key=key,
version=azure_api_version,
prefix=prefix,
user=user,
)
elif embedding_engine == "ollama":
func_async = lambda texts, prefix=None, user=None: generate_ollama_batch_embeddings_async(
model=embedding_model,
texts=texts,
url=url,
key=key,
prefix=prefix,
user=user,
)
def generate_multiple(query, prefix, user, func):
# Return async function for parallel batch processing (FastAPI-compatible)
async def embedding_wrapper(query, prefix=None, user=None):
if isinstance(query, list):
embeddings = []
for i in range(0, len(query), embedding_batch_size):
batch_embeddings = func(
query[i : i + embedding_batch_size],
prefix=prefix,
user=user,
)
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
return embeddings
return await generate_multiple_async(query, prefix, user, func_async, embedding_batch_size)
else:
return func(query, prefix, user)
result = await func_async([query], prefix=prefix, user=user)
return result[0] if result else None
return lambda query, prefix=None, user=None: generate_multiple(
query, prefix, user, func
)
return embedding_wrapper
else:
raise ValueError(f"Unknown embedding engine: {embedding_engine}")
@ -572,7 +733,7 @@ def get_reranking_function(reranking_engine, reranking_model, reranking_function
)
def get_sources_from_items(
async def get_sources_from_items(
request,
items,
queries,
@ -800,7 +961,7 @@ def get_sources_from_items(
query_result = None # Initialize to None
if hybrid_search:
try:
query_result = query_collection_with_hybrid_search(
query_result = await query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
@ -818,7 +979,7 @@ def get_sources_from_items(
# fallback to non-hybrid search
if not hybrid_search and query_result is None:
query_result = query_collection(
query_result = await query_collection(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
@ -1056,26 +1217,25 @@ def generate_embeddings(
else:
text = f"{prefix}{text}"
# Run async embedding generation synchronously
if engine == "ollama":
embeddings = generate_ollama_batch_embeddings(
**{
"model": model,
"texts": text if isinstance(text, list) else [text],
"url": url,
"key": key,
"prefix": prefix,
"user": user,
}
)
embeddings = asyncio.run(generate_ollama_batch_embeddings_async(
model=model,
texts=text if isinstance(text, list) else [text],
url=url,
key=key,
prefix=prefix,
user=user,
))
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "openai":
embeddings = generate_openai_batch_embeddings(
embeddings = asyncio.run(generate_openai_batch_embeddings_async(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
))
return embeddings[0] if isinstance(text, str) else embeddings
elif engine == "azure_openai":
azure_api_version = kwargs.get("azure_api_version", "")
embeddings = generate_azure_openai_batch_embeddings(
embeddings = asyncio.run(generate_azure_openai_batch_embeddings_async(
model,
text if isinstance(text, list) else [text],
url,
@ -1083,7 +1243,7 @@ def generate_embeddings(
azure_api_version,
prefix,
user,
)
))
return embeddings[0] if isinstance(text, str) else embeddings
@ -1104,7 +1264,7 @@ class RerankCompressor(BaseDocumentCompressor):
extra = "forbid"
arbitrary_types_allowed = True
def compress_documents(
async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
@ -1114,12 +1274,14 @@ class RerankCompressor(BaseDocumentCompressor):
scores = None
if reranking:
scores = self.reranking_function(query, documents)
scores = self.reranking_function(
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = self.embedding_function(
query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = await self.embedding_function(
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
)
scores = util.cos_sim(query_embedding, document_embedding)[0]

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
import logging
import asyncio
from typing import Optional
from open_webui.models.memories import Memories, MemoryModel
@ -17,7 +18,7 @@ router = APIRouter()
@router.get("/ef")
async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
return {"result": await request.app.state.EMBEDDING_FUNCTION("hello world")}
############################
@ -51,15 +52,17 @@ async def add_memory(
):
memory = Memories.insert_new_memory(user.id, form_data.content)
vector = await request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
)
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"vector": vector,
"metadata": {"created_at": memory.created_at},
}
],
@ -86,9 +89,11 @@ async def query_memory(
if not memories:
raise HTTPException(status_code=404, detail="No memories found for user")
vector = await request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)
results = VECTOR_DB_CLIENT.search(
collection_name=f"user-memory-{user.id}",
vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
vectors=[vector],
limit=form_data.k,
)
@ -105,21 +110,26 @@ async def reset_memory_from_vector_db(
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
# Generate vectors in parallel
vectors = await asyncio.gather(*[
request.app.state.EMBEDDING_FUNCTION(memory.content, user=user)
for memory in memories
])
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"vector": vectors[idx],
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,
},
}
for memory in memories
for idx, memory in enumerate(memories)
],
)
@ -164,15 +174,17 @@ async def update_memory_by_id(
raise HTTPException(status_code=404, detail="Memory not found")
if form_data.content is not None:
vector = await request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
)
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
{
"id": memory.id,
"text": memory.content,
"vector": request.app.state.EMBEDDING_FUNCTION(
memory.content, user=user
),
"vector": vector,
"metadata": {
"created_at": memory.created_at,
"updated_at": memory.updated_at,

View file

@ -1467,11 +1467,12 @@ def save_docs_to_vector_db(
),
)
embeddings = embedding_function(
# Run async embedding in sync context
embeddings = asyncio.run(embedding_function(
list(map(lambda x: x.replace("\n", " "), texts)),
prefix=RAG_EMBEDDING_CONTENT_PREFIX,
user=user,
)
))
log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items")
items = [
@ -2262,7 +2263,7 @@ class QueryDocForm(BaseModel):
@router.post("/query/doc")
def query_doc_handler(
async def query_doc_handler(
request: Request,
form_data: QueryDocForm,
user=Depends(get_verified_user),
@ -2275,7 +2276,7 @@ def query_doc_handler(
collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get(
collection_name=form_data.collection_name
)
return query_doc_with_hybrid_search(
return await query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
collection_result=collection_results[form_data.collection_name],
query=form_data.query,
@ -2285,8 +2286,8 @@ def query_doc_handler(
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
reranking_function=(
(
lambda sentences: request.app.state.RERANKING_FUNCTION(
sentences, user=user
lambda query, documents: request.app.state.RERANKING_FUNCTION(
query, documents, user=user
)
)
if request.app.state.RERANKING_FUNCTION
@ -2307,11 +2308,12 @@ def query_doc_handler(
user=user,
)
else:
query_embedding = await request.app.state.EMBEDDING_FUNCTION(
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
)
return query_doc(
collection_name=form_data.collection_name,
query_embedding=request.app.state.EMBEDDING_FUNCTION(
form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
),
query_embedding=query_embedding,
k=form_data.k if form_data.k else request.app.state.config.TOP_K,
user=user,
)
@ -2335,7 +2337,7 @@ class QueryCollectionsForm(BaseModel):
@router.post("/query/collection")
def query_collection_handler(
async def query_collection_handler(
request: Request,
form_data: QueryCollectionsForm,
user=Depends(get_verified_user),
@ -2344,7 +2346,7 @@ def query_collection_handler(
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
form_data.hybrid is None or form_data.hybrid
):
return query_collection_with_hybrid_search(
return await query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
@ -2379,7 +2381,7 @@ def query_collection_handler(
),
)
else:
return query_collection(
return await query_collection(
collection_names=form_data.collection_names,
queries=[form_data.query],
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
@ -2461,7 +2463,7 @@ if ENV == "dev":
@router.get("/ef/{text}")
async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
return {
"result": request.app.state.EMBEDDING_FUNCTION(
"result": await request.app.state.EMBEDDING_FUNCTION(
text, prefix=RAG_EMBEDDING_QUERY_PREFIX
)
}

View file

@ -993,37 +993,32 @@ async def chat_completion_files_handler(
queries = [get_last_user_message(body["messages"])]
try:
# Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
sources = await loop.run_in_executor(
executor,
lambda: get_sources_from_items(
request=request,
items=files,
queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user
),
k=request.app.state.config.TOP_K,
reranking_function=(
(
lambda query, documents: request.app.state.RERANKING_FUNCTION(
query, documents, user=user
)
)
if request.app.state.RERANKING_FUNCTION
else None
),
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=all_full_context
or request.app.state.config.RAG_FULL_CONTEXT,
user=user,
),
)
# Directly await async get_sources_from_items (no thread needed - fully async now)
sources = await get_sources_from_items(
request=request,
items=files,
queries=queries,
embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
query, prefix=prefix, user=user
),
k=request.app.state.config.TOP_K,
reranking_function=(
(
lambda query, documents: request.app.state.RERANKING_FUNCTION(
query, documents, user=user
)
)
if request.app.state.RERANKING_FUNCTION
else None
),
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH,
full_context=all_full_context
or request.app.state.config.RAG_FULL_CONTEXT,
user=user,
)
except Exception as e:
log.exception(e)