feat: experimental pgvector pgcrypto support

This commit is contained in:
Timothy Jaeryang Baek 2025-06-09 18:14:33 +04:00
parent 6d4f449085
commit 7f488b3754
2 changed files with 250 additions and 83 deletions

View file

@ -1825,6 +1825,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
) )
PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true"
PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None)
if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
)
# Pinecone # Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)

View file

@ -1,12 +1,16 @@
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
import logging import logging
import json
from sqlalchemy import ( from sqlalchemy import (
func,
literal,
cast, cast,
column, column,
create_engine, create_engine,
Column, Column,
Integer, Integer,
MetaData, MetaData,
LargeBinary,
select, select,
text, text,
Text, Text,
@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
SearchResult, SearchResult,
GetResult, GetResult,
) )
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH from open_webui.config import (
PGVECTOR_DB_URL,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
)
from open_webui.env import SRC_LOG_LEVELS from open_webui.env import SRC_LOG_LEVELS
@ -39,14 +48,27 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
class DocumentChunk(Base): class DocumentChunk(Base):
__tablename__ = "document_chunk" __tablename__ = "document_chunk"
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
collection_name = Column(Text, nullable=False) collection_name = Column(Text, nullable=False)
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) if PGVECTOR_PGCRYPTO:
text = Column(LargeBinary, nullable=True)
vmetadata = Column(LargeBinary, nullable=True)
else:
text = Column(Text, nullable=True)
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
class PgvectorClient(VectorDBBase): class PgvectorClient(VectorDBBase):
@ -147,44 +169,39 @@ class PgvectorClient(VectorDBBase):
def insert(self, collection_name: str, items: List[VectorItem]) -> None: def insert(self, collection_name: str, items: List[VectorItem]) -> None:
try: try:
new_items = [] if PGVECTOR_PGCRYPTO:
for item in items: for item in items:
vector = self.adjust_vector_length(item["vector"]) vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk( # Use raw SQL for BYTEA/pgcrypto
id=item["id"], self.session.execute(
vector=vector, text(
collection_name=collection_name, """
text=item["text"], INSERT INTO document_chunk
vmetadata=item["metadata"], (id, vector, collection_name, text, vmetadata)
) VALUES (
new_items.append(new_chunk) :id, :vector, :collection_name,
self.session.bulk_save_objects(new_items) pgp_sym_encrypt(:text, :key),
self.session.commit() pgp_sym_encrypt(:metadata::text, :key)
log.info( )
f"Inserted {len(new_items)} items into collection '{collection_name}'." ON CONFLICT (id) DO NOTHING
) """
except Exception as e: ),
self.session.rollback() {
log.exception(f"Error during insert: {e}") "id": item["id"],
raise "vector": vector,
"collection_name": collection_name,
def upsert(self, collection_name: str, items: List[VectorItem]) -> None: "text": item["text"],
try: "metadata": json.dumps(item["metadata"]),
for item in items: "key": PGVECTOR_PGCRYPTO_KEY,
vector = self.adjust_vector_length(item["vector"]) },
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
) )
else: self.session.commit()
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
new_chunk = DocumentChunk( new_chunk = DocumentChunk(
id=item["id"], id=item["id"],
vector=vector, vector=vector,
@ -192,11 +209,78 @@ class PgvectorClient(VectorDBBase):
text=item["text"], text=item["text"],
vmetadata=item["metadata"], vmetadata=item["metadata"],
) )
self.session.add(new_chunk) new_items.append(new_chunk)
self.session.commit() self.session.bulk_save_objects(new_items)
log.info( self.session.commit()
f"Upserted {len(items)} items into collection '{collection_name}'." log.info(
) f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
self.session.execute(
text(
"""
INSERT INTO document_chunk
(id, vector, collection_name, text, vmetadata)
VALUES (
:id, :vector, :collection_name,
pgp_sym_encrypt(:text, :key),
pgp_sym_encrypt(:metadata::text, :key)
)
ON CONFLICT (id) DO UPDATE SET
vector = EXCLUDED.vector,
collection_name = EXCLUDED.collection_name,
text = EXCLUDED.text,
vmetadata = EXCLUDED.vmetadata
"""
),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata": json.dumps(item["metadata"]),
"key": PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = item["metadata"]
existing.collection_name = (
collection_name # Update collection_name if necessary
)
else:
new_chunk = DocumentChunk(
id=item["id"],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=item["metadata"],
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
except Exception as e: except Exception as e:
self.session.rollback() self.session.rollback()
log.exception(f"Error during upsert: {e}") log.exception(f"Error during upsert: {e}")
@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
.alias("query_vectors") .alias("query_vectors")
) )
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
# Build the lateral subquery for each query vector # Build the lateral subquery for each query vector
subq = ( subq = (
select( select(*result_fields)
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
).label("distance"),
)
.where(DocumentChunk.collection_name == collection_name) .where(DocumentChunk.collection_name == collection_name)
.order_by( .order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
@ -299,17 +399,43 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
try: try:
query = self.session.query(DocumentChunk).filter( if PGVECTOR_PGCRYPTO:
DocumentChunk.collection_name == collection_name # Build where clause for vmetadata filter
) where_clauses = [DocumentChunk.collection_name == collection_name]
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
for key, value in filter.items(): for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
if limit is not None: if limit is not None:
query = query.limit(limit) query = query.limit(limit)
results = query.all() results = query.all()
if not results: if not results:
return None return None
@ -331,20 +457,38 @@ class PgvectorClient(VectorDBBase):
self, collection_name: str, limit: Optional[int] = None self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
try: try:
query = self.session.query(DocumentChunk).filter( if PGVECTOR_PGCRYPTO:
DocumentChunk.collection_name == collection_name stmt = select(
) DocumentChunk.id,
if limit is not None: pgcrypto_decrypt(
query = query.limit(limit) DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
ids = [[row.id for row in results]]
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
results = query.all() query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if limit is not None:
query = query.limit(limit)
if not results: results = query.all()
return None
ids = [[result.id for result in results]] if not results:
documents = [[result.text for result in results]] return None
metadatas = [[result.vmetadata for result in results]]
ids = [[result.id for result in results]]
documents = [[result.text for result in results]]
metadatas = [[result.vmetadata for result in results]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e: except Exception as e:
@ -358,17 +502,33 @@ class PgvectorClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
try: try:
query = self.session.query(DocumentChunk).filter( if PGVECTOR_PGCRYPTO:
DocumentChunk.collection_name == collection_name wheres = [DocumentChunk.collection_name == collection_name]
) if ids:
if ids: wheres.append(DocumentChunk.id.in_(ids))
query = query.filter(DocumentChunk.id.in_(ids)) if filter:
if filter: for key, value in filter.items():
for key, value in filter.items(): wheres.append(
query = query.filter( pgcrypto_decrypt(
DocumentChunk.vmetadata[key].astext == str(value) DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
) )[key].astext
deleted = query.delete(synchronize_session=False) == str(value)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
deleted = query.delete(synchronize_session=False)
self.session.commit() self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.") log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e: except Exception as e: