diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 26623c95a4..46d3b719a6 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1886,6 +1886,45 @@ if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY: "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key." ) + +PGVECTOR_POOL_SIZE = os.environ.get("PGVECTOR_POOL_SIZE", None) + +if PGVECTOR_POOL_SIZE != None: + try: + PGVECTOR_POOL_SIZE = int(PGVECTOR_POOL_SIZE) + except Exception: + PGVECTOR_POOL_SIZE = None + +PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get("PGVECTOR_POOL_MAX_OVERFLOW", 0) + +if PGVECTOR_POOL_MAX_OVERFLOW == "": + PGVECTOR_POOL_MAX_OVERFLOW = 0 +else: + try: + PGVECTOR_POOL_MAX_OVERFLOW = int(PGVECTOR_POOL_MAX_OVERFLOW) + except Exception: + PGVECTOR_POOL_MAX_OVERFLOW = 0 + +PGVECTOR_POOL_TIMEOUT = os.environ.get("PGVECTOR_POOL_TIMEOUT", 30) + +if PGVECTOR_POOL_TIMEOUT == "": + PGVECTOR_POOL_TIMEOUT = 30 +else: + try: + PGVECTOR_POOL_TIMEOUT = int(PGVECTOR_POOL_TIMEOUT) + except Exception: + PGVECTOR_POOL_TIMEOUT = 30 + +PGVECTOR_POOL_RECYCLE = os.environ.get("PGVECTOR_POOL_RECYCLE", 3600) + +if PGVECTOR_POOL_RECYCLE == "": + PGVECTOR_POOL_RECYCLE = 3600 +else: + try: + PGVECTOR_POOL_RECYCLE = int(PGVECTOR_POOL_RECYCLE) + except Exception: + PGVECTOR_POOL_RECYCLE = 3600 + # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 632937ef5b..64f12aa6d0 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -18,7 +18,7 @@ from sqlalchemy import ( values, ) from sqlalchemy.sql import true -from sqlalchemy.pool import NullPool +from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array @@ -37,6 +37,10 @@ from open_webui.config import ( PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, PGVECTOR_PGCRYPTO, PGVECTOR_PGCRYPTO_KEY, + PGVECTOR_POOL_SIZE, + PGVECTOR_POOL_MAX_OVERFLOW, + PGVECTOR_POOL_TIMEOUT, + PGVECTOR_POOL_RECYCLE, ) from open_webui.env import SRC_LOG_LEVELS @@ -80,9 +84,24 @@ class PgvectorClient(VectorDBBase): self.session = Session else: - engine = create_engine( - PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool - ) + if isinstance(PGVECTOR_POOL_SIZE, int): + if PGVECTOR_POOL_SIZE > 0: + engine = create_engine( + PGVECTOR_DB_URL, + pool_size=PGVECTOR_POOL_SIZE, + max_overflow=PGVECTOR_POOL_MAX_OVERFLOW, + pool_timeout=PGVECTOR_POOL_TIMEOUT, + pool_recycle=PGVECTOR_POOL_RECYCLE, + pool_pre_ping=True, + poolclass=QueuePool, + ) + else: + engine = create_engine( + PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool + ) + else: + engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True) + SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False )