diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index b4fdc97d82..243b8212a8 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -547,16 +547,16 @@ else: CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get( - "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "10" + "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30" ) if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "": - CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 10 + CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 else: try: CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = int(CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES) except Exception: - CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 10 + CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 #################################### diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py index f5510eec35..b410cbab50 100644 --- a/backend/open_webui/utils/files.py +++ b/backend/open_webui/utils/files.py @@ -3,6 +3,20 @@ from open_webui.routers.images import ( upload_image, ) +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + UploadFile, +) + +from open_webui.routers.files import upload_file_handler + +import mimetypes +import base64 +import io + def get_image_url_from_base64(request, base64_image_string, metadata, user): if "data:image/png;base64" in base64_image_string: @@ -19,3 +33,65 @@ def get_image_url_from_base64(request, base64_image_string, metadata, user): ) return image_url return None + + +def load_b64_audio_data(b64_str): + try: + if "," in b64_str: + header, b64_data = b64_str.split(",", 1) + else: + b64_data = b64_str + header = "data:audio/wav;base64" + audio_data = base64.b64decode(b64_data) + content_type = ( + header.split(";")[0].split(":")[1] if ";" in header else "audio/wav" + ) + return audio_data, content_type + except Exception as e: + print(f"Error decoding base64 audio data: {e}") + return None, None + + +def upload_audio(request, audio_data, content_type, metadata, user): + audio_format = mimetypes.guess_extension(content_type) + file = UploadFile( + file=io.BytesIO(audio_data), + filename=f"generated-{audio_format}", # will be converted to a unique ID on upload_file + headers={ + "content-type": content_type, + }, + ) + file_item = upload_file_handler( + request, + file=file, + metadata=metadata, + process=False, + user=user, + ) + url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) + return url + + +def get_audio_url_from_base64(request, base64_audio_string, metadata, user): + if "data:audio/wav;base64" in base64_audio_string: + audio_url = "" + # Extract base64 audio data from the line + audio_data, content_type = load_b64_audio_data(base64_audio_string) + if audio_data is not None: + audio_url = upload_audio( + request, + audio_data, + content_type, + metadata, + user, + ) + return audio_url + return None + + +def get_file_url_from_base64(request, base64_file_string, metadata, user): + if "data:image/png;base64" in base64_file_string: + return get_image_url_from_base64(request, base64_file_string, metadata, user) + elif "data:audio/wav;base64" in base64_file_string: + return get_audio_url_from_base64(request, base64_file_string, metadata, user) + return None diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 3b4fc43858..bf71b77bcd 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -53,7 +53,11 @@ from open_webui.routers.pipelines import ( from open_webui.routers.memories import query_memory, QueryMemoryForm from open_webui.utils.webhook import post_webhook -from open_webui.utils.files import get_image_url_from_base64 +from open_webui.utils.files import ( + get_audio_url_from_base64, + get_file_url_from_base64, + get_image_url_from_base64, +) from open_webui.models.users import UserModel @@ -2573,34 +2577,36 @@ async def process_chat_response( tool_result.remove(item) if tool.get("type") == "mcp": - if ( - isinstance(item, dict) - and item.get("type") == "image" - ): - image_url = get_image_url_from_base64( - request, - f"data:{item.get('mimeType', 'image/png')};base64,{item.get('data', '')}", - { - "chat_id": metadata.get( - "chat_id", None - ), - "message_id": metadata.get( - "message_id", None - ), - "session_id": metadata.get( - "session_id", None - ), - }, - user, - ) + if isinstance(item, dict): + if ( + item.get("type") == "image" + or item.get("type") == "audio" + ): + file_url = get_file_url_from_base64( + request, + f"data:{item.get('mimeType')};base64,{item.get('data', '')}", + { + "chat_id": metadata.get( + "chat_id", None + ), + "message_id": metadata.get( + "message_id", None + ), + "session_id": metadata.get( + "session_id", None + ), + "result": item, + }, + user, + ) - tool_result_files.append( - { - "type": "image", - "url": image_url, - } - ) - tool_result.remove(item) + tool_result_files.append( + { + "type": item.get("type", "data"), + "url": file_url, + } + ) + tool_result.remove(item) if tool_result_files: if not isinstance(tool_result, list): @@ -2612,7 +2618,7 @@ async def process_chat_response( tool_result.append( { "type": file.get("type", "data"), - "content": "Displayed", + "content": "Result is being displayed as a file.", } )