diff --git a/backend/open_webui/routers/prune.py b/backend/open_webui/routers/prune.py index 2968764d07..112901118d 100644 --- a/backend/open_webui/routers/prune.py +++ b/backend/open_webui/routers/prune.py @@ -17,7 +17,7 @@ from sqlalchemy import text from open_webui.utils.auth import get_admin_user from open_webui.models.users import Users -from open_webui.models.chats import Chats +from open_webui.models.chats import Chat, ChatModel, Chats from open_webui.models.files import Files from open_webui.models.notes import Notes from open_webui.models.prompts import Prompts @@ -128,6 +128,26 @@ class JSONFileIDExtractor: r"/api/v1/files/([a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})" ) + @classmethod + def extract_file_ids(cls, json_string: str) -> Set[str]: + """ + Extract file IDs from JSON string WITHOUT database validation. + + Args: + json_string: JSON content as string (or any string to scan) + + Returns: + Set of extracted file IDs (not validated against database) + + Note: + Use this method when you have a preloaded set of valid file IDs + to validate against, avoiding N database queries. + """ + potential_ids = [] + potential_ids.extend(cls._FILE_ID_PATTERN.findall(json_string)) + potential_ids.extend(cls._URL_PATTERN.findall(json_string)) + return set(potential_ids) + @classmethod def extract_and_validate_file_ids(cls, json_string: str) -> Set[str]: """ @@ -1068,6 +1088,10 @@ def get_active_file_ids() -> Set[str]: active_file_ids = set() try: + # Preload all valid file IDs to avoid N database queries during validation + # This is O(1) set lookup instead of O(n) DB queries + all_file_ids = {f.id for f in Files.get_files()} + log.debug(f"Preloaded {len(all_file_ids)} file IDs for validation") # Scan knowledge bases for file references knowledge_bases = Knowledges.get_knowledge_bases() log.debug(f"Found {len(knowledge_bases)} knowledge bases") @@ -1092,26 +1116,34 @@ def get_active_file_ids() -> Set[str]: for file_id in file_ids: if isinstance(file_id, str) and file_id.strip(): - active_file_ids.add(file_id.strip()) + stripped_id = file_id.strip() + # Validate against preloaded set (O(1) lookup) + if stripped_id in all_file_ids: + active_file_ids.add(stripped_id) # Scan chats for file references - chats = Chats.get_chats() - log.debug(f"Found {len(chats)} chats to scan for file references") + # Stream chats to avoid loading all into memory + chat_count = 0 + with get_db() as db: + for chat_orm in db.query(Chat).yield_per(1000): + chat_count += 1 + chat = ChatModel.model_validate(chat_orm) - for chat in chats: - if not chat.chat or not isinstance(chat.chat, dict): - continue + if not chat.chat or not isinstance(chat.chat, dict): + continue - try: - chat_json_str = json.dumps(chat.chat) - # Use utility to extract and validate file IDs - validated_ids = JSONFileIDExtractor.extract_and_validate_file_ids( - chat_json_str - ) - active_file_ids.update(validated_ids) + try: + chat_json_str = json.dumps(chat.chat) + # Extract file IDs without DB queries + extracted_ids = JSONFileIDExtractor.extract_file_ids(chat_json_str) + # Validate against preloaded set (O(1) per ID) + validated_ids = extracted_ids & all_file_ids + active_file_ids.update(validated_ids) - except Exception as e: - log.debug(f"Error processing chat {chat.id} for file references: {e}") + except Exception as e: + log.debug(f"Error processing chat {chat.id} for file references: {e}") + + log.debug(f"Scanned {chat_count} chats for file references") # Scan folders for file references try: @@ -1121,10 +1153,10 @@ def get_active_file_ids() -> Set[str]: if folder.items: try: items_str = json.dumps(folder.items) - # Use utility to extract and validate file IDs - validated_ids = ( - JSONFileIDExtractor.extract_and_validate_file_ids(items_str) - ) + # Extract file IDs without DB queries + extracted_ids = JSONFileIDExtractor.extract_file_ids(items_str) + # Validate against preloaded set (O(1) per ID) + validated_ids = extracted_ids & all_file_ids active_file_ids.update(validated_ids) except Exception as e: log.debug(f"Error processing folder {folder.id} items: {e}") @@ -1132,10 +1164,10 @@ def get_active_file_ids() -> Set[str]: if hasattr(folder, "data") and folder.data: try: data_str = json.dumps(folder.data) - # Use utility to extract and validate file IDs - validated_ids = ( - JSONFileIDExtractor.extract_and_validate_file_ids(data_str) - ) + # Extract file IDs without DB queries + extracted_ids = JSONFileIDExtractor.extract_file_ids(data_str) + # Validate against preloaded set (O(1) per ID) + validated_ids = extracted_ids & all_file_ids active_file_ids.update(validated_ids) except Exception as e: log.debug(f"Error processing folder {folder.id} data: {e}") @@ -1146,11 +1178,10 @@ def get_active_file_ids() -> Set[str]: # Scan standalone messages for file references try: with get_db() as db: - message_results = db.execute( - text("SELECT id, data FROM message WHERE data IS NOT NULL") - ).fetchall() + stmt = text("SELECT id, data FROM message WHERE data IS NOT NULL") - for message_id, message_data_json in message_results: + for row in db.execute(stmt).yield_per(1000): + message_id, message_data_json = row if message_data_json: try: data_str = ( @@ -1158,12 +1189,10 @@ def get_active_file_ids() -> Set[str]: if isinstance(message_data_json, dict) else str(message_data_json) ) - # Use utility to extract and validate file IDs - validated_ids = ( - JSONFileIDExtractor.extract_and_validate_file_ids( - data_str - ) - ) + # Extract file IDs without DB queries + extracted_ids = JSONFileIDExtractor.extract_file_ids(data_str) + # Validate against preloaded set (O(1) per ID) + validated_ids = extracted_ids & all_file_ids active_file_ids.update(validated_ids) except Exception as e: log.debug(