diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index a4020062e9..b48ba4f2e2 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1825,6 +1825,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int( os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") ) +PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true" +PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None) +if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY: + raise ValueError( + "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key." + ) + # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index b6cb2a4e25..de4073e126 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -1,12 +1,16 @@ from typing import Optional, List, Dict, Any import logging +import json from sqlalchemy import ( + func, + literal, cast, column, create_engine, Column, Integer, MetaData, + LargeBinary, select, text, Text, @@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import ( SearchResult, GetResult, ) -from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH +from open_webui.config import ( + PGVECTOR_DB_URL, + PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, + PGVECTOR_PGCRYPTO, + PGVECTOR_PGCRYPTO_KEY, +) from open_webui.env import SRC_LOG_LEVELS @@ -39,14 +48,27 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def pgcrypto_encrypt(val, key): + return func.pgp_sym_encrypt(val, literal(key)) + + +def pgcrypto_decrypt(col, key, outtype="text"): + return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype) + + class DocumentChunk(Base): __tablename__ = "document_chunk" id = Column(Text, primary_key=True) vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) collection_name = Column(Text, nullable=False) - text = Column(Text, nullable=True) - vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + + if PGVECTOR_PGCRYPTO: + text = Column(LargeBinary, nullable=True) + vmetadata = Column(LargeBinary, nullable=True) + else: + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) class PgvectorClient(VectorDBBase): @@ -147,44 +169,39 @@ class PgvectorClient(VectorDBBase): def insert(self, collection_name: str, items: List[VectorItem]) -> None: try: - new_items = [] - for item in items: - vector = self.adjust_vector_length(item["vector"]) - new_chunk = DocumentChunk( - id=item["id"], - vector=vector, - collection_name=collection_name, - text=item["text"], - vmetadata=item["metadata"], - ) - new_items.append(new_chunk) - self.session.bulk_save_objects(new_items) - self.session.commit() - log.info( - f"Inserted {len(new_items)} items into collection '{collection_name}'." - ) - except Exception as e: - self.session.rollback() - log.exception(f"Error during insert: {e}") - raise - - def upsert(self, collection_name: str, items: List[VectorItem]) -> None: - try: - for item in items: - vector = self.adjust_vector_length(item["vector"]) - existing = ( - self.session.query(DocumentChunk) - .filter(DocumentChunk.id == item["id"]) - .first() - ) - if existing: - existing.vector = vector - existing.text = item["text"] - existing.vmetadata = item["metadata"] - existing.collection_name = ( - collection_name # Update collection_name if necessary + if PGVECTOR_PGCRYPTO: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + # Use raw SQL for BYTEA/pgcrypto + self.session.execute( + text( + """ + INSERT INTO document_chunk + (id, vector, collection_name, text, vmetadata) + VALUES ( + :id, :vector, :collection_name, + pgp_sym_encrypt(:text, :key), + pgp_sym_encrypt(:metadata::text, :key) + ) + ON CONFLICT (id) DO NOTHING + """ + ), + { + "id": item["id"], + "vector": vector, + "collection_name": collection_name, + "text": item["text"], + "metadata": json.dumps(item["metadata"]), + "key": PGVECTOR_PGCRYPTO_KEY, + }, ) - else: + self.session.commit() + log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'") + + else: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) new_chunk = DocumentChunk( id=item["id"], vector=vector, @@ -192,11 +209,78 @@ class PgvectorClient(VectorDBBase): text=item["text"], vmetadata=item["metadata"], ) - self.session.add(new_chunk) - self.session.commit() - log.info( - f"Upserted {len(items)} items into collection '{collection_name}'." - ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + log.info( + f"Inserted {len(new_items)} items into collection '{collection_name}'." + ) + except Exception as e: + self.session.rollback() + log.exception(f"Error during insert: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + if PGVECTOR_PGCRYPTO: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + self.session.execute( + text( + """ + INSERT INTO document_chunk + (id, vector, collection_name, text, vmetadata) + VALUES ( + :id, :vector, :collection_name, + pgp_sym_encrypt(:text, :key), + pgp_sym_encrypt(:metadata::text, :key) + ) + ON CONFLICT (id) DO UPDATE SET + vector = EXCLUDED.vector, + collection_name = EXCLUDED.collection_name, + text = EXCLUDED.text, + vmetadata = EXCLUDED.vmetadata + """ + ), + { + "id": item["id"], + "vector": vector, + "collection_name": collection_name, + "text": item["text"], + "metadata": json.dumps(item["metadata"]), + "key": PGVECTOR_PGCRYPTO_KEY, + }, + ) + self.session.commit() + log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'") + else: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() + ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = item["metadata"] + existing.collection_name = ( + collection_name # Update collection_name if necessary + ) + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=item["metadata"], + ) + self.session.add(new_chunk) + self.session.commit() + log.info( + f"Upserted {len(items)} items into collection '{collection_name}'." + ) except Exception as e: self.session.rollback() log.exception(f"Error during upsert: {e}") @@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase): .alias("query_vectors") ) + result_fields = [ + DocumentChunk.id, + ] + if PGVECTOR_PGCRYPTO: + result_fields.append( + pgcrypto_decrypt( + DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text + ).label("text") + ) + result_fields.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + ).label("vmetadata") + ) + else: + result_fields.append(DocumentChunk.text) + result_fields.append(DocumentChunk.vmetadata) + result_fields.append( + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label( + "distance" + ) + ) + # Build the lateral subquery for each query vector subq = ( - select( - DocumentChunk.id, - DocumentChunk.text, - DocumentChunk.vmetadata, - ( - DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) - ).label("distance"), - ) + select(*result_fields) .where(DocumentChunk.collection_name == collection_name) .order_by( (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) @@ -299,17 +399,43 @@ class PgvectorClient(VectorDBBase): self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None ) -> Optional[GetResult]: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + if PGVECTOR_PGCRYPTO: + # Build where clause for vmetadata filter + where_clauses = [DocumentChunk.collection_name == collection_name] + for key, value in filter.items(): + # decrypt then check key: JSON filter after decryption + where_clauses.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + )[key].astext + == str(value) + ) + stmt = select( + DocumentChunk.id, + pgcrypto_decrypt( + DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text + ).label("text"), + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + ).label("vmetadata"), + ).where(*where_clauses) + if limit is not None: + stmt = stmt.limit(limit) + results = self.session.execute(stmt).all() + else: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) - for key, value in filter.items(): - query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) - if limit is not None: - query = query.limit(limit) + if limit is not None: + query = query.limit(limit) - results = query.all() + results = query.all() if not results: return None @@ -331,20 +457,38 @@ class PgvectorClient(VectorDBBase): self, collection_name: str, limit: Optional[int] = None ) -> Optional[GetResult]: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) - if limit is not None: - query = query.limit(limit) + if PGVECTOR_PGCRYPTO: + stmt = select( + DocumentChunk.id, + pgcrypto_decrypt( + DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text + ).label("text"), + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + ).label("vmetadata"), + ).where(DocumentChunk.collection_name == collection_name) + if limit is not None: + stmt = stmt.limit(limit) + results = self.session.execute(stmt).all() + ids = [[row.id for row in results]] + documents = [[row.text for row in results]] + metadatas = [[row.vmetadata for row in results]] + else: - results = query.all() + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) - if not results: - return None + results = query.all() - ids = [[result.id for result in results]] - documents = [[result.text for result in results]] - metadatas = [[result.vmetadata for result in results]] + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: @@ -358,17 +502,33 @@ class PgvectorClient(VectorDBBase): filter: Optional[Dict[str, Any]] = None, ) -> None: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) - if ids: - query = query.filter(DocumentChunk.id.in_(ids)) - if filter: - for key, value in filter.items(): - query = query.filter( - DocumentChunk.vmetadata[key].astext == str(value) - ) - deleted = query.delete(synchronize_session=False) + if PGVECTOR_PGCRYPTO: + wheres = [DocumentChunk.collection_name == collection_name] + if ids: + wheres.append(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + wheres.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + )[key].astext + == str(value) + ) + stmt = DocumentChunk.__table__.delete().where(*wheres) + result = self.session.execute(stmt) + deleted = result.rowcount + else: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) + deleted = query.delete(synchronize_session=False) self.session.commit() log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: