mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
chore: run formatting
This commit is contained in:
parent
8dcf668448
commit
860f3b3cab
2 changed files with 287 additions and 205 deletions
|
|
@ -1,4 +1,9 @@
|
||||||
from open_webui.retrieval.vector.main import VectorDBBase, VectorItem, GetResult, SearchResult
|
from open_webui.retrieval.vector.main import (
|
||||||
|
VectorDBBase,
|
||||||
|
VectorItem,
|
||||||
|
GetResult,
|
||||||
|
SearchResult,
|
||||||
|
)
|
||||||
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Any, Union
|
||||||
|
|
@ -8,39 +13,48 @@ import boto3
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
class S3VectorClient(VectorDBBase):
|
class S3VectorClient(VectorDBBase):
|
||||||
"""
|
"""
|
||||||
AWS S3 Vector integration for Open WebUI Knowledge.
|
AWS S3 Vector integration for Open WebUI Knowledge.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.bucket_name = S3_VECTOR_BUCKET_NAME
|
self.bucket_name = S3_VECTOR_BUCKET_NAME
|
||||||
self.region = S3_VECTOR_REGION
|
self.region = S3_VECTOR_REGION
|
||||||
|
|
||||||
# Simple validation - log warnings instead of raising exceptions
|
# Simple validation - log warnings instead of raising exceptions
|
||||||
if not self.bucket_name:
|
if not self.bucket_name:
|
||||||
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
|
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
|
||||||
if not self.region:
|
if not self.region:
|
||||||
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
|
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
|
||||||
|
|
||||||
if self.bucket_name and self.region:
|
if self.bucket_name and self.region:
|
||||||
try:
|
try:
|
||||||
self.client = boto3.client("s3vectors", region_name=self.region)
|
self.client = boto3.client("s3vectors", region_name=self.region)
|
||||||
log.info(f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'")
|
log.info(
|
||||||
|
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Failed to initialize S3Vector client: {e}")
|
log.error(f"Failed to initialize S3Vector client: {e}")
|
||||||
self.client = None
|
self.client = None
|
||||||
else:
|
else:
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
def _create_index(self, index_name: str, dimension: int, data_type: str = "float32", distance_metric: str = "cosine") -> None:
|
def _create_index(
|
||||||
|
self,
|
||||||
|
index_name: str,
|
||||||
|
dimension: int,
|
||||||
|
data_type: str = "float32",
|
||||||
|
distance_metric: str = "cosine",
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new index in the S3 vector bucket for the given collection if it does not exist.
|
Create a new index in the S3 vector bucket for the given collection if it does not exist.
|
||||||
"""
|
"""
|
||||||
if self.has_collection(index_name):
|
if self.has_collection(index_name):
|
||||||
log.debug(f"Index '{index_name}' already exists, skipping creation")
|
log.debug(f"Index '{index_name}' already exists, skipping creation")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client.create_index(
|
self.client.create_index(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name,
|
||||||
|
|
@ -49,40 +63,44 @@ class S3VectorClient(VectorDBBase):
|
||||||
dimension=dimension,
|
dimension=dimension,
|
||||||
distanceMetric=distance_metric,
|
distanceMetric=distance_metric,
|
||||||
)
|
)
|
||||||
log.info(f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})")
|
log.info(
|
||||||
|
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error creating S3 index '{index_name}': {e}")
|
log.error(f"Error creating S3 index '{index_name}': {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]:
|
def _filter_metadata(
|
||||||
|
self, metadata: Dict[str, Any], item_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
|
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
|
||||||
"""
|
"""
|
||||||
if not isinstance(metadata, dict) or len(metadata) <= 10:
|
if not isinstance(metadata, dict) or len(metadata) <= 10:
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
|
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
|
||||||
important_keys = [
|
important_keys = [
|
||||||
'text', # The actual document content
|
"text", # The actual document content
|
||||||
'file_id', # File ID
|
"file_id", # File ID
|
||||||
'source', # Document source file
|
"source", # Document source file
|
||||||
'title', # Document title
|
"title", # Document title
|
||||||
'page', # Page number
|
"page", # Page number
|
||||||
'total_pages', # Total pages in document
|
"total_pages", # Total pages in document
|
||||||
'embedding_config', # Embedding configuration
|
"embedding_config", # Embedding configuration
|
||||||
'created_by', # User who created it
|
"created_by", # User who created it
|
||||||
'name', # Document name
|
"name", # Document name
|
||||||
'hash', # Content hash
|
"hash", # Content hash
|
||||||
]
|
]
|
||||||
filtered_metadata = {}
|
filtered_metadata = {}
|
||||||
|
|
||||||
# First, add important keys if they exist
|
# First, add important keys if they exist
|
||||||
for key in important_keys:
|
for key in important_keys:
|
||||||
if key in metadata:
|
if key in metadata:
|
||||||
filtered_metadata[key] = metadata[key]
|
filtered_metadata[key] = metadata[key]
|
||||||
if len(filtered_metadata) >= 10:
|
if len(filtered_metadata) >= 10:
|
||||||
break
|
break
|
||||||
|
|
||||||
# If we still have room, add other keys
|
# If we still have room, add other keys
|
||||||
if len(filtered_metadata) < 10:
|
if len(filtered_metadata) < 10:
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
|
|
@ -90,15 +108,17 @@ class S3VectorClient(VectorDBBase):
|
||||||
filtered_metadata[key] = value
|
filtered_metadata[key] = value
|
||||||
if len(filtered_metadata) >= 10:
|
if len(filtered_metadata) >= 10:
|
||||||
break
|
break
|
||||||
|
|
||||||
log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys")
|
log.warning(
|
||||||
|
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
|
||||||
|
)
|
||||||
return filtered_metadata
|
return filtered_metadata
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a vector index (collection) exists in the S3 vector bucket.
|
Check if a vector index (collection) exists in the S3 vector bucket.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||||
indexes = response.get("indexes", [])
|
indexes = response.get("indexes", [])
|
||||||
|
|
@ -106,21 +126,22 @@ class S3VectorClient(VectorDBBase):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error listing indexes: {e}")
|
log.error(f"Error listing indexes: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Delete an entire S3 Vector index/collection.
|
Delete an entire S3 Vector index/collection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
|
log.warning(
|
||||||
|
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(f"Deleting collection '{collection_name}'")
|
log.info(f"Deleting collection '{collection_name}'")
|
||||||
self.client.delete_index(
|
self.client.delete_index(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name, indexName=collection_name
|
||||||
indexName=collection_name
|
|
||||||
)
|
)
|
||||||
log.info(f"Successfully deleted collection '{collection_name}'")
|
log.info(f"Successfully deleted collection '{collection_name}'")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -134,9 +155,9 @@ class S3VectorClient(VectorDBBase):
|
||||||
if not items:
|
if not items:
|
||||||
log.warning("No items to insert")
|
log.warning("No items to insert")
|
||||||
return
|
return
|
||||||
|
|
||||||
dimension = len(items[0]["vector"])
|
dimension = len(items[0]["vector"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.info(f"Index '{collection_name}' does not exist. Creating index.")
|
log.info(f"Index '{collection_name}' does not exist. Creating index.")
|
||||||
|
|
@ -146,7 +167,7 @@ class S3VectorClient(VectorDBBase):
|
||||||
data_type="float32",
|
data_type="float32",
|
||||||
distance_metric="cosine",
|
distance_metric="cosine",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare vectors for insertion
|
# Prepare vectors for insertion
|
||||||
vectors = []
|
vectors = []
|
||||||
for item in items:
|
for item in items:
|
||||||
|
|
@ -155,28 +176,28 @@ class S3VectorClient(VectorDBBase):
|
||||||
if isinstance(vector_data, list):
|
if isinstance(vector_data, list):
|
||||||
# Convert list to float32 values as required by S3 Vector API
|
# Convert list to float32 values as required by S3 Vector API
|
||||||
vector_data = [float(x) for x in vector_data]
|
vector_data = [float(x) for x in vector_data]
|
||||||
|
|
||||||
# Prepare metadata, ensuring the text field is preserved
|
# Prepare metadata, ensuring the text field is preserved
|
||||||
metadata = item.get("metadata", {}).copy()
|
metadata = item.get("metadata", {}).copy()
|
||||||
|
|
||||||
# Add the text field to metadata so it's available for retrieval
|
# Add the text field to metadata so it's available for retrieval
|
||||||
metadata["text"] = item["text"]
|
metadata["text"] = item["text"]
|
||||||
|
|
||||||
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||||
metadata = self._filter_metadata(metadata, item["id"])
|
metadata = self._filter_metadata(metadata, item["id"])
|
||||||
|
|
||||||
vectors.append({
|
vectors.append(
|
||||||
"key": item["id"],
|
{
|
||||||
"data": {
|
"key": item["id"],
|
||||||
"float32": vector_data
|
"data": {"float32": vector_data},
|
||||||
},
|
"metadata": metadata,
|
||||||
"metadata": metadata
|
}
|
||||||
})
|
)
|
||||||
# Insert vectors
|
# Insert vectors
|
||||||
self.client.put_vectors(
|
self.client.put_vectors(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name,
|
||||||
indexName=collection_name,
|
indexName=collection_name,
|
||||||
vectors=vectors
|
vectors=vectors,
|
||||||
)
|
)
|
||||||
log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
|
log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -190,20 +211,22 @@ class S3VectorClient(VectorDBBase):
|
||||||
if not items:
|
if not items:
|
||||||
log.warning("No items to upsert")
|
log.warning("No items to upsert")
|
||||||
return
|
return
|
||||||
|
|
||||||
dimension = len(items[0]["vector"])
|
dimension = len(items[0]["vector"])
|
||||||
log.info(f"Upsert dimension: {dimension}")
|
log.info(f"Upsert dimension: {dimension}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.")
|
log.info(
|
||||||
|
f"Index '{collection_name}' does not exist. Creating index for upsert."
|
||||||
|
)
|
||||||
self._create_index(
|
self._create_index(
|
||||||
index_name=collection_name,
|
index_name=collection_name,
|
||||||
dimension=dimension,
|
dimension=dimension,
|
||||||
data_type="float32",
|
data_type="float32",
|
||||||
distance_metric="cosine",
|
distance_metric="cosine",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare vectors for upsert
|
# Prepare vectors for upsert
|
||||||
vectors = []
|
vectors = []
|
||||||
for item in items:
|
for item in items:
|
||||||
|
|
@ -212,65 +235,69 @@ class S3VectorClient(VectorDBBase):
|
||||||
if isinstance(vector_data, list):
|
if isinstance(vector_data, list):
|
||||||
# Convert list to float32 values as required by S3 Vector API
|
# Convert list to float32 values as required by S3 Vector API
|
||||||
vector_data = [float(x) for x in vector_data]
|
vector_data = [float(x) for x in vector_data]
|
||||||
|
|
||||||
# Prepare metadata, ensuring the text field is preserved
|
# Prepare metadata, ensuring the text field is preserved
|
||||||
metadata = item.get("metadata", {}).copy()
|
metadata = item.get("metadata", {}).copy()
|
||||||
# Add the text field to metadata so it's available for retrieval
|
# Add the text field to metadata so it's available for retrieval
|
||||||
metadata["text"] = item["text"]
|
metadata["text"] = item["text"]
|
||||||
|
|
||||||
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||||
metadata = self._filter_metadata(metadata, item["id"])
|
metadata = self._filter_metadata(metadata, item["id"])
|
||||||
|
|
||||||
vectors.append({
|
vectors.append(
|
||||||
"key": item["id"],
|
{
|
||||||
"data": {
|
"key": item["id"],
|
||||||
"float32": vector_data
|
"data": {"float32": vector_data},
|
||||||
},
|
"metadata": metadata,
|
||||||
"metadata": metadata
|
}
|
||||||
})
|
)
|
||||||
# Upsert vectors (using put_vectors for upsert semantics)
|
# Upsert vectors (using put_vectors for upsert semantics)
|
||||||
log.info(f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}")
|
log.info(
|
||||||
|
f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}"
|
||||||
|
)
|
||||||
self.client.put_vectors(
|
self.client.put_vectors(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name,
|
||||||
indexName=collection_name,
|
indexName=collection_name,
|
||||||
vectors=vectors
|
vectors=vectors,
|
||||||
)
|
)
|
||||||
log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
|
log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error upserting vectors: {e}")
|
log.error(f"Error upserting vectors: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search(self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int) -> Optional[SearchResult]:
|
def search(
|
||||||
|
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search for similar vectors in a collection using multiple query vectors.
|
Search for similar vectors in a collection using multiple query vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(f"Collection '{collection_name}' does not exist")
|
log.warning(f"Collection '{collection_name}' does not exist")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not vectors:
|
if not vectors:
|
||||||
log.warning("No query vectors provided")
|
log.warning("No query vectors provided")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}")
|
log.info(
|
||||||
|
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize result lists
|
# Initialize result lists
|
||||||
all_ids = []
|
all_ids = []
|
||||||
all_documents = []
|
all_documents = []
|
||||||
all_metadatas = []
|
all_metadatas = []
|
||||||
all_distances = []
|
all_distances = []
|
||||||
|
|
||||||
# Process each query vector
|
# Process each query vector
|
||||||
for i, query_vector in enumerate(vectors):
|
for i, query_vector in enumerate(vectors):
|
||||||
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
|
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
|
||||||
|
|
||||||
# Prepare the query vector in S3 Vector format
|
# Prepare the query vector in S3 Vector format
|
||||||
query_vector_dict = {
|
query_vector_dict = {"float32": [float(x) for x in query_vector]}
|
||||||
'float32': [float(x) for x in query_vector]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Call S3 Vector query API
|
# Call S3 Vector query API
|
||||||
response = self.client.query_vectors(
|
response = self.client.query_vectors(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name,
|
||||||
|
|
@ -278,109 +305,119 @@ class S3VectorClient(VectorDBBase):
|
||||||
topK=limit,
|
topK=limit,
|
||||||
queryVector=query_vector_dict,
|
queryVector=query_vector_dict,
|
||||||
returnMetadata=True,
|
returnMetadata=True,
|
||||||
returnDistance=True
|
returnDistance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process results for this query
|
# Process results for this query
|
||||||
query_ids = []
|
query_ids = []
|
||||||
query_documents = []
|
query_documents = []
|
||||||
query_metadatas = []
|
query_metadatas = []
|
||||||
query_distances = []
|
query_distances = []
|
||||||
|
|
||||||
result_vectors = response.get('vectors', [])
|
result_vectors = response.get("vectors", [])
|
||||||
|
|
||||||
for vector in result_vectors:
|
for vector in result_vectors:
|
||||||
vector_id = vector.get('key')
|
vector_id = vector.get("key")
|
||||||
vector_metadata = vector.get('metadata', {})
|
vector_metadata = vector.get("metadata", {})
|
||||||
vector_distance = vector.get('distance', 0.0)
|
vector_distance = vector.get("distance", 0.0)
|
||||||
|
|
||||||
# Extract document text from metadata
|
# Extract document text from metadata
|
||||||
document_text = ""
|
document_text = ""
|
||||||
if isinstance(vector_metadata, dict):
|
if isinstance(vector_metadata, dict):
|
||||||
# Get the text field first (highest priority)
|
# Get the text field first (highest priority)
|
||||||
document_text = vector_metadata.get('text')
|
document_text = vector_metadata.get("text")
|
||||||
if not document_text:
|
if not document_text:
|
||||||
# Fallback to other possible text fields
|
# Fallback to other possible text fields
|
||||||
document_text = (vector_metadata.get('content') or
|
document_text = (
|
||||||
vector_metadata.get('document') or
|
vector_metadata.get("content")
|
||||||
vector_id)
|
or vector_metadata.get("document")
|
||||||
|
or vector_id
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
document_text = vector_id
|
document_text = vector_id
|
||||||
|
|
||||||
query_ids.append(vector_id)
|
query_ids.append(vector_id)
|
||||||
query_documents.append(document_text)
|
query_documents.append(document_text)
|
||||||
query_metadatas.append(vector_metadata)
|
query_metadatas.append(vector_metadata)
|
||||||
query_distances.append(vector_distance)
|
query_distances.append(vector_distance)
|
||||||
|
|
||||||
# Add this query's results to the overall results
|
# Add this query's results to the overall results
|
||||||
all_ids.append(query_ids)
|
all_ids.append(query_ids)
|
||||||
all_documents.append(query_documents)
|
all_documents.append(query_documents)
|
||||||
all_metadatas.append(query_metadatas)
|
all_metadatas.append(query_metadatas)
|
||||||
all_distances.append(query_distances)
|
all_distances.append(query_distances)
|
||||||
|
|
||||||
log.info(f"Search completed. Found results for {len(all_ids)} queries")
|
log.info(f"Search completed. Found results for {len(all_ids)} queries")
|
||||||
|
|
||||||
# Return SearchResult format
|
# Return SearchResult format
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
ids=all_ids if all_ids else None,
|
ids=all_ids if all_ids else None,
|
||||||
documents=all_documents if all_documents else None,
|
documents=all_documents if all_documents else None,
|
||||||
metadatas=all_metadatas if all_metadatas else None,
|
metadatas=all_metadatas if all_metadatas else None,
|
||||||
distances=all_distances if all_distances else None
|
distances=all_distances if all_distances else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error searching collection '{collection_name}': {str(e)}")
|
log.error(f"Error searching collection '{collection_name}': {str(e)}")
|
||||||
# Handle specific AWS exceptions
|
# Handle specific AWS exceptions
|
||||||
if hasattr(e, 'response') and 'Error' in e.response:
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
error_code = e.response['Error']['Code']
|
error_code = e.response["Error"]["Code"]
|
||||||
if error_code == 'NotFoundException':
|
if error_code == "NotFoundException":
|
||||||
log.warning(f"Collection '{collection_name}' not found")
|
log.warning(f"Collection '{collection_name}' not found")
|
||||||
return None
|
return None
|
||||||
elif error_code == 'ValidationException':
|
elif error_code == "ValidationException":
|
||||||
log.error(f"Invalid query vector dimensions or parameters")
|
log.error(f"Invalid query vector dimensions or parameters")
|
||||||
return None
|
return None
|
||||||
elif error_code == 'AccessDeniedException':
|
elif error_code == "AccessDeniedException":
|
||||||
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
|
log.error(
|
||||||
|
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
def query(
|
||||||
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
"""
|
"""
|
||||||
Query vectors from a collection using metadata filter.
|
Query vectors from a collection using metadata filter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(f"Collection '{collection_name}' does not exist")
|
log.warning(f"Collection '{collection_name}' does not exist")
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
if not filter:
|
if not filter:
|
||||||
log.warning("No filter provided, returning all vectors")
|
log.warning("No filter provided, returning all vectors")
|
||||||
return self.get(collection_name)
|
return self.get(collection_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
|
log.info(f"Querying collection '{collection_name}' with filter: {filter}")
|
||||||
|
|
||||||
# For S3 Vector, we need to use list_vectors and then filter results
|
# For S3 Vector, we need to use list_vectors and then filter results
|
||||||
# Since S3 Vector may not support complex server-side filtering,
|
# Since S3 Vector may not support complex server-side filtering,
|
||||||
# we'll retrieve all vectors and filter client-side
|
# we'll retrieve all vectors and filter client-side
|
||||||
|
|
||||||
# Get all vectors first
|
# Get all vectors first
|
||||||
all_vectors_result = self.get(collection_name)
|
all_vectors_result = self.get(collection_name)
|
||||||
|
|
||||||
if not all_vectors_result or not all_vectors_result.ids:
|
if not all_vectors_result or not all_vectors_result.ids:
|
||||||
log.warning("No vectors found in collection")
|
log.warning("No vectors found in collection")
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
# Extract the lists from the result
|
# Extract the lists from the result
|
||||||
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
|
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
|
||||||
all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else []
|
all_documents = (
|
||||||
all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
|
all_vectors_result.documents[0] if all_vectors_result.documents else []
|
||||||
|
)
|
||||||
|
all_metadatas = (
|
||||||
|
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
|
||||||
|
)
|
||||||
|
|
||||||
# Apply client-side filtering
|
# Apply client-side filtering
|
||||||
filtered_ids = []
|
filtered_ids = []
|
||||||
filtered_documents = []
|
filtered_documents = []
|
||||||
filtered_metadatas = []
|
filtered_metadatas = []
|
||||||
|
|
||||||
for i, metadata in enumerate(all_metadatas):
|
for i, metadata in enumerate(all_metadatas):
|
||||||
if self._matches_filter(metadata, filter):
|
if self._matches_filter(metadata, filter):
|
||||||
if i < len(all_ids):
|
if i < len(all_ids):
|
||||||
|
|
@ -388,29 +425,37 @@ class S3VectorClient(VectorDBBase):
|
||||||
if i < len(all_documents):
|
if i < len(all_documents):
|
||||||
filtered_documents.append(all_documents[i])
|
filtered_documents.append(all_documents[i])
|
||||||
filtered_metadatas.append(metadata)
|
filtered_metadatas.append(metadata)
|
||||||
|
|
||||||
# Apply limit if specified
|
# Apply limit if specified
|
||||||
if limit and len(filtered_ids) >= limit:
|
if limit and len(filtered_ids) >= limit:
|
||||||
break
|
break
|
||||||
|
|
||||||
log.info(f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total")
|
log.info(
|
||||||
|
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
|
||||||
|
)
|
||||||
|
|
||||||
# Return GetResult format
|
# Return GetResult format
|
||||||
if filtered_ids:
|
if filtered_ids:
|
||||||
return GetResult(ids=[filtered_ids], documents=[filtered_documents], metadatas=[filtered_metadatas])
|
return GetResult(
|
||||||
|
ids=[filtered_ids],
|
||||||
|
documents=[filtered_documents],
|
||||||
|
metadatas=[filtered_metadatas],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error querying collection '{collection_name}': {str(e)}")
|
log.error(f"Error querying collection '{collection_name}': {str(e)}")
|
||||||
# Handle specific AWS exceptions
|
# Handle specific AWS exceptions
|
||||||
if hasattr(e, 'response') and 'Error' in e.response:
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
error_code = e.response['Error']['Code']
|
error_code = e.response["Error"]["Code"]
|
||||||
if error_code == 'NotFoundException':
|
if error_code == "NotFoundException":
|
||||||
log.warning(f"Collection '{collection_name}' not found")
|
log.warning(f"Collection '{collection_name}' not found")
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
elif error_code == 'AccessDeniedException':
|
elif error_code == "AccessDeniedException":
|
||||||
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
|
log.error(
|
||||||
|
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||||
|
)
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
@ -418,170 +463,203 @@ class S3VectorClient(VectorDBBase):
|
||||||
"""
|
"""
|
||||||
Retrieve all vectors from a collection.
|
Retrieve all vectors from a collection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(f"Collection '{collection_name}' does not exist")
|
log.warning(f"Collection '{collection_name}' does not exist")
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(f"Retrieving all vectors from collection '{collection_name}'")
|
log.info(f"Retrieving all vectors from collection '{collection_name}'")
|
||||||
|
|
||||||
# Initialize result lists
|
# Initialize result lists
|
||||||
all_ids = []
|
all_ids = []
|
||||||
all_documents = []
|
all_documents = []
|
||||||
all_metadatas = []
|
all_metadatas = []
|
||||||
|
|
||||||
# Handle pagination
|
# Handle pagination
|
||||||
next_token = None
|
next_token = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Prepare request parameters
|
# Prepare request parameters
|
||||||
request_params = {
|
request_params = {
|
||||||
'vectorBucketName': self.bucket_name,
|
"vectorBucketName": self.bucket_name,
|
||||||
'indexName': collection_name,
|
"indexName": collection_name,
|
||||||
'returnData': False, # Don't include vector data (not needed for get)
|
"returnData": False, # Don't include vector data (not needed for get)
|
||||||
'returnMetadata': True, # Include metadata
|
"returnMetadata": True, # Include metadata
|
||||||
'maxResults': 500 # Use reasonable page size
|
"maxResults": 500, # Use reasonable page size
|
||||||
}
|
}
|
||||||
|
|
||||||
if next_token:
|
if next_token:
|
||||||
request_params['nextToken'] = next_token
|
request_params["nextToken"] = next_token
|
||||||
|
|
||||||
# Call S3 Vector API
|
# Call S3 Vector API
|
||||||
response = self.client.list_vectors(**request_params)
|
response = self.client.list_vectors(**request_params)
|
||||||
|
|
||||||
# Process vectors in this page
|
# Process vectors in this page
|
||||||
vectors = response.get('vectors', [])
|
vectors = response.get("vectors", [])
|
||||||
|
|
||||||
for vector in vectors:
|
for vector in vectors:
|
||||||
vector_id = vector.get('key')
|
vector_id = vector.get("key")
|
||||||
vector_data = vector.get('data', {})
|
vector_data = vector.get("data", {})
|
||||||
vector_metadata = vector.get('metadata', {})
|
vector_metadata = vector.get("metadata", {})
|
||||||
|
|
||||||
# Extract the actual vector array
|
# Extract the actual vector array
|
||||||
vector_array = vector_data.get('float32', [])
|
vector_array = vector_data.get("float32", [])
|
||||||
|
|
||||||
# For documents, we try to extract text from metadata or use the vector ID
|
# For documents, we try to extract text from metadata or use the vector ID
|
||||||
document_text = ""
|
document_text = ""
|
||||||
if isinstance(vector_metadata, dict):
|
if isinstance(vector_metadata, dict):
|
||||||
# Get the text field first (highest priority)
|
# Get the text field first (highest priority)
|
||||||
document_text = vector_metadata.get('text')
|
document_text = vector_metadata.get("text")
|
||||||
if not document_text:
|
if not document_text:
|
||||||
# Fallback to other possible text fields
|
# Fallback to other possible text fields
|
||||||
document_text = (vector_metadata.get('content') or
|
document_text = (
|
||||||
vector_metadata.get('document') or
|
vector_metadata.get("content")
|
||||||
vector_id)
|
or vector_metadata.get("document")
|
||||||
|
or vector_id
|
||||||
|
)
|
||||||
|
|
||||||
# Log the actual content for debugging
|
# Log the actual content for debugging
|
||||||
log.debug(f"Document text preview (first 200 chars): {str(document_text)[:200]}")
|
log.debug(
|
||||||
|
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
document_text = vector_id
|
document_text = vector_id
|
||||||
|
|
||||||
all_ids.append(vector_id)
|
all_ids.append(vector_id)
|
||||||
all_documents.append(document_text)
|
all_documents.append(document_text)
|
||||||
all_metadatas.append(vector_metadata)
|
all_metadatas.append(vector_metadata)
|
||||||
|
|
||||||
# Check if there are more pages
|
# Check if there are more pages
|
||||||
next_token = response.get('nextToken')
|
next_token = response.get("nextToken")
|
||||||
if not next_token:
|
if not next_token:
|
||||||
break
|
break
|
||||||
|
|
||||||
log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'")
|
log.info(
|
||||||
|
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
|
||||||
|
)
|
||||||
|
|
||||||
# Return in GetResult format
|
# Return in GetResult format
|
||||||
# The Open WebUI GetResult expects lists of lists, so we wrap each list
|
# The Open WebUI GetResult expects lists of lists, so we wrap each list
|
||||||
if all_ids:
|
if all_ids:
|
||||||
return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas])
|
return GetResult(
|
||||||
|
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}")
|
log.error(
|
||||||
|
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
|
||||||
|
)
|
||||||
# Handle specific AWS exceptions
|
# Handle specific AWS exceptions
|
||||||
if hasattr(e, 'response') and 'Error' in e.response:
|
if hasattr(e, "response") and "Error" in e.response:
|
||||||
error_code = e.response['Error']['Code']
|
error_code = e.response["Error"]["Code"]
|
||||||
if error_code == 'NotFoundException':
|
if error_code == "NotFoundException":
|
||||||
log.warning(f"Collection '{collection_name}' not found")
|
log.warning(f"Collection '{collection_name}' not found")
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
elif error_code == 'AccessDeniedException':
|
elif error_code == "AccessDeniedException":
|
||||||
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
|
log.error(
|
||||||
|
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||||
|
)
|
||||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def delete(self, collection_name: str, ids: Optional[List[str]] = None, filter: Optional[Dict] = None) -> None:
|
def delete(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
filter: Optional[Dict] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Delete vectors by ID or filter from a collection.
|
Delete vectors by ID or filter from a collection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
|
log.warning(
|
||||||
|
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if this is a knowledge collection (not file-specific)
|
# Check if this is a knowledge collection (not file-specific)
|
||||||
is_knowledge_collection = not collection_name.startswith("file-")
|
is_knowledge_collection = not collection_name.startswith("file-")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ids:
|
if ids:
|
||||||
# Delete by specific vector IDs/keys
|
# Delete by specific vector IDs/keys
|
||||||
log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'")
|
log.info(
|
||||||
|
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
|
||||||
|
)
|
||||||
self.client.delete_vectors(
|
self.client.delete_vectors(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name,
|
||||||
indexName=collection_name,
|
indexName=collection_name,
|
||||||
keys=ids
|
keys=ids,
|
||||||
)
|
)
|
||||||
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
|
log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'")
|
||||||
|
|
||||||
elif filter:
|
elif filter:
|
||||||
# Handle filter-based deletion
|
# Handle filter-based deletion
|
||||||
log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}")
|
log.info(
|
||||||
|
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
|
||||||
|
)
|
||||||
|
|
||||||
# If this is a knowledge collection and we have a file_id filter,
|
# If this is a knowledge collection and we have a file_id filter,
|
||||||
# also clean up the corresponding file-specific collection
|
# also clean up the corresponding file-specific collection
|
||||||
if is_knowledge_collection and "file_id" in filter:
|
if is_knowledge_collection and "file_id" in filter:
|
||||||
file_id = filter["file_id"]
|
file_id = filter["file_id"]
|
||||||
file_collection_name = f"file-{file_id}"
|
file_collection_name = f"file-{file_id}"
|
||||||
if self.has_collection(file_collection_name):
|
if self.has_collection(file_collection_name):
|
||||||
log.info(f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates")
|
log.info(
|
||||||
|
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
|
||||||
|
)
|
||||||
self.delete_collection(file_collection_name)
|
self.delete_collection(file_collection_name)
|
||||||
|
|
||||||
# For the main collection, implement query-then-delete
|
# For the main collection, implement query-then-delete
|
||||||
# First, query to get IDs matching the filter
|
# First, query to get IDs matching the filter
|
||||||
query_result = self.query(collection_name, filter)
|
query_result = self.query(collection_name, filter)
|
||||||
if query_result and query_result.ids and query_result.ids[0]:
|
if query_result and query_result.ids and query_result.ids[0]:
|
||||||
matching_ids = query_result.ids[0]
|
matching_ids = query_result.ids[0]
|
||||||
log.info(f"Found {len(matching_ids)} vectors matching filter, deleting them")
|
log.info(
|
||||||
|
f"Found {len(matching_ids)} vectors matching filter, deleting them"
|
||||||
|
)
|
||||||
|
|
||||||
# Delete the matching vectors by ID
|
# Delete the matching vectors by ID
|
||||||
self.client.delete_vectors(
|
self.client.delete_vectors(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name,
|
||||||
indexName=collection_name,
|
indexName=collection_name,
|
||||||
keys=matching_ids
|
keys=matching_ids,
|
||||||
|
)
|
||||||
|
log.info(
|
||||||
|
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
|
||||||
)
|
)
|
||||||
log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter")
|
|
||||||
else:
|
else:
|
||||||
log.warning("No vectors found matching the filter criteria")
|
log.warning("No vectors found matching the filter criteria")
|
||||||
else:
|
else:
|
||||||
log.warning("No IDs or filter provided for deletion")
|
log.warning("No IDs or filter provided for deletion")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting vectors from collection '{collection_name}': {e}")
|
log.error(
|
||||||
|
f"Error deleting vectors from collection '{collection_name}': {e}"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
|
Reset/clear all vector data. For S3 Vector, this deletes all indexes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.warning("Reset called - this will delete all vector indexes in the S3 bucket")
|
log.warning(
|
||||||
|
"Reset called - this will delete all vector indexes in the S3 bucket"
|
||||||
|
)
|
||||||
|
|
||||||
# List all indexes
|
# List all indexes
|
||||||
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||||
indexes = response.get("indexes", [])
|
indexes = response.get("indexes", [])
|
||||||
|
|
||||||
if not indexes:
|
if not indexes:
|
||||||
log.warning("No indexes found to delete")
|
log.warning("No indexes found to delete")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Delete all indexes
|
# Delete all indexes
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
|
|
@ -589,39 +667,38 @@ class S3VectorClient(VectorDBBase):
|
||||||
if index_name:
|
if index_name:
|
||||||
try:
|
try:
|
||||||
self.client.delete_index(
|
self.client.delete_index(
|
||||||
vectorBucketName=self.bucket_name,
|
vectorBucketName=self.bucket_name, indexName=index_name
|
||||||
indexName=index_name
|
|
||||||
)
|
)
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
log.info(f"Deleted index: {index_name}")
|
log.info(f"Deleted index: {index_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting index '{index_name}': {e}")
|
log.error(f"Error deleting index '{index_name}': {e}")
|
||||||
|
|
||||||
log.info(f"Reset completed: deleted {deleted_count} indexes")
|
log.info(f"Reset completed: deleted {deleted_count} indexes")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error during reset: {e}")
|
log.error(f"Error during reset: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
|
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if metadata matches the given filter conditions.
|
Check if metadata matches the given filter conditions.
|
||||||
"""
|
"""
|
||||||
if not isinstance(metadata, dict) or not isinstance(filter, dict):
|
if not isinstance(metadata, dict) or not isinstance(filter, dict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check each filter condition
|
# Check each filter condition
|
||||||
for key, expected_value in filter.items():
|
for key, expected_value in filter.items():
|
||||||
# Handle special operators
|
# Handle special operators
|
||||||
if key.startswith('$'):
|
if key.startswith("$"):
|
||||||
if key == '$and':
|
if key == "$and":
|
||||||
# All conditions must match
|
# All conditions must match
|
||||||
if not isinstance(expected_value, list):
|
if not isinstance(expected_value, list):
|
||||||
continue
|
continue
|
||||||
for condition in expected_value:
|
for condition in expected_value:
|
||||||
if not self._matches_filter(metadata, condition):
|
if not self._matches_filter(metadata, condition):
|
||||||
return False
|
return False
|
||||||
elif key == '$or':
|
elif key == "$or":
|
||||||
# At least one condition must match
|
# At least one condition must match
|
||||||
if not isinstance(expected_value, list):
|
if not isinstance(expected_value, list):
|
||||||
continue
|
continue
|
||||||
|
|
@ -633,27 +710,30 @@ class S3VectorClient(VectorDBBase):
|
||||||
if not any_match:
|
if not any_match:
|
||||||
return False
|
return False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the actual value from metadata
|
# Get the actual value from metadata
|
||||||
actual_value = metadata.get(key)
|
actual_value = metadata.get(key)
|
||||||
|
|
||||||
# Handle different types of expected values
|
# Handle different types of expected values
|
||||||
if isinstance(expected_value, dict):
|
if isinstance(expected_value, dict):
|
||||||
# Handle comparison operators
|
# Handle comparison operators
|
||||||
for op, op_value in expected_value.items():
|
for op, op_value in expected_value.items():
|
||||||
if op == '$eq':
|
if op == "$eq":
|
||||||
if actual_value != op_value:
|
if actual_value != op_value:
|
||||||
return False
|
return False
|
||||||
elif op == '$ne':
|
elif op == "$ne":
|
||||||
if actual_value == op_value:
|
if actual_value == op_value:
|
||||||
return False
|
return False
|
||||||
elif op == '$in':
|
elif op == "$in":
|
||||||
if not isinstance(op_value, list) or actual_value not in op_value:
|
if (
|
||||||
|
not isinstance(op_value, list)
|
||||||
|
or actual_value not in op_value
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
elif op == '$nin':
|
elif op == "$nin":
|
||||||
if isinstance(op_value, list) and actual_value in op_value:
|
if isinstance(op_value, list) and actual_value in op_value:
|
||||||
return False
|
return False
|
||||||
elif op == '$exists':
|
elif op == "$exists":
|
||||||
if bool(op_value) != (key in metadata):
|
if bool(op_value) != (key in metadata):
|
||||||
return False
|
return False
|
||||||
# Add more operators as needed
|
# Add more operators as needed
|
||||||
|
|
@ -661,5 +741,5 @@ class S3VectorClient(VectorDBBase):
|
||||||
# Simple equality check
|
# Simple equality check
|
||||||
if actual_value != expected_value:
|
if actual_value != expected_value:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,11 @@ class Vector:
|
||||||
return QdrantClient()
|
return QdrantClient()
|
||||||
case VectorType.PINECONE:
|
case VectorType.PINECONE:
|
||||||
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
from open_webui.retrieval.vector.dbs.pinecone import PineconeClient
|
||||||
|
|
||||||
return PineconeClient()
|
return PineconeClient()
|
||||||
case VectorType.S3VECTOR:
|
case VectorType.S3VECTOR:
|
||||||
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
|
from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient
|
||||||
|
|
||||||
return S3VectorClient()
|
return S3VectorClient()
|
||||||
case VectorType.OPENSEARCH:
|
case VectorType.OPENSEARCH:
|
||||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue