diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 9c84f9c704..7caf57b0aa 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -33,6 +33,7 @@ from fastapi.responses import FileResponse from pydantic import BaseModel +from open_webui.utils.misc import strict_match_mime_type from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.headers import include_user_info_headers from open_webui.config import ( @@ -1155,17 +1156,9 @@ def transcription( stt_supported_content_types = getattr( request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] - ) + ) or ["audio/*", "video/webm"] - if not any( - fnmatch(file.content_type, content_type) - for content_type in ( - stt_supported_content_types - if stt_supported_content_types - and any(t.strip() for t in stt_supported_content_types) - else ["audio/*", "video/webm"] - ) - ): + if not strict_match_mime_type(stt_supported_content_types, file.content_type): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED, diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 3bb28b95d6..6eb7a19cbc 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -47,7 +47,7 @@ from open_webui.storage.provider import Storage from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access - +from open_webui.utils.misc import strict_match_mime_type from pydantic import BaseModel log = logging.getLogger(__name__) @@ -108,17 +108,9 @@ def process_uploaded_file(request, file, file_path, file_item, file_metadata, us if file.content_type: stt_supported_content_types = getattr( request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] - ) + ) or ["audio/*", "video/webm"] - if any( - fnmatch(file.content_type, content_type) - for content_type in ( - stt_supported_content_types - if stt_supported_content_types - and any(t.strip() for t in stt_supported_content_types) - else ["audio/*", "video/webm"] - ) - ): + if strict_match_mime_type(stt_supported_content_types, file.content_type): file_path = Storage.get_file(file_path) result = transcribe(request, file_path, file_metadata, user) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 21943caff8..e0e21249f6 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Callable, Optional, Sequence, Union import json import aiohttp +import mimeparse import collections.abc @@ -577,6 +578,37 @@ def throttle(interval: float = 10.0): return decorator +def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[str]: + """ + Strictly match the mime type with the supported mime types. + + :param supported: The supported mime types. + :param header: The header to match. + :return: The matched mime type or None if no match is found. + """ + + try: + if isinstance(supported, str): + supported = supported.split(",") + + supported = [s for s in supported if s.strip() and "/" in s] + + match = mimeparse.best_match(supported, header) + if not match: + return None + + _, _, match_params = mimeparse.parse_mime_type(match) + _, _, header_params = mimeparse.parse_mime_type(header) + for k, v in match_params.items(): + if header_params.get(k) != v: + return None + + return match + except Exception as e: + log.exception(f"Failed to match mime type {header}: {e}") + return None + + def extract_urls(text: str) -> list[str]: # Regex pattern to match URLs url_pattern = re.compile( diff --git a/backend/requirements.txt b/backend/requirements.txt index 558b6ecc46..b337a5f8b5 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,6 +20,7 @@ aiofiles starlette-compress==1.6.1 httpx[socks,http2,zstd,cli,brotli]==0.28.1 starsessions[redis]==2.2.1 +python-mimeparse==2.0.0 sqlalchemy==2.0.44 alembic==1.17.2