mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 21:05:19 +00:00
enh/refac: url input handling
This commit is contained in:
parent
ce83276fa4
commit
a2a2bafdf6
5 changed files with 82 additions and 41 deletions
|
|
@ -157,3 +157,10 @@ class YoutubeLoader:
|
||||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||||
)
|
)
|
||||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||||
|
|
||||||
|
async def aload(self) -> Generator[Document, None, None]:
|
||||||
|
"""Asynchronously load YouTube transcripts into `Document` objects."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self.load)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import requests
|
||||||
import hashlib
|
import hashlib
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import time
|
import time
|
||||||
|
import re
|
||||||
|
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
@ -16,6 +17,7 @@ from langchain_core.documents import Document
|
||||||
from open_webui.config import VECTOR_DB
|
from open_webui.config import VECTOR_DB
|
||||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||||
|
|
||||||
|
|
||||||
from open_webui.models.users import UserModel
|
from open_webui.models.users import UserModel
|
||||||
from open_webui.models.files import Files
|
from open_webui.models.files import Files
|
||||||
from open_webui.models.knowledge import Knowledges
|
from open_webui.models.knowledge import Knowledges
|
||||||
|
|
@ -27,6 +29,9 @@ from open_webui.retrieval.vector.main import GetResult
|
||||||
from open_webui.utils.access_control import has_access
|
from open_webui.utils.access_control import has_access
|
||||||
from open_webui.utils.misc import get_message_list
|
from open_webui.utils.misc import get_message_list
|
||||||
|
|
||||||
|
from open_webui.retrieval.web.utils import get_web_loader
|
||||||
|
from open_webui.retrieval.loaders.youtube import YoutubeLoader
|
||||||
|
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
|
|
@ -49,6 +54,33 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
|
def is_youtube_url(url: str) -> bool:
|
||||||
|
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
|
||||||
|
return re.match(youtube_regex, url) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_loader(request, url: str):
|
||||||
|
if is_youtube_url(url):
|
||||||
|
return YoutubeLoader(
|
||||||
|
url,
|
||||||
|
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||||
|
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return get_web_loader(
|
||||||
|
url,
|
||||||
|
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
||||||
|
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_content_from_url(request, url: str) -> str:
|
||||||
|
loader = get_loader(request, url)
|
||||||
|
docs = loader.load()
|
||||||
|
content = " ".join([doc.page_content for doc in docs])
|
||||||
|
return content, docs
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchRetriever(BaseRetriever):
|
class VectorSearchRetriever(BaseRetriever):
|
||||||
collection_name: Any
|
collection_name: Any
|
||||||
embedding_function: Any
|
embedding_function: Any
|
||||||
|
|
@ -571,6 +603,13 @@ def get_sources_from_items(
|
||||||
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
|
"metadatas": [[{"file_id": chat.id, "name": chat.title}]],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elif item.get("type") == "url":
|
||||||
|
content, docs = get_content_from_url(request, item.get("url"))
|
||||||
|
if docs:
|
||||||
|
query_result = {
|
||||||
|
"documents": [[content]],
|
||||||
|
"metadatas": [[{"url": item.get("url"), "name": item.get("url")}]],
|
||||||
|
}
|
||||||
elif item.get("type") == "file":
|
elif item.get("type") == "file":
|
||||||
if (
|
if (
|
||||||
item.get("context") == "full"
|
item.get("context") == "full"
|
||||||
|
|
@ -736,7 +775,6 @@ def get_sources_from_items(
|
||||||
sources.append(source)
|
sources.append(source)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,7 @@ from open_webui.retrieval.web.firecrawl import search_firecrawl
|
||||||
from open_webui.retrieval.web.external import search_external
|
from open_webui.retrieval.web.external import search_external
|
||||||
|
|
||||||
from open_webui.retrieval.utils import (
|
from open_webui.retrieval.utils import (
|
||||||
|
get_content_from_url,
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
get_reranking_function,
|
get_reranking_function,
|
||||||
get_model_path,
|
get_model_path,
|
||||||
|
|
@ -1691,33 +1692,6 @@ def process_text(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_youtube_url(url: str) -> bool:
|
|
||||||
youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$"
|
|
||||||
return re.match(youtube_regex, url) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_loader(request, url: str):
|
|
||||||
if is_youtube_url(url):
|
|
||||||
return YoutubeLoader(
|
|
||||||
url,
|
|
||||||
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
|
||||||
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return get_web_loader(
|
|
||||||
url,
|
|
||||||
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
|
|
||||||
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_content_from_url(request, url: str) -> str:
|
|
||||||
loader = get_loader(request, url)
|
|
||||||
docs = loader.load()
|
|
||||||
content = " ".join([doc.page_content for doc in docs])
|
|
||||||
return content, docs
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/process/youtube")
|
@router.post("/process/youtube")
|
||||||
@router.post("/process/web")
|
@router.post("/process/web")
|
||||||
def process_web(
|
def process_web(
|
||||||
|
|
@ -1733,7 +1707,11 @@ def process_web(
|
||||||
|
|
||||||
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL:
|
||||||
save_docs_to_vector_db(
|
save_docs_to_vector_db(
|
||||||
request, docs, collection_name, overwrite=True, user=user
|
request,
|
||||||
|
docs,
|
||||||
|
collection_name,
|
||||||
|
overwrite=True,
|
||||||
|
user=user,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
collection_name = None
|
collection_name = None
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,10 @@ from open_webui.routers.tasks import (
|
||||||
generate_image_prompt,
|
generate_image_prompt,
|
||||||
generate_chat_tags,
|
generate_chat_tags,
|
||||||
)
|
)
|
||||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
from open_webui.routers.retrieval import (
|
||||||
|
process_web_search,
|
||||||
|
SearchForm,
|
||||||
|
)
|
||||||
from open_webui.routers.images import (
|
from open_webui.routers.images import (
|
||||||
load_b64_image_data,
|
load_b64_image_data,
|
||||||
image_generations,
|
image_generations,
|
||||||
|
|
@ -76,6 +79,7 @@ from open_webui.utils.task import (
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import (
|
from open_webui.utils.misc import (
|
||||||
deep_update,
|
deep_update,
|
||||||
|
extract_urls,
|
||||||
get_message_list,
|
get_message_list,
|
||||||
add_or_update_system_message,
|
add_or_update_system_message,
|
||||||
add_or_update_user_message,
|
add_or_update_user_message,
|
||||||
|
|
@ -823,7 +827,11 @@ async def chat_completion_files_handler(
|
||||||
|
|
||||||
if files := body.get("metadata", {}).get("files", None):
|
if files := body.get("metadata", {}).get("files", None):
|
||||||
# Check if all files are in full context mode
|
# Check if all files are in full context mode
|
||||||
all_full_context = all(item.get("context") == "full" for item in files)
|
all_full_context = all(
|
||||||
|
item.get("context") == "full"
|
||||||
|
for item in files
|
||||||
|
if item.get("type") == "file"
|
||||||
|
)
|
||||||
|
|
||||||
queries = []
|
queries = []
|
||||||
if not all_full_context:
|
if not all_full_context:
|
||||||
|
|
@ -855,10 +863,6 @@ async def chat_completion_files_handler(
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if len(queries) == 0:
|
|
||||||
queries = [get_last_user_message(body["messages"])]
|
|
||||||
|
|
||||||
if not all_full_context:
|
|
||||||
await __event_emitter__(
|
await __event_emitter__(
|
||||||
{
|
{
|
||||||
"type": "status",
|
"type": "status",
|
||||||
|
|
@ -870,6 +874,9 @@ async def chat_completion_files_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(queries) == 0:
|
||||||
|
queries = [get_last_user_message(body["messages"])]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Offload get_sources_from_items to a separate thread
|
# Offload get_sources_from_items to a separate thread
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
@ -908,7 +915,6 @@ async def chat_completion_files_handler(
|
||||||
log.debug(f"rag_contexts:sources: {sources}")
|
log.debug(f"rag_contexts:sources: {sources}")
|
||||||
|
|
||||||
unique_ids = set()
|
unique_ids = set()
|
||||||
|
|
||||||
for source in sources or []:
|
for source in sources or []:
|
||||||
if not source or len(source.keys()) == 0:
|
if not source or len(source.keys()) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
@ -927,7 +933,6 @@ async def chat_completion_files_handler(
|
||||||
unique_ids.add(_id)
|
unique_ids.add(_id)
|
||||||
|
|
||||||
sources_count = len(unique_ids)
|
sources_count = len(unique_ids)
|
||||||
|
|
||||||
await __event_emitter__(
|
await __event_emitter__(
|
||||||
{
|
{
|
||||||
"type": "status",
|
"type": "status",
|
||||||
|
|
@ -1170,8 +1175,15 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
tool_ids = form_data.pop("tool_ids", None)
|
tool_ids = form_data.pop("tool_ids", None)
|
||||||
files = form_data.pop("files", None)
|
files = form_data.pop("files", None)
|
||||||
|
|
||||||
# Remove files duplicates
|
prompt = get_last_user_message(form_data["messages"])
|
||||||
if files:
|
urls = extract_urls(prompt)
|
||||||
|
|
||||||
|
if files or urls:
|
||||||
|
if not files:
|
||||||
|
files = []
|
||||||
|
files = [*files, *[{"type": "url", "url": url, "name": url} for url in urls]]
|
||||||
|
|
||||||
|
# Remove duplicate files based on their content
|
||||||
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
files = list({json.dumps(f, sort_keys=True): f for f in files}.values())
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
|
|
@ -1372,8 +1384,6 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||||
)
|
)
|
||||||
|
|
||||||
context_string = context_string.strip()
|
context_string = context_string.strip()
|
||||||
|
|
||||||
prompt = get_last_user_message(form_data["messages"])
|
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
raise Exception("No user message found")
|
raise Exception("No user message found")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -531,3 +531,11 @@ def throttle(interval: float = 10.0):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def extract_urls(text: str) -> list[str]:
|
||||||
|
# Regex pattern to match URLs
|
||||||
|
url_pattern = re.compile(
|
||||||
|
r"(https?://[^\s]+)", re.IGNORECASE
|
||||||
|
) # Matches http and https URLs
|
||||||
|
return url_pattern.findall(text)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue