diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 28cb853656..8ece4af736 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2149,6 +2149,21 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int( os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") ) +PGVECTOR_USE_HALFVEC = ( + os.getenv("PGVECTOR_USE_HALFVEC", "false").lower() == "true" +) + +if ( + PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH > 2000 + and not PGVECTOR_USE_HALFVEC +): + raise ValueError( + "PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to " + f"{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the " + "'vector' type. Set PGVECTOR_USE_HALFVEC=true to enable the 'halfvec' " + "type required for high-dimensional embeddings." + ) + PGVECTOR_CREATE_EXTENSION = ( os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true" ) @@ -2198,6 +2213,40 @@ else: except Exception: PGVECTOR_POOL_RECYCLE = 3600 +PGVECTOR_INDEX_METHOD = os.getenv("PGVECTOR_INDEX_METHOD", "").strip().lower() +if PGVECTOR_INDEX_METHOD not in ("ivfflat", "hnsw", ""): + PGVECTOR_INDEX_METHOD = "" + +PGVECTOR_HNSW_M = os.environ.get("PGVECTOR_HNSW_M", 16) + +if PGVECTOR_HNSW_M == "": + PGVECTOR_HNSW_M = 16 +else: + try: + PGVECTOR_HNSW_M = int(PGVECTOR_HNSW_M) + except Exception: + PGVECTOR_HNSW_M = 16 + +PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get("PGVECTOR_HNSW_EF_CONSTRUCTION", 64) + +if PGVECTOR_HNSW_EF_CONSTRUCTION == "": + PGVECTOR_HNSW_EF_CONSTRUCTION = 64 +else: + try: + PGVECTOR_HNSW_EF_CONSTRUCTION = int(PGVECTOR_HNSW_EF_CONSTRUCTION) + except Exception: + PGVECTOR_HNSW_EF_CONSTRUCTION = 64 + +PGVECTOR_IVFFLAT_LISTS = os.environ.get("PGVECTOR_IVFFLAT_LISTS", 100) + +if PGVECTOR_IVFFLAT_LISTS == "": + PGVECTOR_IVFFLAT_LISTS = 100 +else: + try: + PGVECTOR_IVFFLAT_LISTS = int(PGVECTOR_IVFFLAT_LISTS) + except Exception: + PGVECTOR_IVFFLAT_LISTS = 100 + # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 312b48944c..3ffad4fc48 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Tuple import logging import json from sqlalchemy import ( @@ -22,7 +22,7 @@ from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array -from pgvector.sqlalchemy import Vector +from pgvector.sqlalchemy import Vector, HALFVEC from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.exc import NoSuchTableError @@ -44,11 +44,20 @@ from open_webui.config import ( PGVECTOR_POOL_MAX_OVERFLOW, PGVECTOR_POOL_TIMEOUT, PGVECTOR_POOL_RECYCLE, + PGVECTOR_INDEX_METHOD, + PGVECTOR_HNSW_M, + PGVECTOR_HNSW_EF_CONSTRUCTION, + PGVECTOR_IVFFLAT_LISTS, + PGVECTOR_USE_HALFVEC, ) from open_webui.env import SRC_LOG_LEVELS VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH +USE_HALFVEC = PGVECTOR_USE_HALFVEC + +VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector +VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops" Base = declarative_base() log = logging.getLogger(__name__) @@ -67,7 +76,7 @@ class DocumentChunk(Base): __tablename__ = "document_chunk" id = Column(Text, primary_key=True) - vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) + vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True) collection_name = Column(Text, nullable=False) if PGVECTOR_PGCRYPTO: @@ -157,13 +166,9 @@ class PgvectorClient(VectorDBBase): connection = self.session.connection() Base.metadata.create_all(bind=connection) - # Create an index on the vector column if it doesn't exist - self.session.execute( - text( - "CREATE INDEX IF NOT EXISTS idx_document_chunk_vector " - "ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);" - ) - ) + index_method, index_options = self._vector_index_configuration() + self._ensure_vector_index(index_method, index_options) + self.session.execute( text( "CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name " @@ -177,6 +182,80 @@ class PgvectorClient(VectorDBBase): log.exception(f"Error during initialization: {e}") raise + @staticmethod + def _extract_index_method(index_def: Optional[str]) -> Optional[str]: + if not index_def: + return None + try: + after_using = index_def.lower().split("using ", 1)[1] + return after_using.split()[0] + except (IndexError, AttributeError): + return None + + def _vector_index_configuration(self) -> Tuple[str, str]: + if PGVECTOR_INDEX_METHOD: + index_method = PGVECTOR_INDEX_METHOD + log.info( + "Using vector index method '%s' from PGVECTOR_INDEX_METHOD.", + index_method, + ) + elif USE_HALFVEC: + index_method = "hnsw" + log.info( + "VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.", + VECTOR_LENGTH, + ) + else: + index_method = "ivfflat" + + if index_method == "hnsw": + index_options = ( + f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})" + ) + else: + index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})" + + return index_method, index_options + + def _ensure_vector_index(self, index_method: str, index_options: str) -> None: + index_name = "idx_document_chunk_vector" + existing_index_def = self.session.execute( + text( + """ + SELECT indexdef + FROM pg_indexes + WHERE schemaname = current_schema() + AND tablename = 'document_chunk' + AND indexname = :index_name + """ + ), + {"index_name": index_name}, + ).scalar() + + existing_method = self._extract_index_method(existing_index_def) + if existing_method and existing_method != index_method: + raise RuntimeError( + f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now " + f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. " + "Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) " + "and recreate it with the new method before restarting Open WebUI." + ) + + if not existing_index_def: + index_sql = ( + f"CREATE INDEX IF NOT EXISTS {index_name} " + f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})" + ) + if index_options: + index_sql = f"{index_sql} {index_options}" + self.session.execute(text(index_sql)) + log.info( + "Ensured vector index '%s' using %s%s.", + index_name, + index_method, + f" {index_options}" if index_options else "", + ) + def check_vector_length(self) -> None: """ Check if the VECTOR_LENGTH matches the existing vector column dimension in the database. @@ -196,16 +275,19 @@ class PgvectorClient(VectorDBBase): if "vector" in document_chunk_table.columns: vector_column = document_chunk_table.columns["vector"] vector_type = vector_column.type - if isinstance(vector_type, Vector): - db_vector_length = vector_type.dim - if db_vector_length != VECTOR_LENGTH: - raise Exception( - f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. " - "Cannot change vector size after initialization without migrating the data." - ) - else: + expected_type = HALFVEC if USE_HALFVEC else Vector + + if not isinstance(vector_type, expected_type): raise Exception( - "The 'vector' column exists but is not of type 'Vector'." + "The 'vector' column type does not match the expected type " + f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}." + ) + + db_vector_length = getattr(vector_type, "dim", None) + if db_vector_length is not None and db_vector_length != VECTOR_LENGTH: + raise Exception( + f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. " + "Cannot change vector size after initialization without migrating the data." ) else: raise Exception( @@ -360,11 +442,11 @@ class PgvectorClient(VectorDBBase): num_queries = len(vectors) def vector_expr(vector): - return cast(array(vector), Vector(VECTOR_LENGTH)) + return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) # Create the values for query vectors qid_col = column("qid", Integer) - q_vector_col = column("q_vector", Vector(VECTOR_LENGTH)) + q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) query_vectors = ( values(qid_col, q_vector_col) .data(