enh/refac: url input handling

This commit is contained in:
Timothy Jaeryang Baek 2025-10-04 02:02:26 -05:00
parent ce83276fa4
commit a2a2bafdf6
5 changed files with 82 additions and 41 deletions

View file

@ -157,3 +157,10 @@ class YoutubeLoader:
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
)
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
async def aload(self) -> Generator[Document, None, None]:
"""Asynchronously load YouTube transcripts into `Document` objects."""
import asyncio
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self.load)

View file

@ -6,6 +6,7 @@ import requests
import hashlib
from concurrent.futures import ThreadPoolExecutor
import time
import re
from urllib.parse import quote
from huggingface_hub import snapshot_download
@ -16,6 +17,7 @@ from langchain_core.documents import Document
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.models.knowledge import Knowledges
@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.access_control import has_access
from open_webui.utils.misc import get_message_list
from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.loaders.youtube import YoutubeLoader
from open_webui.env import (
SRC_LOG_LEVELS,
@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
def is_youtube_url(url: str) -> bool:
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
return content, docs
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
@ -571,6 +603,13 @@ def get_sources_from_items(
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
}
elif item.get("type") == "url":
content, docs = get_content_from_url(request, item.get("url"))
if docs:
query_result = {
"documents": [[content]],
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
}
elif item.get("type") == "file":
if (
item.get("context") == "full"
@ -736,7 +775,6 @@ def get_sources_from_items(
sources.append(source)
except Exception as e:
log.exception(e)
return sources

View file

@ -71,6 +71,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl
from open_webui.retrieval.web.external import search_external
from open_webui.retrieval.utils import (
get_content_from_url,
get_embedding_function,
get_reranking_function,
get_model_path,
@ -1691,33 +1692,6 @@ def process_text(
)
def is_youtube_url(url: str) -> bool:
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
content = " ".join([doc.page_content for doc in docs])
return content, docs
@router.post("/process/youtube")
@router.post("/process/web")
def process_web(
@ -1733,7 +1707,11 @@ def process_web(
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
save_docs_to_vector_db(
request, docs, collection_name, overwrite=True, user=user
request,
docs,
collection_name,
overwrite=True,
user=user,
)
else:
collection_name = None

View file

@ -40,7 +40,10 @@ from open_webui.routers.tasks import (
generate_image_prompt,
generate_chat_tags,
)
from open_webui.routers.retrieval import process_web_search, SearchForm
from open_webui.routers.retrieval import (
process_web_search,
SearchForm,
)
from open_webui.routers.images import (
load_b64_image_data,
image_generations,
@ -76,6 +79,7 @@ from open_webui.utils.task import (
)
from open_webui.utils.misc import (
deep_update,
extract_urls,
get_message_list,
add_or_update_system_message,
add_or_update_user_message,
@ -823,7 +827,11 @@ async def chat_completion_files_handler(
if files := body.get("metadata", {}).get("files", None):
# Check if all files are in full context mode
all_full_context = all(item.get("context") == "full" for item in files)
all_full_context = all(
item.get("context") == "full"
for item in files
if item.get("type") == "file"
)
queries = []
if not all_full_context:
@ -855,10 +863,6 @@ async def chat_completion_files_handler(
except:
pass
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
if not all_full_context:
await __event_emitter__(
{
"type": "status",
@ -870,6 +874,9 @@ async def chat_completion_files_handler(
}
)
if len(queries) == 0:
queries = [get_last_user_message(body["messages"])]
try:
# Offload get_sources_from_items to a separate thread
loop = asyncio.get_running_loop()
@ -908,7 +915,6 @@ async def chat_completion_files_handler(
log.debug(f"rag_contexts:sources: {sources}")
unique_ids = set()
for source in sources or []:
if not source or len(source.keys()) == 0:
continue
@ -927,7 +933,6 @@ async def chat_completion_files_handler(
unique_ids.add(_id)
sources_count = len(unique_ids)
await __event_emitter__(
{
"type": "status",
@ -1170,8 +1175,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
tool_ids = form_data.pop("tool_ids", None)
files = form_data.pop("files", None)
# Remove files duplicates
if files:
prompt = get_last_user_message(form_data["messages"])
urls = extract_urls(prompt)
if files or urls:
if not files:
files = []
files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
# Remove duplicate files based on their content
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
metadata = {
@ -1372,8 +1384,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
)
context_string = context_string.strip()
prompt = get_last_user_message(form_data["messages"])
if prompt is None:
raise Exception("No user message found")

View file

@ -531,3 +531,11 @@ def throttle(interval: float = 10.0):
return wrapper
return decorator
def extract_urls(text: str) -> list[str]:
# Regex pattern to match URLs
url_pattern = re.compile(
r"(https?://[^\s]+)", re.IGNORECASE
) # Matches http and https URLs
return url_pattern.findall(text)