enh/refac: url input handling

This commit is contained in:
Timothy Jaeryang Baek 2025-10-04 02:02:26 -05:00
parent ce83276fa4
commit a2a2bafdf6
5 changed files with 82 additions and 41 deletions

View file

@ -157,3 +157,10 @@ class YoutubeLoader:
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
) )
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list)) raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
async def aload(self) -> Generator[Document, None, None]:
"""Asynchronously load YouTube transcripts into `Document` objects."""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.load)

View file

@ -6,6 +6,7 @@ import requests
import hashlib import hashlib
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import time import time
import re
from urllib.parse import quote from urllib.parse import quote
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -16,6 +17,7 @@ from langchain_core.documents import Document
from open_webui.config import VECTOR_DB from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel from open_webui.models.users import UserModel
from open_webui.models.files import Files from open_webui.models.files import Files
from open_webui.models.knowledge import Knowledges from open_webui.models.knowledge import Knowledges
@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from open_webui.utils.misc import get_message_list from open_webui.utils.misc import get_message_list
from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.loaders.youtube import YoutubeLoader
from open_webui.env import ( from open_webui.env import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
def is_youtube_url(url: str) -> bool:
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
return content, docs
class VectorSearchRetriever(BaseRetriever): class VectorSearchRetriever(BaseRetriever):
collection_name: Any collection_name: Any
embedding_function: Any embedding_function: Any
@ -571,6 +603,13 @@ def get_sources_from_items(
"metadatas": [[{"file_id": chat.id, "name": chat.title}]], "metadatas": [[{"file_id": chat.id, "name": chat.title}]],
} }
elif item.get("type") == "url":
content, docs = get_content_from_url(request, item.get("url"))
if docs:
query_result = {
"documents": [[content]],
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
}
elif item.get("type") == "file": elif item.get("type") == "file":
if ( if (
item.get("context") == "full" item.get("context") == "full"
@ -736,7 +775,6 @@ def get_sources_from_items(
sources.append(source) sources.append(source)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return sources return sources

View file

@ -71,6 +71,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl
from open_webui.retrieval.web.external import search_external from open_webui.retrieval.web.external import search_external
from open_webui.retrieval.utils import ( from open_webui.retrieval.utils import (
get_content_from_url,
get_embedding_function, get_embedding_function,
get_reranking_function, get_reranking_function,
get_model_path, get_model_path,
@ -1691,33 +1692,6 @@ def process_text(
) )
def is_youtube_url(url: str) -> bool:
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
return content, docs
@router.post("/process/youtube") @router.post("/process/youtube")
@router.post("/process/web") @router.post("/process/web")
def process_web( def process_web(
@ -1733,7 +1707,11 @@ def process_web(
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
save_docs_to_vector_db( save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user request,
docs,
collection_name,
overwrite=True,
user=user,
) )
else: else:
collection_name = None collection_name = None

View file

@ -40,7 +40,10 @@ from open_webui.routers.tasks import (
generate_image_prompt, generate_image_prompt,
generate_chat_tags, generate_chat_tags,
) )
from open_webui.routers.retrieval import process_web_search, SearchForm from open_webui.routers.retrieval import (
process_web_search,
SearchForm,
)
from open_webui.routers.images import ( from open_webui.routers.images import (
load_b64_image_data, load_b64_image_data,
image_generations, image_generations,
@ -76,6 +79,7 @@ from open_webui.utils.task import (
) )
from open_webui.utils.misc import ( from open_webui.utils.misc import (
deep_update, deep_update,
extract_urls,
get_message_list, get_message_list,
add_or_update_system_message, add_or_update_system_message,
add_or_update_user_message, add_or_update_user_message,
@ -823,7 +827,11 @@ async def chat_completion_files_handler(
if files := body.get("metadata", {}).get("files", None): if files := body.get("metadata", {}).get("files", None):
# Check if all files are in full context mode # Check if all files are in full context mode
all_full_context = all(item.get("context") == "full" for item in files) all_full_context = all(
item.get("context") == "full"
for item in files
if item.get("type") == "file"
)
queries = [] queries = []
if not all_full_context: if not all_full_context:
@ -855,10 +863,6 @@ async def chat_completion_files_handler(
except: except:
pass pass
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
if not all_full_context:
await __event_emitter__( await __event_emitter__(
{ {
"type": "status", "type": "status",
@ -870,6 +874,9 @@ async def chat_completion_files_handler(
} }
) )
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
try: try:
# Offload get_sources_from_items to a separate thread # Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -908,7 +915,6 @@ async def chat_completion_files_handler(
log.debug(f"rag_contexts:sources: {sources}") log.debug(f"rag_contexts:sources: {sources}")
unique_ids = set() unique_ids = set()
for source in sources or []: for source in sources or []:
if not source or len(source.keys()) == 0: if not source or len(source.keys()) == 0:
continue continue
@ -927,7 +933,6 @@ async def chat_completion_files_handler(
unique_ids.add(_id) unique_ids.add(_id)
sources_count = len(unique_ids) sources_count = len(unique_ids)
await __event_emitter__( await __event_emitter__(
{ {
"type": "status", "type": "status",
@ -1170,8 +1175,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tool_ids = form_data.pop("tool_ids", None) tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None) files = form_data.pop("files", None)
# Remove files duplicates prompt = get_last_user_message(form_data["messages"])
if files: urls = extract_urls(prompt)
if files or urls:
if not files:
files = []
files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
# Remove duplicate files based on their content
files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
metadata = { metadata = {
@ -1372,8 +1384,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
) )
context_string = context_string.strip() context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
if prompt is None: if prompt is None:
raise Exception("No user message found") raise Exception("No user message found")

View file

@ -531,3 +531,11 @@ def throttle(interval: float = 10.0):
return wrapper return wrapper
return decorator return decorator
def extract_urls(text: str) -> list[str]:
# Regex pattern to match URLs
url_pattern = re.compile(
r"(https?://[^\s]+)", re.IGNORECASE
) # Matches http and https URLs
return url_pattern.findall(text)