diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 6e07c28016..059ea43cc0 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -1,5 +1,7 @@ from pymilvus import MilvusClient as Client from pymilvus import FieldSchema, DataType +from pymilvus import connections, Collection + import json import logging from typing import Optional @@ -188,6 +190,8 @@ class MilvusClient(VectorDBBase): return self._result_to_search_result(result) def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) + # Construct the filter string for querying collection_name = collection_name.replace("-", "_") if not self.has_collection(collection_name): @@ -201,72 +205,36 @@ class MilvusClient(VectorDBBase): for key, value in filter.items() ] ) - max_limit = 16383 # The maximum number of records per request - all_results = [] - if limit is None: - # Milvus default limit for query if not specified is 16384, but docs mention iteration. - # Let's set a practical high number if "all" is intended, or handle true pagination. - # For now, if limit is None, we'll fetch in batches up to a very large number. - # This part could be refined based on expected use cases for "get all". - # For this function signature, None implies "as many as possible" up to Milvus limits. - limit = ( - 16384 * 10 - ) # A large number to signify fetching many, will be capped by actual data or max_limit per call. - log.info( - f"Limit not specified for query, fetching up to {limit} results in batches." - ) - # Initialize offset and remaining to handle pagination - offset = 0 - remaining = limit + collection = Collection(f"{self.collection_prefix}_{collection_name}") + collection.load() + all_results = [] try: log.info( f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}" ) - # Loop until there are no more items to fetch or the desired limit is reached - while remaining > 0: - current_fetch = min( - max_limit, remaining if isinstance(remaining, int) else max_limit - ) - log.debug( - f"Querying with offset: {offset}, current_fetch: {current_fetch}" - ) - results = self.client.query( - collection_name=f"{self.collection_prefix}_{collection_name}", - filter=filter_string, - output_fields=[ - "id", - "data", - "metadata", - ], # Explicitly list needed fields. Vector not usually needed in query. - limit=current_fetch, - offset=offset, - ) + iterator = collection.query_iterator( + filter=filter_string, + output_fields=[ + "id", + "data", + "metadata", + ], + limit=limit, # Pass the limit directly; None means no limit. + ) - if not results: - log.debug("No more results from query.") - break - - all_results.extend(results) - results_count = len(results) - log.debug(f"Fetched {results_count} results in this batch.") - - if isinstance(remaining, int): - remaining -= results_count - - offset += results_count - - # Break the loop if the results returned are less than the requested fetch count (means end of data) - if results_count < current_fetch: - log.debug( - "Fetched less than requested, assuming end of results for this query." - ) + while True: + result = iterator.next() + if not result: + iterator.close() break + all_results += result log.info(f"Total results from query: {len(all_results)}") return self._result_to_get_result([all_results]) + except Exception as e: log.exception( f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"