mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
feat: pgvector hnsw index type (#19158)
* Adding hnsw index type for pgvector, allowing vector dimensions larger than 2000 * remove some variable assignments * Make USE_HALFVEC variable configurable * Simplify USE_HALFVEC handling * Raise runtime error if the index requires rebuilt --------- Co-authored-by: Moritz <moritz.mueller2@tu-dresden.de>
This commit is contained in:
parent
63ebc295ce
commit
6cdb13d5cb
2 changed files with 152 additions and 21 deletions
|
|
@ -2149,6 +2149,21 @@ PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH = int(
|
|||
os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
|
||||
)
|
||||
|
||||
PGVECTOR_USE_HALFVEC = (
|
||||
os.getenv("PGVECTOR_USE_HALFVEC", "false").lower() == "true"
|
||||
)
|
||||
|
||||
if (
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH > 2000
|
||||
and not PGVECTOR_USE_HALFVEC
|
||||
):
|
||||
raise ValueError(
|
||||
"PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH is set to "
|
||||
f"{PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH}, which exceeds the 2000 dimension limit of the "
|
||||
"'vector' type. Set PGVECTOR_USE_HALFVEC=true to enable the 'halfvec' "
|
||||
"type required for high-dimensional embeddings."
|
||||
)
|
||||
|
||||
PGVECTOR_CREATE_EXTENSION = (
|
||||
os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true"
|
||||
)
|
||||
|
|
@ -2198,6 +2213,40 @@ else:
|
|||
except Exception:
|
||||
PGVECTOR_POOL_RECYCLE = 3600
|
||||
|
||||
PGVECTOR_INDEX_METHOD = os.getenv("PGVECTOR_INDEX_METHOD", "").strip().lower()
|
||||
if PGVECTOR_INDEX_METHOD not in ("ivfflat", "hnsw", ""):
|
||||
PGVECTOR_INDEX_METHOD = ""
|
||||
|
||||
PGVECTOR_HNSW_M = os.environ.get("PGVECTOR_HNSW_M", 16)
|
||||
|
||||
if PGVECTOR_HNSW_M == "":
|
||||
PGVECTOR_HNSW_M = 16
|
||||
else:
|
||||
try:
|
||||
PGVECTOR_HNSW_M = int(PGVECTOR_HNSW_M)
|
||||
except Exception:
|
||||
PGVECTOR_HNSW_M = 16
|
||||
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION = os.environ.get("PGVECTOR_HNSW_EF_CONSTRUCTION", 64)
|
||||
|
||||
if PGVECTOR_HNSW_EF_CONSTRUCTION == "":
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
|
||||
else:
|
||||
try:
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION = int(PGVECTOR_HNSW_EF_CONSTRUCTION)
|
||||
except Exception:
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION = 64
|
||||
|
||||
PGVECTOR_IVFFLAT_LISTS = os.environ.get("PGVECTOR_IVFFLAT_LISTS", 100)
|
||||
|
||||
if PGVECTOR_IVFFLAT_LISTS == "":
|
||||
PGVECTOR_IVFFLAT_LISTS = 100
|
||||
else:
|
||||
try:
|
||||
PGVECTOR_IVFFLAT_LISTS = int(PGVECTOR_IVFFLAT_LISTS)
|
||||
except Exception:
|
||||
PGVECTOR_IVFFLAT_LISTS = 100
|
||||
|
||||
# Pinecone
|
||||
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
|
||||
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
import logging
|
||||
import json
|
||||
from sqlalchemy import (
|
||||
|
|
@ -22,7 +22,7 @@ from sqlalchemy.pool import NullPool, QueuePool
|
|||
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from pgvector.sqlalchemy import Vector, HALFVEC
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
|
|
@ -44,11 +44,20 @@ from open_webui.config import (
|
|||
PGVECTOR_POOL_MAX_OVERFLOW,
|
||||
PGVECTOR_POOL_TIMEOUT,
|
||||
PGVECTOR_POOL_RECYCLE,
|
||||
PGVECTOR_INDEX_METHOD,
|
||||
PGVECTOR_HNSW_M,
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
||||
PGVECTOR_IVFFLAT_LISTS,
|
||||
PGVECTOR_USE_HALFVEC,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
||||
|
||||
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
||||
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -67,7 +76,7 @@ class DocumentChunk(Base):
|
|||
__tablename__ = "document_chunk"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
||||
collection_name = Column(Text, nullable=False)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
|
|
@ -157,13 +166,9 @@ class PgvectorClient(VectorDBBase):
|
|||
connection = self.session.connection()
|
||||
Base.metadata.create_all(bind=connection)
|
||||
|
||||
# Create an index on the vector column if it doesn't exist
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
||||
)
|
||||
)
|
||||
index_method, index_options = self._vector_index_configuration()
|
||||
self._ensure_vector_index(index_method, index_options)
|
||||
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
|
|
@ -177,6 +182,80 @@ class PgvectorClient(VectorDBBase):
|
|||
log.exception(f"Error during initialization: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
|
||||
if not index_def:
|
||||
return None
|
||||
try:
|
||||
after_using = index_def.lower().split("using ", 1)[1]
|
||||
return after_using.split()[0]
|
||||
except (IndexError, AttributeError):
|
||||
return None
|
||||
|
||||
def _vector_index_configuration(self) -> Tuple[str, str]:
|
||||
if PGVECTOR_INDEX_METHOD:
|
||||
index_method = PGVECTOR_INDEX_METHOD
|
||||
log.info(
|
||||
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
|
||||
index_method,
|
||||
)
|
||||
elif USE_HALFVEC:
|
||||
index_method = "hnsw"
|
||||
log.info(
|
||||
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
|
||||
VECTOR_LENGTH,
|
||||
)
|
||||
else:
|
||||
index_method = "ivfflat"
|
||||
|
||||
if index_method == "hnsw":
|
||||
index_options = (
|
||||
f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
|
||||
)
|
||||
else:
|
||||
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
|
||||
|
||||
return index_method, index_options
|
||||
|
||||
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
||||
index_name = "idx_document_chunk_vector"
|
||||
existing_index_def = self.session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = current_schema()
|
||||
AND tablename = 'document_chunk'
|
||||
AND indexname = :index_name
|
||||
"""
|
||||
),
|
||||
{"index_name": index_name},
|
||||
).scalar()
|
||||
|
||||
existing_method = self._extract_index_method(existing_index_def)
|
||||
if existing_method and existing_method != index_method:
|
||||
raise RuntimeError(
|
||||
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
||||
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
||||
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
|
||||
"and recreate it with the new method before restarting Open WebUI."
|
||||
)
|
||||
|
||||
if not existing_index_def:
|
||||
index_sql = (
|
||||
f"CREATE INDEX IF NOT EXISTS {index_name} "
|
||||
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
|
||||
)
|
||||
if index_options:
|
||||
index_sql = f"{index_sql} {index_options}"
|
||||
self.session.execute(text(index_sql))
|
||||
log.info(
|
||||
"Ensured vector index '%s' using %s%s.",
|
||||
index_name,
|
||||
index_method,
|
||||
f" {index_options}" if index_options else "",
|
||||
)
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
"""
|
||||
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
||||
|
|
@ -196,17 +275,20 @@ class PgvectorClient(VectorDBBase):
|
|||
if "vector" in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns["vector"]
|
||||
vector_type = vector_column.type
|
||||
if isinstance(vector_type, Vector):
|
||||
db_vector_length = vector_type.dim
|
||||
if db_vector_length != VECTOR_LENGTH:
|
||||
expected_type = HALFVEC if USE_HALFVEC else Vector
|
||||
|
||||
if not isinstance(vector_type, expected_type):
|
||||
raise Exception(
|
||||
"The 'vector' column type does not match the expected type "
|
||||
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
||||
)
|
||||
|
||||
db_vector_length = getattr(vector_type, "dim", None)
|
||||
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column exists but is not of type 'Vector'."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
|
|
@ -360,11 +442,11 @@ class PgvectorClient(VectorDBBase):
|
|||
num_queries = len(vectors)
|
||||
|
||||
def vector_expr(vector):
|
||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
||||
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
|
||||
# Create the values for query vectors
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
|
|
|
|||
Loading…
Reference in a new issue