refactor: More implementation improvements

Signed-off-by: Anush008 <anushshetty90@gmail.com>
This commit is contained in:
Anush008 2025-07-04 12:33:54 +05:30
parent 7c734d3fea
commit 0ac57a088f
No known key found for this signature in database

View file

@ -1,5 +1,5 @@
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple, List, Dict, Any
from urllib.parse import urlparse from urllib.parse import urlparse
import grpc import grpc
@ -25,11 +25,24 @@ from qdrant_client.models import models
NO_LIMIT = 999999999 NO_LIMIT = 999999999
TENANT_ID_FIELD = "tenant_id" TENANT_ID_FIELD = "tenant_id"
DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
return models.FieldCondition(
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
return models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
class QdrantClient(VectorDBBase): class QdrantClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = QDRANT_COLLECTION_PREFIX self.collection_prefix = QDRANT_COLLECTION_PREFIX
@ -48,16 +61,17 @@ class QdrantClient(VectorDBBase):
host = parsed.hostname or self.QDRANT_URI host = parsed.hostname or self.QDRANT_URI
http_port = parsed.port or 6333 # default REST port http_port = parsed.port or 6333 # default REST port
if self.PREFER_GRPC: self.client = (
self.client = Qclient( Qclient(
host=host, host=host,
port=http_port, port=http_port,
grpc_port=self.GRPC_PORT, grpc_port=self.GRPC_PORT,
prefer_grpc=self.PREFER_GRPC, prefer_grpc=self.PREFER_GRPC,
api_key=self.QDRANT_API_KEY, api_key=self.QDRANT_API_KEY,
) )
else: if self.PREFER_GRPC
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
)
# Main collection types for multi-tenancy # Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
@ -67,23 +81,13 @@ class QdrantClient(VectorDBBase):
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
def _result_to_get_result(self, points) -> GetResult: def _result_to_get_result(self, points) -> GetResult:
ids = [] ids, documents, metadatas = [], [], []
documents = []
metadatas = []
for point in points: for point in points:
payload = point.payload payload = point.payload
ids.append(point.id) ids.append(point.id)
documents.append(payload["text"]) documents.append(payload["text"])
metadatas.append(payload["metadata"]) metadatas.append(payload["metadata"])
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
}
)
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
""" """
@ -116,9 +120,7 @@ class QdrantClient(VectorDBBase):
return self.KNOWLEDGE_COLLECTION, tenant_id return self.KNOWLEDGE_COLLECTION, tenant_id
def _create_multi_tenant_collection( def _create_multi_tenant_collection(
self, self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
mt_collection_name: str,
dimension: int = 384,
): ):
""" """
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields. Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
@ -145,7 +147,19 @@ class QdrantClient(VectorDBBase):
), ),
) )
def _create_points(self, items: list[VectorItem], tenant_id: str): for field in ("metadata.hash", "metadata.file_id"):
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name=field,
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
on_disk=self.QDRANT_ON_DISK,
),
)
def _create_points(
self, items: List[VectorItem], tenant_id: str
) -> List[PointStruct]:
""" """
Create point structs from vector items with tenant ID. Create point structs from vector items with tenant ID.
""" """
@ -163,16 +177,13 @@ class QdrantClient(VectorDBBase):
] ]
def _ensure_collection( def _ensure_collection(
self, self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
mt_collection_name: str,
dimension: int = 384,
): ):
""" """
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields. Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
""" """
if self.client.collection_exists(collection_name=mt_collection_name): if not self.client.collection_exists(collection_name=mt_collection_name):
return self._create_multi_tenant_collection(mt_collection_name, dimension)
self._create_multi_tenant_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
""" """
@ -183,9 +194,7 @@ class QdrantClient(VectorDBBase):
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection): if not self.client.collection_exists(collection_name=mt_collection):
return False return False
tenant_filter = models.FieldCondition( tenant_filter = _tenant_filter(tenant_id)
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
count_result = self.client.count( count_result = self.client.count(
collection_name=mt_collection, collection_name=mt_collection,
count_filter=models.Filter(must=[tenant_filter]), count_filter=models.Filter(must=[tenant_filter]),
@ -195,8 +204,8 @@ class QdrantClient(VectorDBBase):
def delete( def delete(
self, self,
collection_name: str, collection_name: str,
ids: Optional[list[str]] = None, ids: Optional[List[str]] = None,
filter: Optional[dict] = None, filter: Optional[Dict[str, Any]] = None,
): ):
""" """
Delete vectors by ID or filter from a collection with tenant isolation. Delete vectors by ID or filter from a collection with tenant isolation.
@ -209,94 +218,55 @@ class QdrantClient(VectorDBBase):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
return None return None
tenant_filter = models.FieldCondition( must_conditions = [_tenant_filter(tenant_id)]
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
must_conditions = [tenant_filter]
should_conditions = [] should_conditions = []
if ids: if ids:
for id_value in ids: should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
should_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
),
)
elif filter: elif filter:
for key, value in filter.items(): must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
must_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
)
try: try:
update_result = self.client.delete( return self.client.delete(
collection_name=mt_collection, collection_name=mt_collection,
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions) filter=models.Filter(must=must_conditions, should=should_conditions)
), ),
) )
return update_result
except Exception as e: except Exception as e:
log.warning(f"Error deleting from collection {mt_collection}: {e}") log.warning(f"Error deleting from collection {mt_collection}: {e}")
return None return None
def search( def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int self, collection_name: str, vectors: List[List[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
""" """
Search for the nearest neighbor items based on the vectors with tenant isolation. Search for the nearest neighbor items based on the vectors with tenant isolation.
""" """
if not self.client: if not self.client or not vectors:
return None return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
if not self.client.collection_exists(collection_name=mt_collection): if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, search returns None") log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
return None return None
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None tenant_filter = _tenant_filter(tenant_id)
try: query_response = self.client.query_points(
tenant_filter = models.FieldCondition( collection_name=mt_collection,
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) query=vectors[0],
) limit=limit,
collection_dim = self.client.get_collection( query_filter=models.Filter(must=[tenant_filter]),
mt_collection )
).config.params.vectors.size get_result = self._result_to_get_result(query_response.points)
if collection_dim != dimension: return SearchResult(
if collection_dim < dimension: ids=get_result.ids,
vectors = [vector[:collection_dim] for vector in vectors] documents=get_result.documents,
else: metadatas=get_result.metadatas,
vectors = [ distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
vector + [0] * (collection_dim - dimension) )
for vector in vectors
]
prefetch_query = models.Prefetch(
filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT,
)
query_response = self.client.query_points(
collection_name=mt_collection,
query=vectors[0],
prefetch=prefetch_query,
limit=limit,
)
get_result = self._result_to_get_result(query_response.points)
return SearchResult(
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[
[(point.score + 1.0) / 2.0 for point in query_response.points]
],
)
except Exception as e:
log.exception(f"Error searching collection '{collection_name}': {e}")
return None
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
):
""" """
Query points with filters and tenant isolation. Query points with filters and tenant isolation.
""" """
@ -306,19 +276,10 @@ class QdrantClient(VectorDBBase):
if not self.client.collection_exists(collection_name=mt_collection): if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, query returns None") log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
return None return None
if limit is None: if limit is None:
limit = NO_LIMIT limit = NO_LIMIT
tenant_filter = models.FieldCondition( tenant_filter = _tenant_filter(tenant_id)
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
)
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
)
combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
try: try:
points = self.client.query_points( points = self.client.query_points(
@ -341,36 +302,32 @@ class QdrantClient(VectorDBBase):
if not self.client.collection_exists(collection_name=mt_collection): if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, get returns None") log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
return None return None
tenant_filter = _tenant_filter(tenant_id)
tenant_filter = models.FieldCondition(
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
try: try:
points = self.client.query_points( points = self.client.query_points(
collection_name=mt_collection, collection_name=mt_collection,
query_filter=models.Filter(must=[tenant_filter]), query_filter=models.Filter(must=[tenant_filter]),
limit=NO_LIMIT, limit=NO_LIMIT,
) )
return self._result_to_get_result(points.points) return self._result_to_get_result(points.points)
except Exception as e: except Exception as e:
log.exception(f"Error getting collection '{collection_name}': {e}") log.exception(f"Error getting collection '{collection_name}': {e}")
return None return None
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: List[VectorItem]):
""" """
Upsert items with tenant ID. Upsert items with tenant ID.
""" """
if not self.client or not items: if not self.client or not items:
return None return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
dimension = len(items[0]["vector"]) if items else None dimension = len(items[0]["vector"])
self._ensure_collection(mt_collection, dimension) self._ensure_collection(mt_collection, dimension)
points = self._create_points(items, tenant_id) points = self._create_points(items, tenant_id)
self.client.upload_points(mt_collection, points) self.client.upload_points(mt_collection, points)
return None return None
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: List[VectorItem]):
""" """
Insert items with tenant ID. Insert items with tenant ID.
""" """
@ -382,11 +339,9 @@ class QdrantClient(VectorDBBase):
""" """
if not self.client: if not self.client:
return None return None
for collection in self.client.get_collections().collections:
collection_names = self.client.get_collections().collections if collection.name.startswith(self.collection_prefix):
for collection_name in collection_names: self.client.delete_collection(collection_name=collection.name)
if collection_name.name.startswith(self.collection_prefix):
self.client.delete_collection(collection_name=collection_name.name)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
""" """
@ -398,17 +353,9 @@ class QdrantClient(VectorDBBase):
if not self.client.collection_exists(collection_name=mt_collection): if not self.client.collection_exists(collection_name=mt_collection):
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
return None return None
self.client.delete( self.client.delete(
collection_name=mt_collection, collection_name=mt_collection,
points_selector=models.FilterSelector( points_selector=models.FilterSelector(
filter=models.Filter( filter=models.Filter(must=[_tenant_filter(tenant_id)])
must=[
models.FieldCondition(
key=TENANT_ID_FIELD,
match=models.MatchValue(value=tenant_id),
)
]
)
), ),
) )