From 5c59c50e2d530779b27c4f892531ec00caab191d Mon Sep 17 00:00:00 2001 From: "0xThresh.eth" <0xthresh@protonmail.com> Date: Sun, 20 Jul 2025 16:48:23 -0600 Subject: [PATCH] more prgoress on s3 vector --- .../retrieval/vector/s3/s3vector.py | 719 ++++++++++++++++++ 1 file changed, 719 insertions(+) create mode 100644 backend/open_webui/retrieval/vector/s3/s3vector.py diff --git a/backend/open_webui/retrieval/vector/s3/s3vector.py b/backend/open_webui/retrieval/vector/s3/s3vector.py new file mode 100644 index 0000000000..88392d09ba --- /dev/null +++ b/backend/open_webui/retrieval/vector/s3/s3vector.py @@ -0,0 +1,719 @@ +from open_webui.retrieval.vector.main import VectorDBBase, VectorItem, GetResult, SearchResult +from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION +from open_webui.env import SRC_LOG_LEVELS +from typing import List, Optional, Dict, Any, Union +import logging +import boto3 + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +class S3VectorClient(VectorDBBase): + """ + AWS S3 Vector integration for Open WebUI Knowledge. + Assumes AWS credentials are available via environment variables or IAM roles. + """ + def __init__(self): + self.bucket_name = S3_VECTOR_BUCKET_NAME + self.region = S3_VECTOR_REGION + self.client = boto3.client("s3vectors", region_name=self.region) + + def _create_index(self, index_name: str, dimension: int, data_type: str = "float32", distance_metric: str = "cosine"): + """ + Create a new index in the S3 vector bucket for the given collection if it does not exist. + """ + if self.has_collection(index_name): + return + try: + self.client.create_index( + vectorBucketName=self.bucket_name, + indexName=index_name, + dataType=data_type, + dimension=dimension, + distanceMetric=distance_metric, + ) + log.info(f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})") + except Exception as e: + log.error(f"Error creating S3 index '{index_name}': {e}") + + def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]: + """ + Filter metadata to comply with S3 Vector API limit of 10 keys maximum. + If AWS S3 Vector feature starts supporting more than 10 keys, this should be adjusted, and preferably removed. + """ + if not isinstance(metadata, dict) or len(metadata) <= 10: + return metadata + + # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata + important_keys = [ + 'text', # THE MOST IMPORTANT - the actual document content + 'file_id', # File ID + 'source', # Document source file + 'title', # Document title + 'page', # Page number + 'total_pages', # Total pages in document + 'embedding_config', # Embedding configuration + 'created_by', # User who created it + 'name', # Document name + 'hash', # Content hash + ] + filtered_metadata = {} + + # First, add important keys if they exist + for key in important_keys: + if key in metadata: + filtered_metadata[key] = metadata[key] + if len(filtered_metadata) >= 10: + break + + # If we still have room, add other keys + if len(filtered_metadata) < 10: + for key, value in metadata.items(): + if key not in filtered_metadata: + filtered_metadata[key] = value + if len(filtered_metadata) >= 10: + break + + log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys") + return filtered_metadata + + def _check_for_duplicate_file_collections(self, knowledge_collection_name: str, new_items: List[Dict[str, Any]]) -> None: + """ + Check for existing file-specific collections that might create duplicates. + """ + # Extract file IDs from the new items to find corresponding file collections + file_ids = set() + for item in new_items: + metadata = item.get("metadata", {}) + file_id = metadata.get("file_id") + if file_id: + file_ids.add(file_id) + + # Check for existing file-specific collections + duplicate_collections = [] + for file_id in file_ids: + file_collection_name = f"file-{file_id}" + if self.has_collection(file_collection_name): + duplicate_collections.append(file_collection_name) + + if duplicate_collections: + log.warning(f"Found existing file-specific collections that may contain duplicate vectors: {duplicate_collections}") + log.warning(f"Consider manually deleting these collections to avoid duplicate storage:") + for collection in duplicate_collections: + log.warning(f" - {collection}") + log.warning(f"Continuing with insertion to knowledge collection '{knowledge_collection_name}'") + else: + log.info(f"No duplicate file-specific collections found for knowledge collection '{knowledge_collection_name}'") + + def has_collection(self, collection_name: str) -> bool: + """ + Check if a vector index (collection) exists in the S3 vector bucket. + """ + try: + response = self.client.list_indexes(vectorBucketName=self.bucket_name) + indexes = response.get("indexes", []) + return any(idx.get("indexName") == collection_name for idx in indexes) + except Exception as e: + log.error(f"Error listing indexes: {e}") + return False + def delete_collection(self, collection_name: str) -> None: + """ + Delete an entire S3 Vector index/collection. + """ + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist, nothing to delete") + return + + try: + log.info(f"Deleting collection '{collection_name}'") + self.client.delete_index( + vectorBucketName=self.bucket_name, + indexName=collection_name + ) + log.info(f"Successfully deleted collection '{collection_name}'") + except Exception as e: + log.error(f"Error deleting collection '{collection_name}': {e}") + raise + + def insert(self, collection_name: str, items: List[Dict[str, Any]]) -> None: + """ + Insert vector items into the S3 Vector index. Create index if it does not exist. + + Supports both knowledge collection indexes and file-specific indexes (file-{file_id}). + """ + if not items: + log.warning("No items to insert") + return + + dimension = len(items[0]["vector"]) + + try: + if not self.has_collection(collection_name): + log.info(f"Index '{collection_name}' does not exist. Creating index.") + self._create_index( + index_name=collection_name, + dimension=dimension, + data_type="float32", + distance_metric="cosine", + ) + + # Check for any existing file-specific collections that might create duplicates + self._check_for_duplicate_file_collections(collection_name, items) + + # Prepare vectors for insertion + vectors = [] + for item in items: + # Ensure vector data is in the correct format for S3 Vector API + vector_data = item["vector"] + if isinstance(vector_data, list): + # Convert list to float32 values as required by S3 Vector API + vector_data = [float(x) for x in vector_data] + + # Prepare metadata, ensuring the text field is preserved + metadata = item.get("metadata", {}).copy() + + # Add the text field to metadata so it's available for retrieval + if "text" in item: + metadata["text"] = item["text"] + else: + log.warning(f"No 'text' field found in item with ID: {item.get('id')}") + + # Filter metadata to comply with S3 Vector API limit of 10 keys + metadata = self._filter_metadata(metadata, item["id"]) + + vectors.append({ + "key": item["id"], + "data": { + "float32": vector_data + }, + "metadata": metadata + }) + # Insert vectors + self.client.put_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + vectors=vectors + ) + log.info(f"Inserted {len(vectors)} vectors into index '{collection_name}'.") + except Exception as e: + log.error(f"Error inserting vectors: {e}") + raise + + def upsert(self, collection_name: str, items: List[Dict[str, Any]]) -> None: + """ + Insert or update vector items in the S3 Vector index. Create index if it does not exist. + + Supports both knowledge collections and file-specific collections for compatibility + with existing Open WebUI backend logic. + """ + if not items: + log.warning("No items to upsert") + return + + dimension = len(items[0]["vector"]) + log.info(f"Upsert dimension: {dimension}") + + try: + if not self.has_collection(collection_name): + log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.") + self._create_index( + index_name=collection_name, + dimension=dimension, + data_type="float32", + distance_metric="cosine", + ) + + # Check for any existing file-specific collections that might create duplicates + self._check_for_duplicate_file_collections(collection_name, items) + + # Prepare vectors for upsert + vectors = [] + for item in items: + # Ensure vector data is in the correct format for S3 Vector API + vector_data = item["vector"] + if isinstance(vector_data, list): + # Convert list to float32 values as required by S3 Vector API + vector_data = [float(x) for x in vector_data] + + # Prepare metadata, ensuring the text field is preserved + metadata = item.get("metadata", {}).copy() + # Add the text field to metadata so it's available for retrieval + if "text" in item: + metadata["text"] = item["text"] + + # Filter metadata to comply with S3 Vector API limit of 10 keys + metadata = self._filter_metadata(metadata, item["id"]) + + vectors.append({ + "key": item["id"], + "data": { + "float32": vector_data + }, + "metadata": metadata + }) + # Upsert vectors (using put_vectors for upsert semantics) + log.info(f"Upserting {len(vectors)} vectors. First vector sample: key={vectors[0]['key']}, data_type={type(vectors[0]['data']['float32'])}, data_len={len(vectors[0]['data']['float32'])}") + self.client.put_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + vectors=vectors + ) + log.info(f"Upserted {len(vectors)} vectors into index '{collection_name}'.") + except Exception as e: + log.error(f"Error upserting vectors: {e}") + raise + + def search(self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int) -> Optional[SearchResult]: + """ + Search for similar vectors in a collection using multiple query vectors. + + Uses S3 Vector's query_vectors API to perform similarity search. + + Args: + collection_name: Name of the collection to search in + vectors: List of query vectors to search with + limit: Maximum number of results to return per query + + Returns: + SearchResult containing IDs, documents, metadatas, and distances + """ + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist") + return None + + if not vectors: + log.warning("No query vectors provided") + return None + + try: + log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}") + + # Initialize result lists + all_ids = [] + all_documents = [] + all_metadatas = [] + all_distances = [] + + # Process each query vector + for i, query_vector in enumerate(vectors): + log.debug(f"Processing query vector {i+1}/{len(vectors)}") + + # Prepare the query vector in S3 Vector format + query_vector_dict = { + 'float32': [float(x) for x in query_vector] + } + + # Call S3 Vector query API + response = self.client.query_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + topK=limit, + queryVector=query_vector_dict, + returnMetadata=True, + returnDistance=True + ) + + # Process results for this query + query_ids = [] + query_documents = [] + query_metadatas = [] + query_distances = [] + + result_vectors = response.get('vectors', []) + + for vector in result_vectors: + vector_id = vector.get('key') + vector_metadata = vector.get('metadata', {}) + vector_distance = vector.get('distance', 0.0) + + # Extract document text from metadata + document_text = "" + if isinstance(vector_metadata, dict): + # Get the text field first (highest priority) + document_text = vector_metadata.get('text') + if not document_text: + # Fallback to other possible text fields + document_text = (vector_metadata.get('content') or + vector_metadata.get('document') or + vector_id) + else: + document_text = vector_id + + query_ids.append(vector_id) + query_documents.append(document_text) + query_metadatas.append(vector_metadata) + query_distances.append(vector_distance) + + # Add this query's results to the overall results + all_ids.append(query_ids) + all_documents.append(query_documents) + all_metadatas.append(query_metadatas) + all_distances.append(query_distances) + + log.info(f"Search completed. Found results for {len(all_ids)} queries") + + # Return SearchResult format + return SearchResult( + ids=all_ids if all_ids else None, + documents=all_documents if all_documents else None, + metadatas=all_metadatas if all_metadatas else None, + distances=all_distances if all_distances else None + ) + + except Exception as e: + log.error(f"Error searching collection '{collection_name}': {str(e)}") + # Handle specific AWS exceptions + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error']['Code'] + if error_code == 'NotFoundException': + log.warning(f"Collection '{collection_name}' not found") + return None + elif error_code == 'ValidationException': + log.error(f"Invalid query vector dimensions or parameters") + return None + elif error_code == 'AccessDeniedException': + log.error(f"Access denied for collection '{collection_name}'. Check permissions.") + return None + raise + + def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]: + """ + Query vectors from a collection using metadata filter. + + For S3 Vector, this uses the list_vectors API with metadata filters. + Note: S3 Vector supports metadata filtering, but the exact filter syntax may vary. + + Args: + collection_name: Name of the collection to query + filter: Dictionary containing metadata filter conditions + limit: Maximum number of results to return (optional) + + Returns: + GetResult containing IDs, documents, and metadatas + """ + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + if not filter: + log.warning("No filter provided, returning all vectors") + return self.get(collection_name) + + try: + log.info(f"Querying collection '{collection_name}' with filter: {filter}") + + # For S3 Vector, we need to use list_vectors and then filter results + # Since S3 Vector may not support complex server-side filtering, + # we'll retrieve all vectors and filter client-side + + # Get all vectors first + all_vectors_result = self.get(collection_name) + + if not all_vectors_result or not all_vectors_result.ids: + log.warning("No vectors found in collection") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + # Extract the lists from the result + all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else [] + all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else [] + all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else [] + + # Apply client-side filtering + filtered_ids = [] + filtered_documents = [] + filtered_metadatas = [] + + for i, metadata in enumerate(all_metadatas): + if self._matches_filter(metadata, filter): + if i < len(all_ids): + filtered_ids.append(all_ids[i]) + if i < len(all_documents): + filtered_documents.append(all_documents[i]) + filtered_metadatas.append(metadata) + + # Apply limit if specified + if limit and len(filtered_ids) >= limit: + break + + log.info(f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total") + + # Return GetResult format + if filtered_ids: + return GetResult(ids=[filtered_ids], documents=[filtered_documents], metadatas=[filtered_metadatas]) + else: + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + except Exception as e: + log.error(f"Error querying collection '{collection_name}': {str(e)}") + # Handle specific AWS exceptions + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error']['Code'] + if error_code == 'NotFoundException': + log.warning(f"Collection '{collection_name}' not found") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + elif error_code == 'AccessDeniedException': + log.error(f"Access denied for collection '{collection_name}'. Check permissions.") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + raise + + def get(self, collection_name: str) -> Optional[GetResult]: + """ + Retrieve all vectors from a collection. + + Uses S3 Vector's list_vectors API to get all vectors with their data and metadata. + Handles pagination automatically to retrieve all vectors. + """ + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + try: + log.info(f"Retrieving all vectors from collection '{collection_name}'") + + # Initialize result lists + all_ids = [] + all_documents = [] + all_metadatas = [] + + # Handle pagination + next_token = None + + while True: + # Prepare request parameters + request_params = { + 'vectorBucketName': self.bucket_name, + 'indexName': collection_name, + 'returnData': False, # Don't include vector data (not needed for get) + 'returnMetadata': True, # Include metadata + 'maxResults': 500 # Use reasonable page size + } + + if next_token: + request_params['nextToken'] = next_token + + # Call S3 Vector API + response = self.client.list_vectors(**request_params) + + # Process vectors in this page + vectors = response.get('vectors', []) + + for vector in vectors: + vector_id = vector.get('key') + vector_data = vector.get('data', {}) + vector_metadata = vector.get('metadata', {}) + + # Extract the actual vector array + vector_array = vector_data.get('float32', []) + + # For documents, we try to extract text from metadata or use the vector ID + document_text = "" + if isinstance(vector_metadata, dict): + # Get the text field first (highest priority) + document_text = vector_metadata.get('text') + if not document_text: + # Fallback to other possible text fields + document_text = (vector_metadata.get('content') or + vector_metadata.get('document') or + vector_id) + + # Log the actual content for debugging + log.debug(f"Document text preview (first 200 chars): {str(document_text)[:200]}") + else: + document_text = vector_id + + all_ids.append(vector_id) + all_documents.append(document_text) + all_metadatas.append(vector_metadata) + + # Check if there are more pages + next_token = response.get('nextToken') + if not next_token: + break + + log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'") + + # Return in GetResult format + # The Open WebUI GetResult expects lists of lists, so we wrap each list + if all_ids: + return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]) + else: + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + except Exception as e: + log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}") + # Handle specific AWS exceptions + if hasattr(e, 'response') and 'Error' in e.response: + error_code = e.response['Error']['Code'] + if error_code == 'NotFoundException': + log.warning(f"Collection '{collection_name}' not found") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + elif error_code == 'AccessDeniedException': + log.error(f"Access denied for collection '{collection_name}'. Check permissions.") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + raise + + def delete(self, collection_name: str, ids: Optional[List[str]] = None, filter: Optional[Dict] = None) -> None: + """ + Delete vectors by ID or filter from a collection. + + For S3 Vector, we support deletion by IDs. Filter-based deletion requires querying first. + For knowledge collections, also handles cleanup of related file-specific collections. + """ + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist, nothing to delete") + return + + # Check if this is a knowledge collection (not file-specific) + is_knowledge_collection = not collection_name.startswith("file-") + + try: + if ids: + # Delete by specific vector IDs/keys + log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'") + self.client.delete_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + keys=ids + ) + log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'") + + elif filter: + # Handle filter-based deletion + log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}") + + # If this is a knowledge collection and we have a file_id filter, + # also clean up the corresponding file-specific collection + if is_knowledge_collection and "file_id" in filter: + file_id = filter["file_id"] + file_collection_name = f"file-{file_id}" + if self.has_collection(file_collection_name): + log.info(f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates") + self.delete_collection(file_collection_name) + + # For the main collection, implement query-then-delete + # First, query to get IDs matching the filter + query_result = self.query(collection_name, filter) + if query_result and query_result.ids and query_result.ids[0]: + matching_ids = query_result.ids[0] + log.info(f"Found {len(matching_ids)} vectors matching filter, deleting them") + + # Delete the matching vectors by ID + self.client.delete_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + keys=matching_ids + ) + log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter") + else: + log.warning("No vectors found matching the filter criteria") + else: + log.warning("No IDs or filter provided for deletion") + except Exception as e: + log.error(f"Error deleting vectors from collection '{collection_name}': {e}") + raise + + def reset(self) -> None: + """ + Reset/clear all vector data. For S3 Vector, this would mean deleting all indexes. + Use with caution as this is destructive. + """ + try: + log.warning("Reset called - this will delete all vector indexes in the S3 bucket") + + # List all indexes + response = self.client.list_indexes(vectorBucketName=self.bucket_name) + indexes = response.get("indexes", []) + + if not indexes: + log.warning("No indexes found to delete") + return + + # Delete all indexes + deleted_count = 0 + for index in indexes: + index_name = index.get("indexName") + if index_name: + try: + self.client.delete_index( + vectorBucketName=self.bucket_name, + indexName=index_name + ) + deleted_count += 1 + log.info(f"Deleted index: {index_name}") + except Exception as e: + log.error(f"Error deleting index '{index_name}': {e}") + + log.info(f"Reset completed: deleted {deleted_count} indexes") + + except Exception as e: + log.error(f"Error during reset: {e}") + raise + + def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool: + """ + Check if metadata matches the given filter conditions. + Supports basic equality matching and simple logical operations. + + Args: + metadata: The metadata to check + filter: The filter conditions to match against + + Returns: + True if metadata matches all filter conditions, False otherwise + """ + if not isinstance(metadata, dict) or not isinstance(filter, dict): + return False + + # Check each filter condition + for key, expected_value in filter.items(): + # Handle special operators + if key.startswith('$'): + if key == '$and': + # All conditions must match + if not isinstance(expected_value, list): + continue + for condition in expected_value: + if not self._matches_filter(metadata, condition): + return False + elif key == '$or': + # At least one condition must match + if not isinstance(expected_value, list): + continue + any_match = False + for condition in expected_value: + if self._matches_filter(metadata, condition): + any_match = True + break + if not any_match: + return False + continue + + # Get the actual value from metadata + actual_value = metadata.get(key) + + # Handle different types of expected values + if isinstance(expected_value, dict): + # Handle comparison operators + for op, op_value in expected_value.items(): + if op == '$eq': + if actual_value != op_value: + return False + elif op == '$ne': + if actual_value == op_value: + return False + elif op == '$in': + if not isinstance(op_value, list) or actual_value not in op_value: + return False + elif op == '$nin': + if isinstance(op_value, list) and actual_value in op_value: + return False + elif op == '$exists': + if bool(op_value) != (key in metadata): + return False + # Add more operators as needed + else: + # Simple equality check + if actual_value != expected_value: + return False + + return True