mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: experimental pgvector pgcrypto support
This commit is contained in:
parent
6d4f449085
commit
7f488b3754
2 changed files with 250 additions and 83 deletions
|
|
@ -1825,6 +1825,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
|||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||
)
|
||||
|
||||
PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true"
|
||||
PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None)
|
||||
if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
|
||||
raise ValueError(
|
||||
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
|
||||
)
|
||||
|
||||
# Pinecone
|
||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from typing import Optional, List, Dict, Any
|
||||
import logging
|
||||
import json
|
||||
from sqlalchemy import (
|
||||
func,
|
||||
literal,
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
LargeBinary,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
|
|
@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
|
|||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
from open_webui.config import (
|
||||
PGVECTOR_DB_URL,
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
PGVECTOR_PGCRYPTO,
|
||||
PGVECTOR_PGCRYPTO_KEY,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
|
@ -39,12 +48,25 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def pgcrypto_encrypt(val, key):
|
||||
return func.pgp_sym_encrypt(val, literal(key))
|
||||
|
||||
|
||||
def pgcrypto_decrypt(col, key, outtype="text"):
|
||||
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
text = Column(LargeBinary, nullable=True)
|
||||
vmetadata = Column(LargeBinary, nullable=True)
|
||||
else:
|
||||
text = Column(Text, nullable=True)
|
||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||
|
||||
|
|
@ -147,6 +169,36 @@ class PgvectorClient(VectorDBBase):
|
|||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
# Use raw SQL for BYTEA/pgcrypto
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, vector, collection_name, text, vmetadata)
|
||||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
|
||||
|
||||
else:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
|
|
@ -170,6 +222,38 @@ class PgvectorClient(VectorDBBase):
|
|||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
self.session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document_chunk
|
||||
(id, vector, collection_name, text, vmetadata)
|
||||
VALUES (
|
||||
:id, :vector, :collection_name,
|
||||
pgp_sym_encrypt(:text, :key),
|
||||
pgp_sym_encrypt(:metadata::text, :key)
|
||||
)
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
vector = EXCLUDED.vector,
|
||||
collection_name = EXCLUDED.collection_name,
|
||||
text = EXCLUDED.text,
|
||||
vmetadata = EXCLUDED.vmetadata
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": json.dumps(item["metadata"]),
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
||||
else:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
|
|
@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
|
|||
.alias("query_vectors")
|
||||
)
|
||||
|
||||
result_fields = [
|
||||
DocumentChunk.id,
|
||||
]
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text")
|
||||
)
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata")
|
||||
)
|
||||
else:
|
||||
result_fields.append(DocumentChunk.text)
|
||||
result_fields.append(DocumentChunk.vmetadata)
|
||||
result_fields.append(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
)
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
subq = (
|
||||
select(
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
).label("distance"),
|
||||
)
|
||||
select(*result_fields)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
|
|
@ -299,12 +399,38 @@ class PgvectorClient(VectorDBBase):
|
|||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Build where clause for vmetadata filter
|
||||
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||
for key, value in filter.items():
|
||||
# decrypt then check key: JSON filter after decryption
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
).where(*where_clauses)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
|
@ -331,6 +457,24 @@ class PgvectorClient(VectorDBBase):
|
|||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
).where(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
ids = [[row.id for row in results]]
|
||||
documents = [[row.text for row in results]]
|
||||
metadatas = [[row.vmetadata for row in results]]
|
||||
else:
|
||||
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
|
@ -358,6 +502,22 @@ class PgvectorClient(VectorDBBase):
|
|||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
wheres = [DocumentChunk.collection_name == collection_name]
|
||||
if ids:
|
||||
wheres.append(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
wheres.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
||||
result = self.session.execute(stmt)
|
||||
deleted = result.rowcount
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue