refac/fix: milvus query logic

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 03:18:38 +04:00
parent 115231c0e5
commit ad98d4300b

View file

@ -1,5 +1,7 @@
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
from pymilvus import connections, Collection
import json
import logging
from typing import Optional
@ -188,6 +190,8 @@ class MilvusClient(VectorDBBase):
return self._result_to_search_result(result)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
# Construct the filter string for querying
collection_name = collection_name.replace("-", "_")
if not self.has_collection(collection_name):
@ -201,72 +205,36 @@ class MilvusClient(VectorDBBase):
for key, value in filter.items()
]
)
max_limit = 16383 # The maximum number of records per request
all_results = []
if limit is None:
# Milvus default limit for query if not specified is 16384, but docs mention iteration.
# Let's set a practical high number if "all" is intended, or handle true pagination.
# For now, if limit is None, we'll fetch in batches up to a very large number.
# This part could be refined based on expected use cases for "get all".
# For this function signature, None implies "as many as possible" up to Milvus limits.
limit = (
16384 * 10
) # A large number to signify fetching many, will be capped by actual data or max_limit per call.
log.info(
f"Limit not specified for query, fetching up to {limit} results in batches."
)
# Initialize offset and remaining to handle pagination
offset = 0
remaining = limit
collection = Collection(f"{self.collection_prefix}_{collection_name}")
collection.load()
all_results = []
try:
log.info(
f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
)
# Loop until there are no more items to fetch or the desired limit is reached
while remaining > 0:
current_fetch = min(
max_limit, remaining if isinstance(remaining, int) else max_limit
)
log.debug(
f"Querying with offset: {offset}, current_fetch: {current_fetch}"
)
results = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
iterator = collection.query_iterator(
filter=filter_string,
output_fields=[
"id",
"data",
"metadata",
], # Explicitly list needed fields. Vector not usually needed in query.
limit=current_fetch,
offset=offset,
],
limit=limit, # Pass the limit directly; None means no limit.
)
if not results:
log.debug("No more results from query.")
break
all_results.extend(results)
results_count = len(results)
log.debug(f"Fetched {results_count} results in this batch.")
if isinstance(remaining, int):
remaining -= results_count
offset += results_count
# Break the loop if the results returned are less than the requested fetch count (means end of data)
if results_count < current_fetch:
log.debug(
"Fetched less than requested, assuming end of results for this query."
)
while True:
result = iterator.next()
if not result:
iterator.close()
break
all_results += result
log.info(f"Total results from query: {len(all_results)}")
return self._result_to_get_result([all_results])
except Exception as e:
log.exception(
f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"