From 0ac57a088f6fb14eb5ce306965e8e17e0bdb005b Mon Sep 17 00:00:00 2001 From: Anush008 Date: Fri, 4 Jul 2025 12:33:54 +0530 Subject: [PATCH] refactor: More implementation improvements Signed-off-by: Anush008 --- .../vector/dbs/qdrant_multitenancy.py | 201 +++++++----------- 1 file changed, 74 insertions(+), 127 deletions(-) diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index df2c4e2431..fae12c94e3 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Dict, Any from urllib.parse import urlparse import grpc @@ -25,11 +25,24 @@ from qdrant_client.models import models NO_LIMIT = 999999999 TENANT_ID_FIELD = "tenant_id" +DEFAULT_DIMENSION = 384 log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def _tenant_filter(tenant_id: str) -> models.FieldCondition: + return models.FieldCondition( + key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) + ) + + +def _metadata_filter(key: str, value: Any) -> models.FieldCondition: + return models.FieldCondition( + key=f"metadata.{key}", match=models.MatchValue(value=value) + ) + + class QdrantClient(VectorDBBase): def __init__(self): self.collection_prefix = QDRANT_COLLECTION_PREFIX @@ -48,16 +61,17 @@ class QdrantClient(VectorDBBase): host = parsed.hostname or self.QDRANT_URI http_port = parsed.port or 6333 # default REST port - if self.PREFER_GRPC: - self.client = Qclient( + self.client = ( + Qclient( host=host, port=http_port, grpc_port=self.GRPC_PORT, prefer_grpc=self.PREFER_GRPC, api_key=self.QDRANT_API_KEY, ) - else: - self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) + if self.PREFER_GRPC + else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) + ) # Main collection types for multi-tenancy self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" @@ -67,23 +81,13 @@ class QdrantClient(VectorDBBase): self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" def _result_to_get_result(self, points) -> GetResult: - ids = [] - documents = [] - metadatas = [] - + ids, documents, metadatas = [], [], [] for point in points: payload = point.payload ids.append(point.id) documents.append(payload["text"]) metadatas.append(payload["metadata"]) - - return GetResult( - **{ - "ids": [ids], - "documents": [documents], - "metadatas": [metadatas], - } - ) + return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: """ @@ -116,9 +120,7 @@ class QdrantClient(VectorDBBase): return self.KNOWLEDGE_COLLECTION, tenant_id def _create_multi_tenant_collection( - self, - mt_collection_name: str, - dimension: int = 384, + self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION ): """ Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields. @@ -145,7 +147,19 @@ class QdrantClient(VectorDBBase): ), ) - def _create_points(self, items: list[VectorItem], tenant_id: str): + for field in ("metadata.hash", "metadata.file_id"): + self.client.create_payload_index( + collection_name=mt_collection_name, + field_name=field, + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + on_disk=self.QDRANT_ON_DISK, + ), + ) + + def _create_points( + self, items: List[VectorItem], tenant_id: str + ) -> List[PointStruct]: """ Create point structs from vector items with tenant ID. """ @@ -163,16 +177,13 @@ class QdrantClient(VectorDBBase): ] def _ensure_collection( - self, - mt_collection_name: str, - dimension: int = 384, + self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION ): """ Ensure the collection exists and payload indexes are created for tenant_id and metadata fields. """ - if self.client.collection_exists(collection_name=mt_collection_name): - return - self._create_multi_tenant_collection(mt_collection_name, dimension) + if not self.client.collection_exists(collection_name=mt_collection_name): + self._create_multi_tenant_collection(mt_collection_name, dimension) def has_collection(self, collection_name: str) -> bool: """ @@ -183,9 +194,7 @@ class QdrantClient(VectorDBBase): mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) if not self.client.collection_exists(collection_name=mt_collection): return False - tenant_filter = models.FieldCondition( - key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) - ) + tenant_filter = _tenant_filter(tenant_id) count_result = self.client.count( collection_name=mt_collection, count_filter=models.Filter(must=[tenant_filter]), @@ -195,8 +204,8 @@ class QdrantClient(VectorDBBase): def delete( self, collection_name: str, - ids: Optional[list[str]] = None, - filter: Optional[dict] = None, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, ): """ Delete vectors by ID or filter from a collection with tenant isolation. @@ -209,94 +218,55 @@ class QdrantClient(VectorDBBase): log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") return None - tenant_filter = models.FieldCondition( - key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) - ) - must_conditions = [tenant_filter] + must_conditions = [_tenant_filter(tenant_id)] should_conditions = [] if ids: - for id_value in ids: - should_conditions.append( - models.FieldCondition( - key="metadata.id", - match=models.MatchValue(value=id_value), - ), - ) + should_conditions = [_metadata_filter("id", id_value) for id_value in ids] elif filter: - for key, value in filter.items(): - must_conditions.append( - models.FieldCondition( - key=f"metadata.{key}", - match=models.MatchValue(value=value), - ), - ) + must_conditions += [_metadata_filter(k, v) for k, v in filter.items()] try: - update_result = self.client.delete( + return self.client.delete( collection_name=mt_collection, points_selector=models.FilterSelector( filter=models.Filter(must=must_conditions, should=should_conditions) ), ) - - return update_result except Exception as e: log.warning(f"Error deleting from collection {mt_collection}: {e}") return None def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: List[List[float | int]], limit: int ) -> Optional[SearchResult]: """ Search for the nearest neighbor items based on the vectors with tenant isolation. """ - if not self.client: + if not self.client or not vectors: return None mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) if not self.client.collection_exists(collection_name=mt_collection): log.debug(f"Collection {mt_collection} doesn't exist, search returns None") return None - dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None - try: - tenant_filter = models.FieldCondition( - key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) - ) - collection_dim = self.client.get_collection( - mt_collection - ).config.params.vectors.size - if collection_dim != dimension: - if collection_dim < dimension: - vectors = [vector[:collection_dim] for vector in vectors] - else: - vectors = [ - vector + [0] * (collection_dim - dimension) - for vector in vectors - ] - prefetch_query = models.Prefetch( - filter=models.Filter(must=[tenant_filter]), - limit=NO_LIMIT, - ) - query_response = self.client.query_points( - collection_name=mt_collection, - query=vectors[0], - prefetch=prefetch_query, - limit=limit, - ) - get_result = self._result_to_get_result(query_response.points) - return SearchResult( - ids=get_result.ids, - documents=get_result.documents, - metadatas=get_result.metadatas, - distances=[ - [(point.score + 1.0) / 2.0 for point in query_response.points] - ], - ) - except Exception as e: - log.exception(f"Error searching collection '{collection_name}': {e}") - return None + tenant_filter = _tenant_filter(tenant_id) + query_response = self.client.query_points( + collection_name=mt_collection, + query=vectors[0], + limit=limit, + query_filter=models.Filter(must=[tenant_filter]), + ) + get_result = self._result_to_get_result(query_response.points) + return SearchResult( + ids=get_result.ids, + documents=get_result.documents, + metadatas=get_result.metadatas, + distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]], + ) - def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ): """ Query points with filters and tenant isolation. """ @@ -306,19 +276,10 @@ class QdrantClient(VectorDBBase): if not self.client.collection_exists(collection_name=mt_collection): log.debug(f"Collection {mt_collection} doesn't exist, query returns None") return None - if limit is None: limit = NO_LIMIT - tenant_filter = models.FieldCondition( - key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) - ) - field_conditions = [] - for key, value in filter.items(): - field_conditions.append( - models.FieldCondition( - key=f"metadata.{key}", match=models.MatchValue(value=value) - ) - ) + tenant_filter = _tenant_filter(tenant_id) + field_conditions = [_metadata_filter(k, v) for k, v in filter.items()] combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) try: points = self.client.query_points( @@ -341,36 +302,32 @@ class QdrantClient(VectorDBBase): if not self.client.collection_exists(collection_name=mt_collection): log.debug(f"Collection {mt_collection} doesn't exist, get returns None") return None - - tenant_filter = models.FieldCondition( - key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) - ) + tenant_filter = _tenant_filter(tenant_id) try: points = self.client.query_points( collection_name=mt_collection, query_filter=models.Filter(must=[tenant_filter]), limit=NO_LIMIT, ) - return self._result_to_get_result(points.points) except Exception as e: log.exception(f"Error getting collection '{collection_name}': {e}") return None - def upsert(self, collection_name: str, items: list[VectorItem]): + def upsert(self, collection_name: str, items: List[VectorItem]): """ Upsert items with tenant ID. """ if not self.client or not items: return None mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - dimension = len(items[0]["vector"]) if items else None + dimension = len(items[0]["vector"]) self._ensure_collection(mt_collection, dimension) points = self._create_points(items, tenant_id) self.client.upload_points(mt_collection, points) return None - def insert(self, collection_name: str, items: list[VectorItem]): + def insert(self, collection_name: str, items: List[VectorItem]): """ Insert items with tenant ID. """ @@ -382,11 +339,9 @@ class QdrantClient(VectorDBBase): """ if not self.client: return None - - collection_names = self.client.get_collections().collections - for collection_name in collection_names: - if collection_name.name.startswith(self.collection_prefix): - self.client.delete_collection(collection_name=collection_name.name) + for collection in self.client.get_collections().collections: + if collection.name.startswith(self.collection_prefix): + self.client.delete_collection(collection_name=collection.name) def delete_collection(self, collection_name: str): """ @@ -398,17 +353,9 @@ class QdrantClient(VectorDBBase): if not self.client.collection_exists(collection_name=mt_collection): log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") return None - self.client.delete( collection_name=mt_collection, points_selector=models.FilterSelector( - filter=models.Filter( - must=[ - models.FieldCondition( - key=TENANT_ID_FIELD, - match=models.MatchValue(value=tenant_id), - ) - ] - ) + filter=models.Filter(must=[_tenant_filter(tenant_id)]) ), )