open-webui/backend/open_webui/retrieval/vector/dbs/milvus.py

394 lines
15 KiB
Python
Raw Normal View History

2024-09-12 05:52:19 +00:00
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
2025-08-13 23:18:38 +00:00
from pymilvus import connections, Collection
2024-09-12 05:52:19 +00:00
import json
import logging
2024-09-10 03:37:06 +00:00
from typing import Optional
2025-07-31 13:45:06 +00:00
2025-09-29 01:17:27 +00:00
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
2024-09-12 05:52:19 +00:00
from open_webui.config import (
2025-01-19 19:59:07 +00:00
MILVUS_URI,
MILVUS_DB,
MILVUS_TOKEN,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
MILVUS_DISKANN_MAX_DEGREE,
MILVUS_DISKANN_SEARCH_LIST_SIZE,
2024-09-12 05:52:19 +00:00
)
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
2024-09-10 03:37:06 +00:00
2025-05-10 15:00:01 +00:00
class MilvusClient(VectorDBBase):
2024-09-10 03:37:06 +00:00
def __init__(self):
2024-09-12 05:52:19 +00:00
self.collection_prefix = "open_webui"
if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)
2024-09-10 03:37:06 +00:00
2024-09-13 05:30:30 +00:00
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
for match in result:
_ids = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata"))
ids.append(_ids)
documents.append(_documents)
metadatas.append(_metadatas)
return GetResult(
**{
"ids": ids,
"documents": documents,
"metadatas": metadatas,
}
)
def _result_to_search_result(self, result) -> SearchResult:
2024-09-12 05:52:19 +00:00
ids = []
distances = []
documents = []
metadatas = []
for match in result:
_ids = []
_distances = []
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
2025-03-25 18:09:17 +00:00
_dist = (item.get("distance") + 1.0) / 2.0
_distances.append(_dist)
2024-09-12 05:52:19 +00:00
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
2024-09-13 05:18:20 +00:00
return SearchResult(
**{
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
}
)
2024-09-12 05:52:19 +00:00
def _create_collection(self, collection_name: str, dimension: int):
schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=True,
)
schema.add_field(
field_name="id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=dimension,
description="vector",
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
)
index_params = self.client.prepare_index_params()
# Use configurations from config.py
index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper()
2025-05-10 15:00:01 +00:00
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
index_creation_params = {}
if index_type == "HNSW":
2025-05-10 15:00:01 +00:00
index_creation_params = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
}
log.info(f"HNSW params: {index_creation_params}")
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type == "DISKANN":
index_creation_params = {
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
}
log.info(f"DISKANN params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
2024-09-12 05:52:19 +00:00
index_params.add_index(
2024-09-26 20:59:09 +00:00
field_name="vector",
index_type=index_type,
metric_type=metric_type,
params=index_creation_params,
2024-09-12 05:52:19 +00:00
)
2024-09-10 03:37:06 +00:00
2024-09-12 05:52:19 +00:00
self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}",
schema=schema,
index_params=index_params,
)
2025-05-10 15:00:01 +00:00
log.info(
f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
)
2024-09-12 05:52:19 +00:00
2024-09-12 06:00:31 +00:00
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
2024-09-12 06:00:31 +00:00
return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
2024-09-10 03:37:06 +00:00
def delete_collection(self, collection_name: str):
2024-09-12 05:52:19 +00:00
# Delete the collection based on the collection name.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
2024-09-12 05:52:19 +00:00
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
2024-09-10 03:37:06 +00:00
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
2024-09-13 05:18:20 +00:00
) -> Optional[SearchResult]:
2024-09-12 05:52:19 +00:00
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
2024-09-12 05:52:19 +00:00
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
# search_params=search_params # Potentially add later if needed
2024-09-12 05:52:19 +00:00
)
2024-09-13 05:30:30 +00:00
return self._result_to_search_result(result)
2024-09-10 03:37:06 +00:00
def query(self, collection_name: str, filter: dict, limit: int = -1):
2025-08-13 23:18:38 +00:00
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
2025-05-10 15:00:01 +00:00
log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
2024-10-07 00:58:09 +00:00
return None
filter_expressions = []
for key, value in filter.items():
if isinstance(value, str):
filter_expressions.append(f'metadata["{key}"] == "{value}"')
else:
filter_expressions.append(f'metadata["{key}"] == {value}')
filter_string = " && ".join(filter_expressions)
2024-10-05 16:58:46 +00:00
2025-08-13 23:18:38 +00:00
collection = Collection(f"{self.collection_prefix}_{collection_name}")
collection.load()
2025-05-10 15:00:01 +00:00
2024-10-07 00:58:09 +00:00
try:
2025-05-10 15:00:01 +00:00
log.info(
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
)
2025-08-13 23:18:38 +00:00
iterator = collection.query_iterator(
expr=filter_string,
2025-08-13 23:18:38 +00:00
output_fields=[
"id",
"data",
"metadata",
],
limit=limit if limit > 0 else -1,
2025-08-13 23:18:38 +00:00
)
2025-05-10 15:00:01 +00:00
all_results = []
2025-08-13 23:18:38 +00:00
while True:
batch = iterator.next()
if not batch:
2025-08-13 23:18:38 +00:00
iterator.close()
2024-10-07 00:58:09 +00:00
break
all_results.extend(batch)
2025-05-10 15:00:01 +00:00
log.debug(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results] if all_results else [[]])
2025-08-13 23:18:38 +00:00
2024-10-07 00:58:09 +00:00
except Exception as e:
2025-02-27 06:18:18 +00:00
log.exception(
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
2025-02-27 06:18:18 +00:00
)
2024-10-07 00:58:09 +00:00
return None
2024-10-03 13:53:21 +00:00
2024-09-13 05:18:20 +00:00
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. This can be very resource-intensive for large collections.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
2025-05-10 15:00:01 +00:00
log.warning(
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
)
# Using query with a trivial filter to get all items.
# This will use the paginated query logic.
return self.query(collection_name=collection_name, filter={}, limit=-1)
2024-09-10 03:37:06 +00:00
def insert(self, collection_name: str, items: list[VectorItem]):
2024-09-12 05:52:19 +00:00
# Insert the items into the collection, if the collection does not exist, it will be created.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
2024-09-12 05:52:19 +00:00
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
2025-05-10 15:00:01 +00:00
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
)
if not items:
2025-05-10 15:00:01 +00:00
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
)
raise ValueError(
"Cannot create Milvus collection without items to determine vector dimension."
)
2024-09-12 05:52:19 +00:00
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
2025-05-10 15:00:01 +00:00
log.info(
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
2024-09-12 05:52:19 +00:00
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
2025-09-29 01:17:27 +00:00
"metadata": process_metadata(item["metadata"]),
2024-09-12 05:52:19 +00:00
}
for item in items
],
)
2024-09-10 03:37:06 +00:00
def upsert(self, collection_name: str, items: list[VectorItem]):
2024-09-12 05:52:19 +00:00
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
2024-09-12 05:52:19 +00:00
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
2025-05-10 15:00:01 +00:00
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
)
if not items:
2025-05-10 15:00:01 +00:00
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
)
raise ValueError(
"Cannot create Milvus collection for upsert without items to determine vector dimension."
)
2024-09-12 05:52:19 +00:00
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
2025-05-10 15:00:01 +00:00
log.info(
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
2024-09-12 05:52:19 +00:00
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
2025-09-29 01:17:27 +00:00
"metadata": process_metadata(item["metadata"]),
2024-09-12 05:52:19 +00:00
}
for item in items
],
)
2024-09-10 03:37:06 +00:00
2024-10-03 13:43:50 +00:00
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids or filter.
2024-10-07 00:58:09 +00:00
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
2025-05-10 15:00:01 +00:00
log.warning(
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None
2024-10-03 13:43:50 +00:00
if ids:
2025-05-10 15:00:01 +00:00
log.info(
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
)
2024-10-03 13:43:50 +00:00
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
ids=ids,
)
elif filter:
filter_string = " && ".join(
[
2024-10-07 00:58:09 +00:00
f'metadata["{key}"] == {json.dumps(value)}'
2024-10-03 13:43:50 +00:00
for key, value in filter.items()
]
)
2025-05-10 15:00:01 +00:00
log.info(
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
)
2024-10-03 13:43:50 +00:00
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter=filter_string,
)
else:
2025-05-10 15:00:01 +00:00
log.warning(
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
)
return None
2024-09-10 03:37:06 +00:00
def reset(self):
# Resets the database. This will delete all collections and item entries that match the prefix.
2025-05-10 15:00:01 +00:00
log.warning(
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
)
2024-09-12 05:52:19 +00:00
collection_names = self.client.list_collections()
deleted_collections = []
for collection_name_full in collection_names:
if collection_name_full.startswith(self.collection_prefix):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")