mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
feat: experimental pgvector pgcrypto support
This commit is contained in:
parent
6d4f449085
commit
7f488b3754
2 changed files with 250 additions and 83 deletions
|
|
@ -1825,6 +1825,13 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
||||||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true"
|
||||||
|
PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None)
|
||||||
|
if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY:
|
||||||
|
raise ValueError(
|
||||||
|
"PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key."
|
||||||
|
)
|
||||||
|
|
||||||
# Pinecone
|
# Pinecone
|
||||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
func,
|
||||||
|
literal,
|
||||||
cast,
|
cast,
|
||||||
column,
|
column,
|
||||||
create_engine,
|
create_engine,
|
||||||
Column,
|
Column,
|
||||||
Integer,
|
Integer,
|
||||||
MetaData,
|
MetaData,
|
||||||
|
LargeBinary,
|
||||||
select,
|
select,
|
||||||
text,
|
text,
|
||||||
Text,
|
Text,
|
||||||
|
|
@ -28,7 +32,12 @@ from open_webui.retrieval.vector.main import (
|
||||||
SearchResult,
|
SearchResult,
|
||||||
GetResult,
|
GetResult,
|
||||||
)
|
)
|
||||||
from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
from open_webui.config import (
|
||||||
|
PGVECTOR_DB_URL,
|
||||||
|
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||||
|
PGVECTOR_PGCRYPTO,
|
||||||
|
PGVECTOR_PGCRYPTO_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
@ -39,12 +48,25 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def pgcrypto_encrypt(val, key):
|
||||||
|
return func.pgp_sym_encrypt(val, literal(key))
|
||||||
|
|
||||||
|
|
||||||
|
def pgcrypto_decrypt(col, key, outtype="text"):
|
||||||
|
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
||||||
|
|
||||||
|
|
||||||
class DocumentChunk(Base):
|
class DocumentChunk(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||||
collection_name = Column(Text, nullable=False)
|
collection_name = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
text = Column(LargeBinary, nullable=True)
|
||||||
|
vmetadata = Column(LargeBinary, nullable=True)
|
||||||
|
else:
|
||||||
text = Column(Text, nullable=True)
|
text = Column(Text, nullable=True)
|
||||||
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
||||||
|
|
||||||
|
|
@ -147,6 +169,36 @@ class PgvectorClient(VectorDBBase):
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
try:
|
try:
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
for item in items:
|
||||||
|
vector = self.adjust_vector_length(item["vector"])
|
||||||
|
# Use raw SQL for BYTEA/pgcrypto
|
||||||
|
self.session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO document_chunk
|
||||||
|
(id, vector, collection_name, text, vmetadata)
|
||||||
|
VALUES (
|
||||||
|
:id, :vector, :collection_name,
|
||||||
|
pgp_sym_encrypt(:text, :key),
|
||||||
|
pgp_sym_encrypt(:metadata::text, :key)
|
||||||
|
)
|
||||||
|
ON CONFLICT (id) DO NOTHING
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": vector,
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": json.dumps(item["metadata"]),
|
||||||
|
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.session.commit()
|
||||||
|
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
|
||||||
|
|
||||||
|
else:
|
||||||
new_items = []
|
new_items = []
|
||||||
for item in items:
|
for item in items:
|
||||||
vector = self.adjust_vector_length(item["vector"])
|
vector = self.adjust_vector_length(item["vector"])
|
||||||
|
|
@ -170,6 +222,38 @@ class PgvectorClient(VectorDBBase):
|
||||||
|
|
||||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
try:
|
try:
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
for item in items:
|
||||||
|
vector = self.adjust_vector_length(item["vector"])
|
||||||
|
self.session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO document_chunk
|
||||||
|
(id, vector, collection_name, text, vmetadata)
|
||||||
|
VALUES (
|
||||||
|
:id, :vector, :collection_name,
|
||||||
|
pgp_sym_encrypt(:text, :key),
|
||||||
|
pgp_sym_encrypt(:metadata::text, :key)
|
||||||
|
)
|
||||||
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
vector = EXCLUDED.vector,
|
||||||
|
collection_name = EXCLUDED.collection_name,
|
||||||
|
text = EXCLUDED.text,
|
||||||
|
vmetadata = EXCLUDED.vmetadata
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": item["id"],
|
||||||
|
"vector": vector,
|
||||||
|
"collection_name": collection_name,
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": json.dumps(item["metadata"]),
|
||||||
|
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.session.commit()
|
||||||
|
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
||||||
|
else:
|
||||||
for item in items:
|
for item in items:
|
||||||
vector = self.adjust_vector_length(item["vector"])
|
vector = self.adjust_vector_length(item["vector"])
|
||||||
existing = (
|
existing = (
|
||||||
|
|
@ -230,16 +314,32 @@ class PgvectorClient(VectorDBBase):
|
||||||
.alias("query_vectors")
|
.alias("query_vectors")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
result_fields = [
|
||||||
|
DocumentChunk.id,
|
||||||
|
]
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
result_fields.append(
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||||
|
).label("text")
|
||||||
|
)
|
||||||
|
result_fields.append(
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||||
|
).label("vmetadata")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result_fields.append(DocumentChunk.text)
|
||||||
|
result_fields.append(DocumentChunk.vmetadata)
|
||||||
|
result_fields.append(
|
||||||
|
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||||
|
"distance"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Build the lateral subquery for each query vector
|
# Build the lateral subquery for each query vector
|
||||||
subq = (
|
subq = (
|
||||||
select(
|
select(*result_fields)
|
||||||
DocumentChunk.id,
|
|
||||||
DocumentChunk.text,
|
|
||||||
DocumentChunk.vmetadata,
|
|
||||||
(
|
|
||||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
|
||||||
).label("distance"),
|
|
||||||
)
|
|
||||||
.where(DocumentChunk.collection_name == collection_name)
|
.where(DocumentChunk.collection_name == collection_name)
|
||||||
.order_by(
|
.order_by(
|
||||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||||
|
|
@ -299,12 +399,38 @@ class PgvectorClient(VectorDBBase):
|
||||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||||
) -> Optional[GetResult]:
|
) -> Optional[GetResult]:
|
||||||
try:
|
try:
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
# Build where clause for vmetadata filter
|
||||||
|
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||||
|
for key, value in filter.items():
|
||||||
|
# decrypt then check key: JSON filter after decryption
|
||||||
|
where_clauses.append(
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||||
|
)[key].astext
|
||||||
|
== str(value)
|
||||||
|
)
|
||||||
|
stmt = select(
|
||||||
|
DocumentChunk.id,
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||||
|
).label("text"),
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||||
|
).label("vmetadata"),
|
||||||
|
).where(*where_clauses)
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
results = self.session.execute(stmt).all()
|
||||||
|
else:
|
||||||
query = self.session.query(DocumentChunk).filter(
|
query = self.session.query(DocumentChunk).filter(
|
||||||
DocumentChunk.collection_name == collection_name
|
DocumentChunk.collection_name == collection_name
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value in filter.items():
|
for key, value in filter.items():
|
||||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
query = query.filter(
|
||||||
|
DocumentChunk.vmetadata[key].astext == str(value)
|
||||||
|
)
|
||||||
|
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
|
|
@ -331,6 +457,24 @@ class PgvectorClient(VectorDBBase):
|
||||||
self, collection_name: str, limit: Optional[int] = None
|
self, collection_name: str, limit: Optional[int] = None
|
||||||
) -> Optional[GetResult]:
|
) -> Optional[GetResult]:
|
||||||
try:
|
try:
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
stmt = select(
|
||||||
|
DocumentChunk.id,
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||||
|
).label("text"),
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||||
|
).label("vmetadata"),
|
||||||
|
).where(DocumentChunk.collection_name == collection_name)
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
results = self.session.execute(stmt).all()
|
||||||
|
ids = [[row.id for row in results]]
|
||||||
|
documents = [[row.text for row in results]]
|
||||||
|
metadatas = [[row.vmetadata for row in results]]
|
||||||
|
else:
|
||||||
|
|
||||||
query = self.session.query(DocumentChunk).filter(
|
query = self.session.query(DocumentChunk).filter(
|
||||||
DocumentChunk.collection_name == collection_name
|
DocumentChunk.collection_name == collection_name
|
||||||
)
|
)
|
||||||
|
|
@ -358,6 +502,22 @@ class PgvectorClient(VectorDBBase):
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
wheres = [DocumentChunk.collection_name == collection_name]
|
||||||
|
if ids:
|
||||||
|
wheres.append(DocumentChunk.id.in_(ids))
|
||||||
|
if filter:
|
||||||
|
for key, value in filter.items():
|
||||||
|
wheres.append(
|
||||||
|
pgcrypto_decrypt(
|
||||||
|
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||||
|
)[key].astext
|
||||||
|
== str(value)
|
||||||
|
)
|
||||||
|
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
||||||
|
result = self.session.execute(stmt)
|
||||||
|
deleted = result.rowcount
|
||||||
|
else:
|
||||||
query = self.session.query(DocumentChunk).filter(
|
query = self.session.query(DocumentChunk).filter(
|
||||||
DocumentChunk.collection_name == collection_name
|
DocumentChunk.collection_name == collection_name
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue