From 9bfc414d2686dd37482ec4db24bc526af1296e90 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 22 Nov 2025 21:33:14 -0500 Subject: [PATCH] refac --- backend/open_webui/retrieval/utils.py | 561 +++++++++++++------------- 1 file changed, 274 insertions(+), 287 deletions(-) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 2c790582fd..56065e14fc 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,6 +1,6 @@ import logging import os -from typing import Optional, Union +from typing import Awaitable, Optional, Union import requests import aiohttp @@ -410,7 +410,9 @@ async def query_collection( return None, e # Generate all query embeddings (in one call) - query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX) + query_embeddings = await embedding_function( + queries, prefix=RAG_EMBEDDING_QUERY_PREFIX + ) log.debug( f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections" ) @@ -491,14 +493,16 @@ async def query_collection_with_hybrid_search( # Prepare tasks for all collections and queries # Avoid running any tasks for collections that failed to fetch data (have assigned None) tasks = [ - (cn, q) - for cn in collection_names - if collection_results[cn] is not None - for q in queries + (collection_name, query) + for collection_name in collection_names + if collection_results[collection_name] is not None + for query in queries ] # Run all queries in parallel using asyncio.gather - task_results = await asyncio.gather(*[process_query(cn, q) for cn, q in tasks]) + task_results = await asyncio.gather( + *[process_query(collection_name, query) for collection_name, query in tasks] + ) for result, err in task_results: if err is not None: @@ -514,7 +518,7 @@ async def query_collection_with_hybrid_search( return merge_and_sort_query_results(results, k=k) -async def generate_openai_batch_embeddings_async( +def generate_openai_batch_embeddings( model: str, texts: list[str], url: str = "https://api.openai.com/v1", @@ -524,7 +528,52 @@ async def generate_openai_batch_embeddings_async( ) -> Optional[list[list[float]]]: try: log.debug( - f"generate_openai_batch_embeddings_async:model {model} batch size: {len(texts)}" + f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}" + ) + json_data = {"input": texts, "model": model} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + r = requests.post( + f"{url}/embeddings", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + json=json_data, + ) + r.raise_for_status() + data = r.json() + if "data" in data: + return [elem["embedding"] for elem in data["data"]] + else: + raise "Something went wrong :/" + except Exception as e: + log.exception(f"Error generating openai batch embeddings: {e}") + return None + + +async def agenerate_openai_batch_embeddings( + model: str, + texts: list[str], + url: str = "https://api.openai.com/v1", + key: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}" ) form_data = {"input": texts, "model": model} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): @@ -537,8 +586,10 @@ async def generate_openai_batch_embeddings_async( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession() as session: - async with session.post(f"{url}/embeddings", headers=headers, json=form_data) as r: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + f"{url}/embeddings", headers=headers, json=form_data + ) as r: r.raise_for_status() data = await r.json() if "data" in data: @@ -550,7 +601,7 @@ async def generate_openai_batch_embeddings_async( return None -async def generate_azure_openai_batch_embeddings_async( +def generate_azure_openai_batch_embeddings( model: str, texts: list[str], url: str, @@ -561,7 +612,61 @@ async def generate_azure_openai_batch_embeddings_async( ) -> Optional[list[list[float]]]: try: log.debug( - f"generate_azure_openai_batch_embeddings_async:deployment {model} batch size: {len(texts)}" + f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" + ) + json_data = {"input": texts} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" + + for _ in range(5): + r = requests.post( + url, + headers={ + "Content-Type": "application/json", + "api-key": key, + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + json=json_data, + ) + if r.status_code == 429: + retry = float(r.headers.get("Retry-After", "1")) + time.sleep(retry) + continue + r.raise_for_status() + data = r.json() + if "data" in data: + return [elem["embedding"] for elem in data["data"]] + else: + raise Exception("Something went wrong :/") + return None + except Exception as e: + log.exception(f"Error generating azure openai batch embeddings: {e}") + return None + + +async def agenerate_azure_openai_batch_embeddings( + model: str, + texts: list[str], + url: str, + key: str = "", + version: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" ) form_data = {"input": texts} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): @@ -576,7 +681,7 @@ async def generate_azure_openai_batch_embeddings_async( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(trust_env=True) as session: async with session.post(full_url, headers=headers, json=form_data) as r: r.raise_for_status() data = await r.json() @@ -589,7 +694,7 @@ async def generate_azure_openai_batch_embeddings_async( return None -async def generate_ollama_batch_embeddings_async( +def generate_ollama_batch_embeddings( model: str, texts: list[str], url: str, @@ -599,7 +704,53 @@ async def generate_ollama_batch_embeddings_async( ) -> Optional[list[list[float]]]: try: log.debug( - f"generate_ollama_batch_embeddings_async:model {model} batch size: {len(texts)}" + f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" + ) + json_data = {"input": texts, "model": model} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + r = requests.post( + f"{url}/api/embed", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {key}", + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + json=json_data, + ) + r.raise_for_status() + data = r.json() + + if "embeddings" in data: + return data["embeddings"] + else: + raise "Something went wrong :/" + except Exception as e: + log.exception(f"Error generating ollama batch embeddings: {e}") + return None + + +async def agenerate_ollama_batch_embeddings( + model: str, + texts: list[str], + url: str, + key: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" ) form_data = {"input": texts, "model": model} if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): @@ -612,8 +763,10 @@ async def generate_ollama_batch_embeddings_async( if ENABLE_FORWARD_USER_INFO_HEADERS and user: headers = include_user_info_headers(headers, user) - async with aiohttp.ClientSession() as session: - async with session.post(f"{url}/api/embed", headers=headers, json=form_data) as r: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + f"{url}/api/embed", headers=headers, json=form_data + ) as r: r.raise_for_status() data = await r.json() if "embeddings" in data: @@ -625,35 +778,6 @@ async def generate_ollama_batch_embeddings_async( return None -async def generate_multiple_async(query, prefix, user, func_async, embedding_batch_size): - if isinstance(query, list): - # Create batches - batches = [ - query[i : i + embedding_batch_size] - for i in range(0, len(query), embedding_batch_size) - ] - - log.debug(f"generate_multiple_async: Processing {len(batches)} batches in parallel") - - # Execute all batches in parallel - tasks = [ - func_async(batch, prefix=prefix, user=user) - for batch in batches - ] - batch_results = await asyncio.gather(*tasks) - - # Flatten results - embeddings = [] - for batch_embeddings in batch_results: - if isinstance(batch_embeddings, list): - embeddings.extend(batch_embeddings) - - log.debug(f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches") - return embeddings - else: - return await func_async([query], prefix=prefix, user=user) - - def get_embedding_function( embedding_engine, embedding_model, @@ -662,64 +786,117 @@ def get_embedding_function( key, embedding_batch_size, azure_api_version=None, -): +) -> Awaitable: if embedding_engine == "": # Sentence transformers: CPU-bound sync operation - # Run in thread pool to prevent blocking async event loop - def sync_encode(query, prefix=None): - return embedding_function.encode( - query, **({"prompt": prefix} if prefix else {}) - ).tolist() + async def async_embedding_function(query, prefix=None, user=None): + return await asyncio.to_thread( + ( + lambda query, prefix=None: embedding_function.encode( + query, **({"prompt": prefix} if prefix else {}) + ).tolist() + ), + query, + prefix, + ) - async def async_wrapper(query, prefix=None, user=None): - # Run CPU-bound operation in thread pool - return await asyncio.to_thread(sync_encode, query, prefix) - - return async_wrapper + return async_embedding_function elif embedding_engine in ["ollama", "openai", "azure_openai"]: - # Create async function based on engine - if embedding_engine == "openai": - func_async = lambda texts, prefix=None, user=None: generate_openai_batch_embeddings_async( - model=embedding_model, - texts=texts, - url=url, - key=key, - prefix=prefix, - user=user, - ) - elif embedding_engine == "azure_openai": - func_async = lambda texts, prefix=None, user=None: generate_azure_openai_batch_embeddings_async( - model=embedding_model, - texts=texts, - url=url, - key=key, - version=azure_api_version, - prefix=prefix, - user=user, - ) - elif embedding_engine == "ollama": - func_async = lambda texts, prefix=None, user=None: generate_ollama_batch_embeddings_async( - model=embedding_model, - texts=texts, - url=url, - key=key, - prefix=prefix, - user=user, - ) + embedding_function = lambda query, prefix=None, user=None: generate_embeddings( + engine=embedding_engine, + model=embedding_model, + text=query, + prefix=prefix, + url=url, + key=key, + user=user, + azure_api_version=azure_api_version, + ) - # Return async function for parallel batch processing (FastAPI-compatible) - async def embedding_wrapper(query, prefix=None, user=None): + async def async_embedding_function(query, prefix, user): if isinstance(query, list): - return await generate_multiple_async(query, prefix, user, func_async, embedding_batch_size) - else: - result = await func_async([query], prefix=prefix, user=user) - return result[0] if result else None + # Create batches + batches = [ + query[i : i + embedding_batch_size] + for i in range(0, len(query), embedding_batch_size) + ] + log.debug( + f"generate_multiple_async: Processing {len(batches)} batches in parallel" + ) - return embedding_wrapper + # Execute all batches in parallel + tasks = [ + embedding_function(batch, prefix=prefix, user=user) + for batch in batches + ] + batch_results = await asyncio.gather(*tasks) + + # Flatten results + embeddings = [] + for batch_embeddings in batch_results: + if isinstance(batch_embeddings, list): + embeddings.extend(batch_embeddings) + + log.debug( + f"generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches" + ) + return embeddings + else: + return await embedding_function(query, prefix, user) + + return async_embedding_function else: raise ValueError(f"Unknown embedding engine: {embedding_engine}") +async def generate_embeddings( + engine: str, + model: str, + text: Union[str, list[str]], + prefix: Union[str, None] = None, + **kwargs, +): + url = kwargs.get("url", "") + key = kwargs.get("key", "") + user = kwargs.get("user") + + if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None: + if isinstance(text, list): + text = [f"{prefix}{text_element}" for text_element in text] + else: + text = f"{prefix}{text}" + + if engine == "ollama": + embeddings = await agenerate_ollama_batch_embeddings( + **{ + "model": model, + "texts": text if isinstance(text, list) else [text], + "url": url, + "key": key, + "prefix": prefix, + "user": user, + } + ) + return embeddings[0] if isinstance(text, str) else embeddings + elif engine == "openai": + embeddings = await agenerate_openai_batch_embeddings( + model, text if isinstance(text, list) else [text], url, key, prefix, user + ) + return embeddings[0] if isinstance(text, str) else embeddings + elif engine == "azure_openai": + azure_api_version = kwargs.get("azure_api_version", "") + embeddings = await agenerate_azure_openai_batch_embeddings( + model, + text if isinstance(text, list) else [text], + url, + key, + azure_api_version, + prefix, + user, + ) + return embeddings[0] if isinstance(text, str) else embeddings + + def get_reranking_function(reranking_engine, reranking_model, reranking_function): if reranking_function is None: return None @@ -1055,198 +1232,6 @@ def get_model_path(model: str, update_model: bool = False): return model -def generate_openai_batch_embeddings( - model: str, - texts: list[str], - url: str = "https://api.openai.com/v1", - key: str = "", - prefix: str = None, - user: UserModel = None, -) -> Optional[list[list[float]]]: - try: - log.debug( - f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}" - ) - json_data = {"input": texts, "model": model} - if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): - json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix - - r = requests.post( - f"{url}/embeddings", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS and user - else {} - ), - }, - json=json_data, - ) - r.raise_for_status() - data = r.json() - if "data" in data: - return [elem["embedding"] for elem in data["data"]] - else: - raise "Something went wrong :/" - except Exception as e: - log.exception(f"Error generating openai batch embeddings: {e}") - return None - - -def generate_azure_openai_batch_embeddings( - model: str, - texts: list[str], - url: str, - key: str = "", - version: str = "", - prefix: str = None, - user: UserModel = None, -) -> Optional[list[list[float]]]: - try: - log.debug( - f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" - ) - json_data = {"input": texts} - if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): - json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix - - url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" - - for _ in range(5): - r = requests.post( - url, - headers={ - "Content-Type": "application/json", - "api-key": key, - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS and user - else {} - ), - }, - json=json_data, - ) - if r.status_code == 429: - retry = float(r.headers.get("Retry-After", "1")) - time.sleep(retry) - continue - r.raise_for_status() - data = r.json() - if "data" in data: - return [elem["embedding"] for elem in data["data"]] - else: - raise Exception("Something went wrong :/") - return None - except Exception as e: - log.exception(f"Error generating azure openai batch embeddings: {e}") - return None - - -def generate_ollama_batch_embeddings( - model: str, - texts: list[str], - url: str, - key: str = "", - prefix: str = None, - user: UserModel = None, -) -> Optional[list[list[float]]]: - try: - log.debug( - f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}" - ) - json_data = {"input": texts, "model": model} - if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): - json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix - - r = requests.post( - f"{url}/api/embed", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {key}", - **( - { - "X-OpenWebUI-User-Name": quote(user.name, safe=" "), - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - }, - json=json_data, - ) - r.raise_for_status() - data = r.json() - - if "embeddings" in data: - return data["embeddings"] - else: - raise "Something went wrong :/" - except Exception as e: - log.exception(f"Error generating ollama batch embeddings: {e}") - return None - - -def generate_embeddings( - engine: str, - model: str, - text: Union[str, list[str]], - prefix: Union[str, None] = None, - **kwargs, -): - url = kwargs.get("url", "") - key = kwargs.get("key", "") - user = kwargs.get("user") - - if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None: - if isinstance(text, list): - text = [f"{prefix}{text_element}" for text_element in text] - else: - text = f"{prefix}{text}" - - # Run async embedding generation synchronously - if engine == "ollama": - embeddings = asyncio.run(generate_ollama_batch_embeddings_async( - model=model, - texts=text if isinstance(text, list) else [text], - url=url, - key=key, - prefix=prefix, - user=user, - )) - return embeddings[0] if isinstance(text, str) else embeddings - elif engine == "openai": - embeddings = asyncio.run(generate_openai_batch_embeddings_async( - model, text if isinstance(text, list) else [text], url, key, prefix, user - )) - return embeddings[0] if isinstance(text, str) else embeddings - elif engine == "azure_openai": - azure_api_version = kwargs.get("azure_api_version", "") - embeddings = asyncio.run(generate_azure_openai_batch_embeddings_async( - model, - text if isinstance(text, list) else [text], - url, - key, - azure_api_version, - prefix, - user, - )) - return embeddings[0] if isinstance(text, str) else embeddings - - import operator from typing import Optional, Sequence @@ -1280,7 +1265,9 @@ class RerankCompressor(BaseDocumentCompressor): else: from sentence_transformers import util - query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX) + query_embedding = await self.embedding_function( + query, RAG_EMBEDDING_QUERY_PREFIX + ) document_embedding = await self.embedding_function( [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX )