diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index dead8458cb..b3dc52d7b2 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -432,13 +432,14 @@ def get_embedding_function( if isinstance(query, list): embeddings = [] for i in range(0, len(query), embedding_batch_size): - embeddings.extend( - func( - query[i : i + embedding_batch_size], - prefix=prefix, - user=user, - ) + batch_embeddings = func( + query[i : i + embedding_batch_size], + prefix=prefix, + user=user, ) + + if isinstance(batch_embeddings, list): + embeddings.extend(batch_embeddings) return embeddings else: return func(query, prefix, user) diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index dd5e2d5bc4..9c3566ab2b 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1334,7 +1334,7 @@ def save_docs_to_vector_db( ) return True - log.info(f"adding to collection {collection_name}") + log.info(f"generating embeddings for {collection_name}") embedding_function = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, @@ -1381,11 +1381,18 @@ def save_docs_to_vector_db( for idx, text in enumerate(texts) ] + log.info(f"adding to collection {collection_name}") VECTOR_DB_CLIENT.insert( collection_name=collection_name, items=items, ) + # Validate the number of items inserted + result = VECTOR_DB_CLIENT.query( + collection_name=collection_name, + filter={"metadata": metadata} if metadata else None, + ) + return True except Exception as e: log.exception(e)