mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: pgvector hnsw index type (#19158)
* Adding hnsw index type for pgvector, allowing vector dimensions larger than 2000 * remove some variable assignments * Make USE_HALFVEC variable configurable * Simplify USE_HALFVEC handling * Raise runtime error if the index requires rebuilt --------- Co-authored-by: Moritz <moritz.mueller2@tu-dresden.de>
This commit is contained in:
parent
63ebc295ce
commit
6cdb13d5cb
2 changed files with 152 additions and 21 deletions
|
|
@ -2149,6 +2149,21 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
||||||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PGVECTOR_USE_HALFVEC = (
|
||||||
|
os.getenv("PGVECTOR_USE_HALFVEC", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH > 2000
|
||||||
|
and not PGVECTOR_USE_HALFVEC
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to "
|
||||||
|
f"{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the "
|
||||||
|
"'vector' type. Set PGVECTOR_USE_HALFVEC=true to enable the 'halfvec' "
|
||||||
|
"type required for high-dimensional embeddings."
|
||||||
|
)
|
||||||
|
|
||||||
PGVECTOR_CREATE_EXTENSION = (
|
PGVECTOR_CREATE_EXTENSION = (
|
||||||
os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true"
|
os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
@ -2198,6 +2213,40 @@ else:
|
||||||
except Exception:
|
except Exception:
|
||||||
PGVECTOR_POOL_RECYCLE = 3600
|
PGVECTOR_POOL_RECYCLE = 3600
|
||||||
|
|
||||||
|
PGVECTOR_INDEX_METHOD = os.getenv("PGVECTOR_INDEX_METHOD", "").strip().lower()
|
||||||
|
if PGVECTOR_INDEX_METHOD not in ("ivfflat", "hnsw", ""):
|
||||||
|
PGVECTOR_INDEX_METHOD = ""
|
||||||
|
|
||||||
|
PGVECTOR_HNSW_M = os.environ.get("PGVECTOR_HNSW_M", 16)
|
||||||
|
|
||||||
|
if PGVECTOR_HNSW_M == "":
|
||||||
|
PGVECTOR_HNSW_M = 16
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
PGVECTOR_HNSW_M = int(PGVECTOR_HNSW_M)
|
||||||
|
except Exception:
|
||||||
|
PGVECTOR_HNSW_M = 16
|
||||||
|
|
||||||
|
PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get("PGVECTOR_HNSW_EF_CONSTRUCTION", 64)
|
||||||
|
|
||||||
|
if PGVECTOR_HNSW_EF_CONSTRUCTION == "":
|
||||||
|
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
PGVECTOR_HNSW_EF_CONSTRUCTION = int(PGVECTOR_HNSW_EF_CONSTRUCTION)
|
||||||
|
except Exception:
|
||||||
|
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
|
||||||
|
|
||||||
|
PGVECTOR_IVFFLAT_LISTS = os.environ.get("PGVECTOR_IVFFLAT_LISTS", 100)
|
||||||
|
|
||||||
|
if PGVECTOR_IVFFLAT_LISTS == "":
|
||||||
|
PGVECTOR_IVFFLAT_LISTS = 100
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
PGVECTOR_IVFFLAT_LISTS = int(PGVECTOR_IVFFLAT_LISTS)
|
||||||
|
except Exception:
|
||||||
|
PGVECTOR_IVFFLAT_LISTS = 100
|
||||||
|
|
||||||
# Pinecone
|
# Pinecone
|
||||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
|
@ -22,7 +22,7 @@ from sqlalchemy.pool import NullPool, QueuePool
|
||||||
|
|
||||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector, HALFVEC
|
||||||
from sqlalchemy.ext.mutable import MutableDict
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
|
|
@ -44,11 +44,20 @@ from open_webui.config import (
|
||||||
PGVECTOR_POOL_MAX_OVERFLOW,
|
PGVECTOR_POOL_MAX_OVERFLOW,
|
||||||
PGVECTOR_POOL_TIMEOUT,
|
PGVECTOR_POOL_TIMEOUT,
|
||||||
PGVECTOR_POOL_RECYCLE,
|
PGVECTOR_POOL_RECYCLE,
|
||||||
|
PGVECTOR_INDEX_METHOD,
|
||||||
|
PGVECTOR_HNSW_M,
|
||||||
|
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
||||||
|
PGVECTOR_IVFFLAT_LISTS,
|
||||||
|
PGVECTOR_USE_HALFVEC,
|
||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||||
|
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
||||||
|
|
||||||
|
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
||||||
|
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -67,7 +76,7 @@ class DocumentChunk(Base):
|
||||||
__tablename__ = "document_chunk"
|
__tablename__ = "document_chunk"
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
||||||
collection_name = Column(Text, nullable=False)
|
collection_name = Column(Text, nullable=False)
|
||||||
|
|
||||||
if PGVECTOR_PGCRYPTO:
|
if PGVECTOR_PGCRYPTO:
|
||||||
|
|
@ -157,13 +166,9 @@ class PgvectorClient(VectorDBBase):
|
||||||
connection = self.session.connection()
|
connection = self.session.connection()
|
||||||
Base.metadata.create_all(bind=connection)
|
Base.metadata.create_all(bind=connection)
|
||||||
|
|
||||||
# Create an index on the vector column if it doesn't exist
|
index_method, index_options = self._vector_index_configuration()
|
||||||
self.session.execute(
|
self._ensure_vector_index(index_method, index_options)
|
||||||
text(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
|
||||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.session.execute(
|
self.session.execute(
|
||||||
text(
|
text(
|
||||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||||
|
|
@ -177,6 +182,80 @@ class PgvectorClient(VectorDBBase):
|
||||||
log.exception(f"Error during initialization: {e}")
|
log.exception(f"Error during initialization: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
|
||||||
|
if not index_def:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
after_using = index_def.lower().split("using ", 1)[1]
|
||||||
|
return after_using.split()[0]
|
||||||
|
except (IndexError, AttributeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _vector_index_configuration(self) -> Tuple[str, str]:
|
||||||
|
if PGVECTOR_INDEX_METHOD:
|
||||||
|
index_method = PGVECTOR_INDEX_METHOD
|
||||||
|
log.info(
|
||||||
|
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
|
||||||
|
index_method,
|
||||||
|
)
|
||||||
|
elif USE_HALFVEC:
|
||||||
|
index_method = "hnsw"
|
||||||
|
log.info(
|
||||||
|
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
|
||||||
|
VECTOR_LENGTH,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_method = "ivfflat"
|
||||||
|
|
||||||
|
if index_method == "hnsw":
|
||||||
|
index_options = (
|
||||||
|
f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
|
||||||
|
|
||||||
|
return index_method, index_options
|
||||||
|
|
||||||
|
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
||||||
|
index_name = "idx_document_chunk_vector"
|
||||||
|
existing_index_def = self.session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT indexdef
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE schemaname = current_schema()
|
||||||
|
AND tablename = 'document_chunk'
|
||||||
|
AND indexname = :index_name
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"index_name": index_name},
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
existing_method = self._extract_index_method(existing_index_def)
|
||||||
|
if existing_method and existing_method != index_method:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
||||||
|
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
||||||
|
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
|
||||||
|
"and recreate it with the new method before restarting Open WebUI."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_index_def:
|
||||||
|
index_sql = (
|
||||||
|
f"CREATE INDEX IF NOT EXISTS {index_name} "
|
||||||
|
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
|
||||||
|
)
|
||||||
|
if index_options:
|
||||||
|
index_sql = f"{index_sql} {index_options}"
|
||||||
|
self.session.execute(text(index_sql))
|
||||||
|
log.info(
|
||||||
|
"Ensured vector index '%s' using %s%s.",
|
||||||
|
index_name,
|
||||||
|
index_method,
|
||||||
|
f" {index_options}" if index_options else "",
|
||||||
|
)
|
||||||
|
|
||||||
def check_vector_length(self) -> None:
|
def check_vector_length(self) -> None:
|
||||||
"""
|
"""
|
||||||
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
||||||
|
|
@ -196,17 +275,20 @@ class PgvectorClient(VectorDBBase):
|
||||||
if "vector" in document_chunk_table.columns:
|
if "vector" in document_chunk_table.columns:
|
||||||
vector_column = document_chunk_table.columns["vector"]
|
vector_column = document_chunk_table.columns["vector"]
|
||||||
vector_type = vector_column.type
|
vector_type = vector_column.type
|
||||||
if isinstance(vector_type, Vector):
|
expected_type = HALFVEC if USE_HALFVEC else Vector
|
||||||
db_vector_length = vector_type.dim
|
|
||||||
if db_vector_length != VECTOR_LENGTH:
|
if not isinstance(vector_type, expected_type):
|
||||||
|
raise Exception(
|
||||||
|
"The 'vector' column type does not match the expected type "
|
||||||
|
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
||||||
|
)
|
||||||
|
|
||||||
|
db_vector_length = getattr(vector_type, "dim", None)
|
||||||
|
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||||
"Cannot change vector size after initialization without migrating the data."
|
"Cannot change vector size after initialization without migrating the data."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"The 'vector' column exists but is not of type 'Vector'."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||||
|
|
@ -360,11 +442,11 @@ class PgvectorClient(VectorDBBase):
|
||||||
num_queries = len(vectors)
|
num_queries = len(vectors)
|
||||||
|
|
||||||
def vector_expr(vector):
|
def vector_expr(vector):
|
||||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||||
|
|
||||||
# Create the values for query vectors
|
# Create the values for query vectors
|
||||||
qid_col = column("qid", Integer)
|
qid_col = column("qid", Integer)
|
||||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||||
query_vectors = (
|
query_vectors = (
|
||||||
values(qid_col, q_vector_col)
|
values(qid_col, q_vector_col)
|
||||||
.data(
|
.data(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue