From 25db8225f804560e91e2a9fe8776dba990d46156 Mon Sep 17 00:00:00 2001 From: Dechao Sun <50395154+vincentsunx@users.noreply.github.com> Date: Fri, 26 Dec 2025 22:32:05 +0800 Subject: [PATCH] openWebUI supports openGauss vector store (#20179) --- backend/open_webui/config.py | 45 ++ .../retrieval/vector/dbs/opengauss.py | 409 ++++++++++++++++++ .../open_webui/retrieval/vector/factory.py | 4 + backend/open_webui/retrieval/vector/type.py | 1 + 4 files changed, 459 insertions(+) create mode 100644 backend/open_webui/retrieval/vector/dbs/opengauss.py diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 9851aa37ba..bea37ecc44 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2342,6 +2342,51 @@ else: except Exception: PGVECTOR_IVFFLAT_LISTS = 100 +# openGauss +OPENGAUSS_DB_URL = os.environ.get("OPENGAUSS_DB_URL", DATABASE_URL) + +OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int( + os.environ.get("OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH", "1536") +) + +OPENGAUSS_POOL_SIZE = os.environ.get("OPENGAUSS_POOL_SIZE", None) + +if OPENGAUSS_POOL_SIZE != None: + try: + OPENGAUSS_POOL_SIZE = int(OPENGAUSS_POOL_SIZE) + except Exception: + OPENGAUSS_POOL_SIZE = None + +OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get("OPENGAUSS_POOL_MAX_OVERFLOW", 0) + +if OPENGAUSS_POOL_MAX_OVERFLOW == "": + OPENGAUSS_POOL_MAX_OVERFLOW = 0 +else: + try: + OPENGAUSS_POOL_MAX_OVERFLOW = int(OPENGAUSS_POOL_MAX_OVERFLOW) + except Exception: + OPENGAUSS_POOL_MAX_OVERFLOW = 0 + +OPENGAUSS_POOL_TIMEOUT = os.environ.get("OPENGAUSS_POOL_TIMEOUT", 30) + +if OPENGAUSS_POOL_TIMEOUT == "": + OPENGAUSS_POOL_TIMEOUT = 30 +else: + try: + OPENGAUSS_POOL_TIMEOUT = int(OPENGAUSS_POOL_TIMEOUT) + except Exception: + OPENGAUSS_POOL_TIMEOUT = 30 + +OPENGAUSS_POOL_RECYCLE = os.environ.get("OPENGAUSS_POOL_RECYCLE", 3600) + +if OPENGAUSS_POOL_RECYCLE == "": + OPENGAUSS_POOL_RECYCLE = 3600 +else: + try: + OPENGAUSS_POOL_RECYCLE = int(OPENGAUSS_POOL_RECYCLE) + except Exception: + OPENGAUSS_POOL_RECYCLE = 3600 + # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) diff --git a/backend/open_webui/retrieval/vector/dbs/opengauss.py b/backend/open_webui/retrieval/vector/dbs/opengauss.py new file mode 100644 index 0000000000..13505e1852 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/opengauss.py @@ -0,0 +1,409 @@ +from typing import Optional, List, Dict, Any +import logging +import re +import json +from sqlalchemy import ( + func, + literal, + cast, + column, + create_engine, + Column, + Integer, + MetaData, + LargeBinary, + select, + text, + Text, + Table, + values, +) +from sqlalchemy.sql import true +from sqlalchemy.pool import NullPool, QueuePool + +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker +from sqlalchemy.dialects.postgresql import JSONB, array +from pgvector.sqlalchemy import Vector +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.exc import NoSuchTableError + +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.dialects import registry + +class OpenGaussDialect(PGDialect_psycopg2): + name = "opengauss" + + def _get_server_version_info(self, connection): + try: + version = connection.exec_driver_sql("SELECT version()").scalar() + if not version: + return (9, 0, 0) + + match = re.search( + r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", + version, + re.IGNORECASE + ) + if match: + return (int(match.group(1)), int(match.group(2)), int(match.group(3))) + + return super()._get_server_version_info(connection) + except Exception: + return (9, 0, 0) + +# Register dialect +registry.register("opengauss", __name__, "OpenGaussDialect") + +from open_webui.retrieval.vector.utils import process_metadata +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) +from open_webui.config import ( + OPENGAUSS_DB_URL, + OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH, + OPENGAUSS_POOL_SIZE, + OPENGAUSS_POOL_MAX_OVERFLOW, + OPENGAUSS_POOL_TIMEOUT, + OPENGAUSS_POOL_RECYCLE, +) + +from open_webui.env import SRC_LOG_LEVELS + +VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH +Base = declarative_base() + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +class DocumentChunk(Base): + __tablename__ = "document_chunk" + + id = Column(Text, primary_key=True) + vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) + collection_name = Column(Text, nullable=False) + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + +class OpenGaussClient(VectorDBBase): + def __init__(self) -> None: + if not OPENGAUSS_DB_URL: + from open_webui.internal.db import Session + self.session = Session + else: + engine_kwargs = { + "pool_pre_ping": True, + "dialect": OpenGaussDialect() + } + + if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0: + engine_kwargs.update({ + "pool_size": OPENGAUSS_POOL_SIZE, + "max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW, + "pool_timeout": OPENGAUSS_POOL_TIMEOUT, + "pool_recycle": OPENGAUSS_POOL_RECYCLE, + "poolclass": QueuePool + }) + else: + engine_kwargs["poolclass"] = NullPool + + engine = create_engine(OPENGAUSS_DB_URL,** engine_kwargs) + + SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False + ) + self.session = scoped_session(SessionLocal) + + try: + connection = self.session.connection() + Base.metadata.create_all(bind=connection) + + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " + "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" + ) + ) + self.session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " + "ON document_chunk (collection_name);" + ) + ) + self.session.commit() + log.info("OpenGauss vector database initialization completed.") + except Exception as e: + self.session.rollback() + log.exception(f"OpenGauss Initialization failed.: {e}") + raise + + def check_vector_length(self) -> None: + metadata = MetaData() + try: + document_chunk_table = Table( + "document_chunk", metadata, autoload_with=self.session.bind + ) + except NoSuchTableError: + return + + if "vector" in document_chunk_table.columns: + vector_column = document_chunk_table.columns["vector"] + vector_type = vector_column.type + if isinstance(vector_type, Vector): + db_vector_length = vector_type.dim + if db_vector_length != VECTOR_LENGTH: + raise Exception( + f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database." + ) + else: + raise Exception("The 'vector' column type is not Vector.") + else: + raise Exception("The 'vector' column does not exist in the 'document_chunk' table.") + + def adjust_vector_length(self, vector: List[float]) -> List[float]: + current_length = len(vector) + if current_length < VECTOR_LENGTH: + vector += [0.0] * (VECTOR_LENGTH - current_length) + elif current_length > VECTOR_LENGTH: + vector = vector[:VECTOR_LENGTH] + return vector + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=process_metadata(item["metadata"]), + ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + log.exception(f"Failed to insert data: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + try: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() + ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = process_metadata(item["metadata"]) + existing.collection_name = collection_name + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=process_metadata(item["metadata"]), + ) + self.session.add(new_chunk) + self.session.commit() + log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.") + except Exception as e: + self.session.rollback() + log.exception(f"Failed to insert or update data.: {e}") + raise + + def search( + self, + collection_name: str, + vectors: List[List[float]], + limit: Optional[int] = None, + ) -> Optional[SearchResult]: + try: + if not vectors: + return None + + vectors = [self.adjust_vector_length(vector) for vector in vectors] + num_queries = len(vectors) + + def vector_expr(vector): + return cast(array(vector), Vector(VECTOR_LENGTH)) + + qid_col = column("qid", Integer) + q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + query_vectors = ( + values(qid_col, q_vector_col) + .data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]) + .alias("query_vectors") + ) + + result_fields = [ + DocumentChunk.id, + DocumentChunk.text, + DocumentChunk.vmetadata, + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label("distance"), + ] + + subq = ( + select(*result_fields) + .where(DocumentChunk.collection_name == collection_name) + .order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) + ) + if limit is not None: + subq = subq.limit(limit) + subq = subq.lateral("result") + + stmt = ( + select( + query_vectors.c.qid, + subq.c.id, + subq.c.text, + subq.c.vmetadata, + subq.c.distance, + ) + .select_from(query_vectors) + .join(subq, true()) + .order_by(query_vectors.c.qid, subq.c.distance) + ) + + result_proxy = self.session.execute(stmt) + results = result_proxy.all() + + ids = [[] for _ in range(num_queries)] + distances = [[] for _ in range(num_queries)] + documents = [[] for _ in range(num_queries)] + metadatas = [[] for _ in range(num_queries)] + + for row in results: + qid = int(row.qid) + ids[qid].append(row.id) + distances[qid].append((2.0 - row.distance) / 2.0) + documents[qid].append(row.text) + metadatas[qid].append(row.vmetadata) + + self.session.rollback() + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + except Exception as e: + self.session.rollback() + log.exception(f"Vector search failed: {e}") + return None + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + + for key, value in filter.items(): + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + self.session.rollback() + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + self.session.rollback() + log.exception(f"Conditional query failed: {e}") + return None + + def get( + self, collection_name: str, limit: Optional[int] = None + ) -> Optional[GetResult]: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) + + results = query.all() + + if not results: + return None + + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + self.session.rollback() + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + except Exception as e: + self.session.rollback() + log.exception(f"Failed to retrieve data: {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ) -> None: + try: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + deleted = query.delete(synchronize_session=False) + self.session.commit() + log.info(f"Deleted {deleted} items from collection '{collection_name}'") + except Exception as e: + self.session.rollback() + log.exception(f"Failed to delete data: {e}") + raise + + def reset(self) -> None: + try: + deleted = self.session.query(DocumentChunk).delete() + self.session.commit() + log.info(f"Reset completed. Deleted {deleted} items") + except Exception as e: + self.session.rollback() + log.exception(f"Reset failed: {e}") + raise + + def close(self) -> None: + pass + + def has_collection(self, collection_name: str) -> bool: + try: + exists = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.collection_name == collection_name) + .first() is not None + ) + self.session.rollback() + return exists + except Exception as e: + self.session.rollback() + log.exception(f"Failed to check collection existence: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + self.delete(collection_name) + log.info(f"Collection '{collection_name}' has been deleted") \ No newline at end of file diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index b843e0926d..68595fb595 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -53,6 +53,10 @@ class Vector: from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient return PgvectorClient() + case VectorType.OPENGAUSS: + from open_webui.retrieval.vector.dbs.opengauss import OpenGaussClient + + return OpenGaussClient() case VectorType.ELASTICSEARCH: from open_webui.retrieval.vector.dbs.elasticsearch import ( ElasticsearchClient, diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py index 292cad1e78..de20133fce 100644 --- a/backend/open_webui/retrieval/vector/type.py +++ b/backend/open_webui/retrieval/vector/type.py @@ -12,3 +12,4 @@ class VectorType(StrEnum): ORACLE23AI = "oracle23ai" S3VECTOR = "s3vector" WEAVIATE = "weaviate" + OPENGAUSS = "opengauss"