mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-16 22:35:23 +00:00
refactor: More implementation improvements
Signed-off-by: Anush008 <anushshetty90@gmail.com>
This commit is contained in:
parent
7c734d3fea
commit
0ac57a088f
1 changed files with 74 additions and 127 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, List, Dict, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
|
|
@ -25,11 +25,24 @@ from qdrant_client.models import models
|
||||||
|
|
||||||
NO_LIMIT = 999999999
|
NO_LIMIT = 999999999
|
||||||
TENANT_ID_FIELD = "tenant_id"
|
TENANT_ID_FIELD = "tenant_id"
|
||||||
|
DEFAULT_DIMENSION = 384
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
|
||||||
|
return models.FieldCondition(
|
||||||
|
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
|
||||||
|
return models.FieldCondition(
|
||||||
|
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QdrantClient(VectorDBBase):
|
class QdrantClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.collection_prefix = QDRANT_COLLECTION_PREFIX
|
self.collection_prefix = QDRANT_COLLECTION_PREFIX
|
||||||
|
|
@ -48,16 +61,17 @@ class QdrantClient(VectorDBBase):
|
||||||
host = parsed.hostname or self.QDRANT_URI
|
host = parsed.hostname or self.QDRANT_URI
|
||||||
http_port = parsed.port or 6333 # default REST port
|
http_port = parsed.port or 6333 # default REST port
|
||||||
|
|
||||||
if self.PREFER_GRPC:
|
self.client = (
|
||||||
self.client = Qclient(
|
Qclient(
|
||||||
host=host,
|
host=host,
|
||||||
port=http_port,
|
port=http_port,
|
||||||
grpc_port=self.GRPC_PORT,
|
grpc_port=self.GRPC_PORT,
|
||||||
prefer_grpc=self.PREFER_GRPC,
|
prefer_grpc=self.PREFER_GRPC,
|
||||||
api_key=self.QDRANT_API_KEY,
|
api_key=self.QDRANT_API_KEY,
|
||||||
)
|
)
|
||||||
else:
|
if self.PREFER_GRPC
|
||||||
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
else Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)
|
||||||
|
)
|
||||||
|
|
||||||
# Main collection types for multi-tenancy
|
# Main collection types for multi-tenancy
|
||||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||||
|
|
@ -67,23 +81,13 @@ class QdrantClient(VectorDBBase):
|
||||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
|
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
|
||||||
|
|
||||||
def _result_to_get_result(self, points) -> GetResult:
|
def _result_to_get_result(self, points) -> GetResult:
|
||||||
ids = []
|
ids, documents, metadatas = [], [], []
|
||||||
documents = []
|
|
||||||
metadatas = []
|
|
||||||
|
|
||||||
for point in points:
|
for point in points:
|
||||||
payload = point.payload
|
payload = point.payload
|
||||||
ids.append(point.id)
|
ids.append(point.id)
|
||||||
documents.append(payload["text"])
|
documents.append(payload["text"])
|
||||||
metadatas.append(payload["metadata"])
|
metadatas.append(payload["metadata"])
|
||||||
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
return GetResult(
|
|
||||||
**{
|
|
||||||
"ids": [ids],
|
|
||||||
"documents": [documents],
|
|
||||||
"metadatas": [metadatas],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
|
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -116,9 +120,7 @@ class QdrantClient(VectorDBBase):
|
||||||
return self.KNOWLEDGE_COLLECTION, tenant_id
|
return self.KNOWLEDGE_COLLECTION, tenant_id
|
||||||
|
|
||||||
def _create_multi_tenant_collection(
|
def _create_multi_tenant_collection(
|
||||||
self,
|
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
|
||||||
mt_collection_name: str,
|
|
||||||
dimension: int = 384,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
|
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
|
||||||
|
|
@ -145,7 +147,19 @@ class QdrantClient(VectorDBBase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_points(self, items: list[VectorItem], tenant_id: str):
|
for field in ("metadata.hash", "metadata.file_id"):
|
||||||
|
self.client.create_payload_index(
|
||||||
|
collection_name=mt_collection_name,
|
||||||
|
field_name=field,
|
||||||
|
field_schema=models.KeywordIndexParams(
|
||||||
|
type=models.KeywordIndexType.KEYWORD,
|
||||||
|
on_disk=self.QDRANT_ON_DISK,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_points(
|
||||||
|
self, items: List[VectorItem], tenant_id: str
|
||||||
|
) -> List[PointStruct]:
|
||||||
"""
|
"""
|
||||||
Create point structs from vector items with tenant ID.
|
Create point structs from vector items with tenant ID.
|
||||||
"""
|
"""
|
||||||
|
|
@ -163,16 +177,13 @@ class QdrantClient(VectorDBBase):
|
||||||
]
|
]
|
||||||
|
|
||||||
def _ensure_collection(
|
def _ensure_collection(
|
||||||
self,
|
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
|
||||||
mt_collection_name: str,
|
|
||||||
dimension: int = 384,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
|
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
|
||||||
"""
|
"""
|
||||||
if self.client.collection_exists(collection_name=mt_collection_name):
|
if not self.client.collection_exists(collection_name=mt_collection_name):
|
||||||
return
|
self._create_multi_tenant_collection(mt_collection_name, dimension)
|
||||||
self._create_multi_tenant_collection(mt_collection_name, dimension)
|
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -183,9 +194,7 @@ class QdrantClient(VectorDBBase):
|
||||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||||
if not self.client.collection_exists(collection_name=mt_collection):
|
if not self.client.collection_exists(collection_name=mt_collection):
|
||||||
return False
|
return False
|
||||||
tenant_filter = models.FieldCondition(
|
tenant_filter = _tenant_filter(tenant_id)
|
||||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
|
||||||
)
|
|
||||||
count_result = self.client.count(
|
count_result = self.client.count(
|
||||||
collection_name=mt_collection,
|
collection_name=mt_collection,
|
||||||
count_filter=models.Filter(must=[tenant_filter]),
|
count_filter=models.Filter(must=[tenant_filter]),
|
||||||
|
|
@ -195,8 +204,8 @@ class QdrantClient(VectorDBBase):
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
ids: Optional[list[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete vectors by ID or filter from a collection with tenant isolation.
|
Delete vectors by ID or filter from a collection with tenant isolation.
|
||||||
|
|
@ -209,94 +218,55 @@ class QdrantClient(VectorDBBase):
|
||||||
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
|
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
tenant_filter = models.FieldCondition(
|
must_conditions = [_tenant_filter(tenant_id)]
|
||||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
|
||||||
)
|
|
||||||
must_conditions = [tenant_filter]
|
|
||||||
should_conditions = []
|
should_conditions = []
|
||||||
if ids:
|
if ids:
|
||||||
for id_value in ids:
|
should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
|
||||||
should_conditions.append(
|
|
||||||
models.FieldCondition(
|
|
||||||
key="metadata.id",
|
|
||||||
match=models.MatchValue(value=id_value),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif filter:
|
elif filter:
|
||||||
for key, value in filter.items():
|
must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
|
||||||
must_conditions.append(
|
|
||||||
models.FieldCondition(
|
|
||||||
key=f"metadata.{key}",
|
|
||||||
match=models.MatchValue(value=value),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
update_result = self.client.delete(
|
return self.client.delete(
|
||||||
collection_name=mt_collection,
|
collection_name=mt_collection,
|
||||||
points_selector=models.FilterSelector(
|
points_selector=models.FilterSelector(
|
||||||
filter=models.Filter(must=must_conditions, should=should_conditions)
|
filter=models.Filter(must=must_conditions, should=should_conditions)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return update_result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Error deleting from collection {mt_collection}: {e}")
|
log.warning(f"Error deleting from collection {mt_collection}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
self, collection_name: str, vectors: List[List[float | int]], limit: int
|
||||||
) -> Optional[SearchResult]:
|
) -> Optional[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search for the nearest neighbor items based on the vectors with tenant isolation.
|
Search for the nearest neighbor items based on the vectors with tenant isolation.
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
if not self.client or not vectors:
|
||||||
return None
|
return None
|
||||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||||
if not self.client.collection_exists(collection_name=mt_collection):
|
if not self.client.collection_exists(collection_name=mt_collection):
|
||||||
log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
|
log.debug(f"Collection {mt_collection} doesn't exist, search returns None")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None
|
tenant_filter = _tenant_filter(tenant_id)
|
||||||
try:
|
query_response = self.client.query_points(
|
||||||
tenant_filter = models.FieldCondition(
|
collection_name=mt_collection,
|
||||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
query=vectors[0],
|
||||||
)
|
limit=limit,
|
||||||
collection_dim = self.client.get_collection(
|
query_filter=models.Filter(must=[tenant_filter]),
|
||||||
mt_collection
|
)
|
||||||
).config.params.vectors.size
|
get_result = self._result_to_get_result(query_response.points)
|
||||||
if collection_dim != dimension:
|
return SearchResult(
|
||||||
if collection_dim < dimension:
|
ids=get_result.ids,
|
||||||
vectors = [vector[:collection_dim] for vector in vectors]
|
documents=get_result.documents,
|
||||||
else:
|
metadatas=get_result.metadatas,
|
||||||
vectors = [
|
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
|
||||||
vector + [0] * (collection_dim - dimension)
|
)
|
||||||
for vector in vectors
|
|
||||||
]
|
|
||||||
prefetch_query = models.Prefetch(
|
|
||||||
filter=models.Filter(must=[tenant_filter]),
|
|
||||||
limit=NO_LIMIT,
|
|
||||||
)
|
|
||||||
query_response = self.client.query_points(
|
|
||||||
collection_name=mt_collection,
|
|
||||||
query=vectors[0],
|
|
||||||
prefetch=prefetch_query,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
get_result = self._result_to_get_result(query_response.points)
|
|
||||||
return SearchResult(
|
|
||||||
ids=get_result.ids,
|
|
||||||
documents=get_result.documents,
|
|
||||||
metadatas=get_result.metadatas,
|
|
||||||
distances=[
|
|
||||||
[(point.score + 1.0) / 2.0 for point in query_response.points]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(f"Error searching collection '{collection_name}': {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
def query(
|
||||||
|
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Query points with filters and tenant isolation.
|
Query points with filters and tenant isolation.
|
||||||
"""
|
"""
|
||||||
|
|
@ -306,19 +276,10 @@ class QdrantClient(VectorDBBase):
|
||||||
if not self.client.collection_exists(collection_name=mt_collection):
|
if not self.client.collection_exists(collection_name=mt_collection):
|
||||||
log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
|
log.debug(f"Collection {mt_collection} doesn't exist, query returns None")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if limit is None:
|
if limit is None:
|
||||||
limit = NO_LIMIT
|
limit = NO_LIMIT
|
||||||
tenant_filter = models.FieldCondition(
|
tenant_filter = _tenant_filter(tenant_id)
|
||||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
field_conditions = [_metadata_filter(k, v) for k, v in filter.items()]
|
||||||
)
|
|
||||||
field_conditions = []
|
|
||||||
for key, value in filter.items():
|
|
||||||
field_conditions.append(
|
|
||||||
models.FieldCondition(
|
|
||||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
combined_filter = models.Filter(must=[tenant_filter, *field_conditions])
|
||||||
try:
|
try:
|
||||||
points = self.client.query_points(
|
points = self.client.query_points(
|
||||||
|
|
@ -341,36 +302,32 @@ class QdrantClient(VectorDBBase):
|
||||||
if not self.client.collection_exists(collection_name=mt_collection):
|
if not self.client.collection_exists(collection_name=mt_collection):
|
||||||
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
|
||||||
return None
|
return None
|
||||||
|
tenant_filter = _tenant_filter(tenant_id)
|
||||||
tenant_filter = models.FieldCondition(
|
|
||||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
points = self.client.query_points(
|
points = self.client.query_points(
|
||||||
collection_name=mt_collection,
|
collection_name=mt_collection,
|
||||||
query_filter=models.Filter(must=[tenant_filter]),
|
query_filter=models.Filter(must=[tenant_filter]),
|
||||||
limit=NO_LIMIT,
|
limit=NO_LIMIT,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._result_to_get_result(points.points)
|
return self._result_to_get_result(points.points)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error getting collection '{collection_name}': {e}")
|
log.exception(f"Error getting collection '{collection_name}': {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
def upsert(self, collection_name: str, items: List[VectorItem]):
|
||||||
"""
|
"""
|
||||||
Upsert items with tenant ID.
|
Upsert items with tenant ID.
|
||||||
"""
|
"""
|
||||||
if not self.client or not items:
|
if not self.client or not items:
|
||||||
return None
|
return None
|
||||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||||
dimension = len(items[0]["vector"]) if items else None
|
dimension = len(items[0]["vector"])
|
||||||
self._ensure_collection(mt_collection, dimension)
|
self._ensure_collection(mt_collection, dimension)
|
||||||
points = self._create_points(items, tenant_id)
|
points = self._create_points(items, tenant_id)
|
||||||
self.client.upload_points(mt_collection, points)
|
self.client.upload_points(mt_collection, points)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: List[VectorItem]):
|
||||||
"""
|
"""
|
||||||
Insert items with tenant ID.
|
Insert items with tenant ID.
|
||||||
"""
|
"""
|
||||||
|
|
@ -382,11 +339,9 @@ class QdrantClient(VectorDBBase):
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
if not self.client:
|
||||||
return None
|
return None
|
||||||
|
for collection in self.client.get_collections().collections:
|
||||||
collection_names = self.client.get_collections().collections
|
if collection.name.startswith(self.collection_prefix):
|
||||||
for collection_name in collection_names:
|
self.client.delete_collection(collection_name=collection.name)
|
||||||
if collection_name.name.startswith(self.collection_prefix):
|
|
||||||
self.client.delete_collection(collection_name=collection_name.name)
|
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str):
|
def delete_collection(self, collection_name: str):
|
||||||
"""
|
"""
|
||||||
|
|
@ -398,17 +353,9 @@ class QdrantClient(VectorDBBase):
|
||||||
if not self.client.collection_exists(collection_name=mt_collection):
|
if not self.client.collection_exists(collection_name=mt_collection):
|
||||||
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
|
log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.client.delete(
|
self.client.delete(
|
||||||
collection_name=mt_collection,
|
collection_name=mt_collection,
|
||||||
points_selector=models.FilterSelector(
|
points_selector=models.FilterSelector(
|
||||||
filter=models.Filter(
|
filter=models.Filter(must=[_tenant_filter(tenant_id)])
|
||||||
must=[
|
|
||||||
models.FieldCondition(
|
|
||||||
key=TENANT_ID_FIELD,
|
|
||||||
match=models.MatchValue(value=tenant_id),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue