diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 2df6a0ab51..a132d72013 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -12,7 +12,7 @@ from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document from open_webui.config import VECTOR_DB -from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.models.users import UserModel from open_webui.models.files import Files diff --git a/backend/open_webui/retrieval/vector/connector.py b/backend/open_webui/retrieval/vector/connector.py deleted file mode 100644 index 198e6f1761..0000000000 --- a/backend/open_webui/retrieval/vector/connector.py +++ /dev/null @@ -1,30 +0,0 @@ -from open_webui.config import VECTOR_DB - -if VECTOR_DB == "milvus": - from open_webui.retrieval.vector.dbs.milvus import MilvusClient - - VECTOR_DB_CLIENT = MilvusClient() -elif VECTOR_DB == "qdrant": - from open_webui.retrieval.vector.dbs.qdrant import QdrantClient - - VECTOR_DB_CLIENT = QdrantClient() -elif VECTOR_DB == "opensearch": - from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient - - VECTOR_DB_CLIENT = OpenSearchClient() -elif VECTOR_DB == "pgvector": - from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient - - VECTOR_DB_CLIENT = PgvectorClient() -elif VECTOR_DB == "elasticsearch": - from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient - - VECTOR_DB_CLIENT = ElasticsearchClient() -elif VECTOR_DB == "pinecone": - from open_webui.retrieval.vector.dbs.pinecone import PineconeClient - - VECTOR_DB_CLIENT = PineconeClient() -else: - from open_webui.retrieval.vector.dbs.chroma import ChromaClient - - VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py new file mode 100644 index 0000000000..0216d62a04 --- /dev/null +++ b/backend/open_webui/retrieval/vector/factory.py @@ -0,0 +1,48 @@ +from open_webui.retrieval.vector.main import VectorDBBase +from open_webui.retrieval.vector.type import VectorType +from open_webui.config import VECTOR_DB + + +class Vector: + + @staticmethod + def get_vector(vector_type: str) -> VectorDBBase: + """ + get vector db instance by vector type + """ + match vector_type: + case VectorType.MILVUS: + from open_webui.retrieval.vector.dbs.milvus import MilvusClient + + return MilvusClient() + case VectorType.QDRANT: + from open_webui.retrieval.vector.dbs.qdrant import QdrantClient + + return QdrantClient() + case VectorType.PINECONE: + from open_webui.retrieval.vector.dbs.pinecone import PineconeClient + + return PineconeClient() + case VectorType.OPENSEARCH: + from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient + + return OpenSearchClient() + case VectorType.PGVECTOR: + from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient + + return PgvectorClient() + case VectorType.ELASTICSEARCH: + from open_webui.retrieval.vector.dbs.elasticsearch import ( + ElasticsearchClient, + ) + + return ElasticsearchClient() + case VectorType.CHROMA: + from open_webui.retrieval.vector.dbs.chroma import ChromaClient + + return ChromaClient() + case _: + raise ValueError(f"Unsupported vector type: {vector_type}") + + +VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB) diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py new file mode 100644 index 0000000000..b03bcb4828 --- /dev/null +++ b/backend/open_webui/retrieval/vector/type.py @@ -0,0 +1,11 @@ +from enum import StrEnum + + +class VectorType(StrEnum): + MILVUS = "milvus" + QDRANT = "qdrant" + CHROMA = "chroma" + PINECONE = "pinecone" + ELASTICSEARCH = "elasticsearch" + OPENSEARCH = "opensearch" + PGVECTOR = "pgvector" diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 920130858a..e6e55f4d38 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -10,7 +10,7 @@ from open_webui.models.knowledge import ( KnowledgeUserResponse, ) from open_webui.models.files import Files, FileModel, FileMetadataResponse -from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.routers.retrieval import ( process_file, ProcessFileForm, diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 6d54c9c170..333e9ecc6a 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -4,7 +4,7 @@ import logging from typing import Optional from open_webui.models.memories import Memories, MemoryModel -from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.utils.auth import get_verified_user from open_webui.env import SRC_LOG_LEVELS diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index efefa12fcd..0b414a5519 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -36,7 +36,7 @@ from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage -from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT # Document loaders from open_webui.retrieval.loaders.main import Loader