mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
enh/refac: url input handling
This commit is contained in:
parent
ce83276fa4
commit
a2a2bafdf6
5 changed files with 82 additions and 41 deletions
|
|
@ -157,3 +157,10 @@ class YoutubeLoader:
|
|||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
)
|
||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||
|
||||
async def aload(self) -> Generator[Document, None, None]:
|
||||
"""Asynchronously load YouTube transcripts into `Document` objects."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.load)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import requests
|
|||
import hashlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
import re
|
||||
|
||||
from urllib.parse import quote
|
||||
from huggingface_hub import snapshot_download
|
||||
|
|
@ -16,6 +17,7 @@ from langchain_core.documents import Document
|
|||
from open_webui.config import VECTOR_DB
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.files import Files
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
|
|
@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult
|
|||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.utils.misc import get_message_list
|
||||
|
||||
from open_webui.retrieval.web.utils import get_web_loader
|
||||
from open_webui.retrieval.loaders.youtube import YoutubeLoader
|
||||
|
||||
|
||||
from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
|
|
@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
|
||||
return re.match(youtube_regex, url) is not None
|
||||
|
||||
|
||||
def get_loader(request, url: str):
|
||||
if is_youtube_url(url):
|
||||
return YoutubeLoader(
|
||||
url,
|
||||
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
)
|
||||
else:
|
||||
return get_web_loader(
|
||||
url,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
||||
|
||||
def get_content_from_url(request, url: str) -> str:
|
||||
loader = get_loader(request, url)
|
||||
docs = loader.load()
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
return content, docs
|
||||
|
||||
|
||||
class VectorSearchRetriever(BaseRetriever):
|
||||
collection_name: Any
|
||||
embedding_function: Any
|
||||
|
|
@ -571,6 +603,13 @@ def get_sources_from_items(
|
|||
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
|
||||
}
|
||||
|
||||
elif item.get("type") == "url":
|
||||
content, docs = get_content_from_url(request, item.get("url"))
|
||||
if docs:
|
||||
query_result = {
|
||||
"documents": [[content]],
|
||||
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
|
||||
}
|
||||
elif item.get("type") == "file":
|
||||
if (
|
||||
item.get("context") == "full"
|
||||
|
|
@ -736,7 +775,6 @@ def get_sources_from_items(
|
|||
sources.append(source)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl
|
|||
from open_webui.retrieval.web.external import search_external
|
||||
|
||||
from open_webui.retrieval.utils import (
|
||||
get_content_from_url,
|
||||
get_embedding_function,
|
||||
get_reranking_function,
|
||||
get_model_path,
|
||||
|
|
@ -1691,33 +1692,6 @@ def process_text(
|
|||
)
|
||||
|
||||
|
||||
def is_youtube_url(url: str) -> bool:
|
||||
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
|
||||
return re.match(youtube_regex, url) is not None
|
||||
|
||||
|
||||
def get_loader(request, url: str):
|
||||
if is_youtube_url(url):
|
||||
return YoutubeLoader(
|
||||
url,
|
||||
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||
)
|
||||
else:
|
||||
return get_web_loader(
|
||||
url,
|
||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||
)
|
||||
|
||||
|
||||
def get_content_from_url(request, url: str) -> str:
|
||||
loader = get_loader(request, url)
|
||||
docs = loader.load()
|
||||
content = " ".join([doc.page_content for doc in docs])
|
||||
return content, docs
|
||||
|
||||
|
||||
@router.post("/process/youtube")
|
||||
@router.post("/process/web")
|
||||
def process_web(
|
||||
|
|
@ -1733,7 +1707,11 @@ def process_web(
|
|||
|
||||
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||
save_docs_to_vector_db(
|
||||
request, docs, collection_name, overwrite=True, user=user
|
||||
request,
|
||||
docs,
|
||||
collection_name,
|
||||
overwrite=True,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
collection_name = None
|
||||
|
|
|
|||
|
|
@ -40,7 +40,10 @@ from open_webui.routers.tasks import (
|
|||
generate_image_prompt,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.routers.retrieval import (
|
||||
process_web_search,
|
||||
SearchForm,
|
||||
)
|
||||
from open_webui.routers.images import (
|
||||
load_b64_image_data,
|
||||
image_generations,
|
||||
|
|
@ -76,6 +79,7 @@ from open_webui.utils.task import (
|
|||
)
|
||||
from open_webui.utils.misc import (
|
||||
deep_update,
|
||||
extract_urls,
|
||||
get_message_list,
|
||||
add_or_update_system_message,
|
||||
add_or_update_user_message,
|
||||
|
|
@ -823,7 +827,11 @@ async def chat_completion_files_handler(
|
|||
|
||||
if files := body.get("metadata", {}).get("files", None):
|
||||
# Check if all files are in full context mode
|
||||
all_full_context = all(item.get("context") == "full" for item in files)
|
||||
all_full_context = all(
|
||||
item.get("context") == "full"
|
||||
for item in files
|
||||
if item.get("type") == "file"
|
||||
)
|
||||
|
||||
queries = []
|
||||
if not all_full_context:
|
||||
|
|
@ -855,10 +863,6 @@ async def chat_completion_files_handler(
|
|||
except:
|
||||
pass
|
||||
|
||||
if len(queries) == 0:
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
if not all_full_context:
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
|
|
@ -870,6 +874,9 @@ async def chat_completion_files_handler(
|
|||
}
|
||||
)
|
||||
|
||||
if len(queries) == 0:
|
||||
queries = [get_last_user_message(body["messages"])]
|
||||
|
||||
try:
|
||||
# Offload get_sources_from_items to a separate thread
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -908,7 +915,6 @@ async def chat_completion_files_handler(
|
|||
log.debug(f"rag_contexts:sources: {sources}")
|
||||
|
||||
unique_ids = set()
|
||||
|
||||
for source in sources or []:
|
||||
if not source or len(source.keys()) == 0:
|
||||
continue
|
||||
|
|
@ -927,7 +933,6 @@ async def chat_completion_files_handler(
|
|||
unique_ids.add(_id)
|
||||
|
||||
sources_count = len(unique_ids)
|
||||
|
||||
await __event_emitter__(
|
||||
{
|
||||
"type": "status",
|
||||
|
|
@ -1170,8 +1175,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
tool_ids = form_data.pop("tool_ids", None)
|
||||
files = form_data.pop("files", None)
|
||||
|
||||
# Remove files duplicates
|
||||
if files:
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
urls = extract_urls(prompt)
|
||||
|
||||
if files or urls:
|
||||
if not files:
|
||||
files = []
|
||||
files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
|
||||
|
||||
# Remove duplicate files based on their content
|
||||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||||
|
||||
metadata = {
|
||||
|
|
@ -1372,8 +1384,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
|||
)
|
||||
|
||||
context_string = context_string.strip()
|
||||
|
||||
prompt = get_last_user_message(form_data["messages"])
|
||||
if prompt is None:
|
||||
raise Exception("No user message found")
|
||||
|
||||
|
|
|
|||
|
|
@ -531,3 +531,11 @@ def throttle(interval: float = 10.0):
|
|||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def extract_urls(text: str) -> list[str]:
|
||||
# Regex pattern to match URLs
|
||||
url_pattern = re.compile(
|
||||
r"(https?://[^\s]+)", re.IGNORECASE
|
||||
) # Matches http and https URLs
|
||||
return url_pattern.findall(text)
|
||||
|
|
|
|||
Loading…
Reference in a new issue