From a2a2bafdf67c155fe99c1d90ca787cdaed2fb805 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 4 Oct 2025 02:02:26 -0500 Subject: [PATCH] enh/refac: url input handling --- .../open_webui/retrieval/loaders/youtube.py | 7 ++++ backend/open_webui/retrieval/utils.py | 40 ++++++++++++++++++- backend/open_webui/routers/retrieval.py | 34 +++------------- backend/open_webui/utils/middleware.py | 34 ++++++++++------ backend/open_webui/utils/misc.py | 8 ++++ 5 files changed, 82 insertions(+), 41 deletions(-) diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py index 360ef0a6c7..da17eaef65 100644 --- a/backend/open_webui/retrieval/loaders/youtube.py +++ b/backend/open_webui/retrieval/loaders/youtube.py @@ -157,3 +157,10 @@ class YoutubeLoader: f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." ) raise NoTranscriptFound(self.video_id, self.language, list(transcript_list)) + + async def aload(self) -> Generator[Document, None, None]: + """Asynchronously load YouTube transcripts into `Document` objects.""" + import asyncio + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.load) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 65da1592e1..133016d85c 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -6,6 +6,7 @@ import requests import hashlib from concurrent.futures import ThreadPoolExecutor import time +import re from urllib.parse import quote from huggingface_hub import snapshot_download @@ -16,6 +17,7 @@ from langchain_core.documents import Document from open_webui.config import VECTOR_DB from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT + from open_webui.models.users import UserModel from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges @@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult from open_webui.utils.access_control import has_access from open_webui.utils.misc import get_message_list +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.loaders.youtube import YoutubeLoader + from open_webui.env import ( SRC_LOG_LEVELS, @@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.retrievers import BaseRetriever +def is_youtube_url(url: str) -> bool: + youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$" + return re.match(youtube_regex, url) is not None + + +def get_loader(request, url: str): + if is_youtube_url(url): + return YoutubeLoader( + url, + language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + ) + else: + return get_web_loader( + url, + verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, + ) + + +def get_content_from_url(request, url: str) -> str: + loader = get_loader(request, url) + docs = loader.load() + content = " ".join([doc.page_content for doc in docs]) + return content, docs + + class VectorSearchRetriever(BaseRetriever): collection_name: Any embedding_function: Any @@ -571,6 +603,13 @@ def get_sources_from_items( "metadatas": [[{"file_id": chat.id, "name": chat.title}]], } + elif item.get("type") == "url": + content, docs = get_content_from_url(request, item.get("url")) + if docs: + query_result = { + "documents": [[content]], + "metadatas": [[{"url": item.get("url"), "name": item.get("url")}]], + } elif item.get("type") == "file": if ( item.get("context") == "full" @@ -736,7 +775,6 @@ def get_sources_from_items( sources.append(source) except Exception as e: log.exception(e) - return sources diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 5a1e798d41..9aaac18fa3 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -71,6 +71,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl from open_webui.retrieval.web.external import search_external from open_webui.retrieval.utils import ( + get_content_from_url, get_embedding_function, get_reranking_function, get_model_path, @@ -1691,33 +1692,6 @@ def process_text( ) -def is_youtube_url(url: str) -> bool: - youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$" - return re.match(youtube_regex, url) is not None - - -def get_loader(request, url: str): - if is_youtube_url(url): - return YoutubeLoader( - url, - language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - ) - else: - return get_web_loader( - url, - verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, - ) - - -def get_content_from_url(request, url: str) -> str: - loader = get_loader(request, url) - docs = loader.load() - content = " ".join([doc.page_content for doc in docs]) - return content, docs - - @router.post("/process/youtube") @router.post("/process/web") def process_web( @@ -1733,7 +1707,11 @@ def process_web( if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: save_docs_to_vector_db( - request, docs, collection_name, overwrite=True, user=user + request, + docs, + collection_name, + overwrite=True, + user=user, ) else: collection_name = None diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 7501f803f1..10189cfae0 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -40,7 +40,10 @@ from open_webui.routers.tasks import ( generate_image_prompt, generate_chat_tags, ) -from open_webui.routers.retrieval import process_web_search, SearchForm +from open_webui.routers.retrieval import ( + process_web_search, + SearchForm, +) from open_webui.routers.images import ( load_b64_image_data, image_generations, @@ -76,6 +79,7 @@ from open_webui.utils.task import ( ) from open_webui.utils.misc import ( deep_update, + extract_urls, get_message_list, add_or_update_system_message, add_or_update_user_message, @@ -823,7 +827,11 @@ async def chat_completion_files_handler( if files := body.get("metadata", {}).get("files", None): # Check if all files are in full context mode - all_full_context = all(item.get("context") == "full" for item in files) + all_full_context = all( + item.get("context") == "full" + for item in files + if item.get("type") == "file" + ) queries = [] if not all_full_context: @@ -855,10 +863,6 @@ async def chat_completion_files_handler( except: pass - if len(queries) == 0: - queries = [get_last_user_message(body["messages"])] - - if not all_full_context: await __event_emitter__( { "type": "status", @@ -870,6 +874,9 @@ async def chat_completion_files_handler( } ) + if len(queries) == 0: + queries = [get_last_user_message(body["messages"])] + try: # Offload get_sources_from_items to a separate thread loop = asyncio.get_running_loop() @@ -908,7 +915,6 @@ async def chat_completion_files_handler( log.debug(f"rag_contexts:sources: {sources}") unique_ids = set() - for source in sources or []: if not source or len(source.keys()) == 0: continue @@ -927,7 +933,6 @@ async def chat_completion_files_handler( unique_ids.add(_id) sources_count = len(unique_ids) - await __event_emitter__( { "type": "status", @@ -1170,8 +1175,15 @@ async def process_chat_payload(request, form_data, user, metadata, model): tool_ids = form_data.pop("tool_ids", None) files = form_data.pop("files", None) - # Remove files duplicates - if files: + prompt = get_last_user_message(form_data["messages"]) + urls = extract_urls(prompt) + + if files or urls: + if not files: + files = [] + files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]] + + # Remove duplicate files based on their content files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) metadata = { @@ -1372,8 +1384,6 @@ async def process_chat_payload(request, form_data, user, metadata, model): ) context_string = context_string.strip() - - prompt = get_last_user_message(form_data["messages"]) if prompt is None: raise Exception("No user message found") diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index 8977cf17d4..9984e378fb 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -531,3 +531,11 @@ def throttle(interval: float = 10.0): return wrapper return decorator + + +def extract_urls(text: str) -> list[str]: + # Regex pattern to match URLs + url_pattern = re.compile( + r"(https?://[^\s]+)", re.IGNORECASE + ) # Matches http and https URLs + return url_pattern.findall(text)