diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 8ece4af736..981d5326c3 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -2861,6 +2861,26 @@ ENABLE_RAG_LOCAL_WEB_FETCH = ( os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" ) + +DEFAULT_WEB_FETCH_FILTER_LIST = [ + "!169.254.169.254", + "!fd00:ec2::254", + "!metadata.google.internal", + "!metadata.azure.com", + "!100.100.100.200", +] + +web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "") +if web_fetch_filter_list == "": + web_fetch_filter_list = [] +else: + web_fetch_filter_list = [ + item.strip() for item in web_fetch_filter_list.split(",") if item.strip() + ] + +WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list)) + + YOUTUBE_LOADER_LANGUAGE = PersistentConfig( "YOUTUBE_LOADER_LANGUAGE", "rag.youtube_loader_language", diff --git a/backend/open_webui/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py index 8025303d6a..d8cfb11ba0 100644 --- a/backend/open_webui/retrieval/web/main.py +++ b/backend/open_webui/retrieval/web/main.py @@ -5,16 +5,13 @@ from urllib.parse import urlparse from pydantic import BaseModel +from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname + def get_filtered_results(results, filter_list): if not filter_list: return results - # Domains starting without "!" → allowed - allow_list = [d for d in filter_list if not d.startswith("!")] - # Domains starting with "!" → blocked - block_list = [d[1:] for d in filter_list if d.startswith("!")] - filtered_results = [] for result in results: @@ -23,17 +20,21 @@ def get_filtered_results(results, filter_list): continue domain = urlparse(url).netloc - - # If allow list is non-empty, require domain to match one of them - if allow_list: - if not any(domain.endswith(allowed) for allowed in allow_list): - continue - - # Block list always removes matches - if any(domain.endswith(blocked) for blocked in block_list): + if not domain: continue - filtered_results.append(result) + hostnames = [domain] + + try: + ipv4_addresses, ipv6_addresses = resolve_hostname(domain) + hostnames.extend(ipv4_addresses) + hostnames.extend(ipv6_addresses) + except Exception: + pass + + if any(is_string_allowed(hostname, filter_list) for hostname in hostnames): + filtered_results.append(result) + continue return filtered_results diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index 91699a157b..df5036487c 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -24,6 +24,7 @@ import validators from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader from langchain_community.document_loaders.base import BaseLoader from langchain_core.documents import Document + from open_webui.retrieval.loaders.tavily import TavilyLoader from open_webui.retrieval.loaders.external_web import ExternalWebLoader from open_webui.constants import ERROR_MESSAGES @@ -38,6 +39,7 @@ from open_webui.config import ( TAVILY_EXTRACT_DEPTH, EXTERNAL_WEB_LOADER_URL, EXTERNAL_WEB_LOADER_API_KEY, + WEB_FETCH_FILTER_LIST, ) from open_webui.env import SRC_LOG_LEVELS @@ -46,10 +48,71 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def resolve_hostname(hostname): + # Get address information + addr_info = socket.getaddrinfo(hostname, None) + + # Extract IP addresses from address information + ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] + ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] + + return ipv4_addresses, ipv6_addresses + + +def get_allow_block_lists(filter_list): + allow_list = [] + block_list = [] + + if filter_list: + for d in filter_list: + if d.startswith("!"): + # Domains starting with "!" → blocked + block_list.append(d[1:]) + else: + # Domains starting without "!" → allowed + allow_list.append(d) + + return allow_list, block_list + + +def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool: + if not filter_list: + return True + + allow_list, block_list = get_allow_block_lists(filter_list) + # If allow list is non-empty, require domain to match one of them + if allow_list: + if not any(string.endswith(allowed) for allowed in allow_list): + return False + + # Block list always removes matches + if any(string.endswith(blocked) for blocked in block_list): + return False + + return True + + def validate_url(url: Union[str, Sequence[str]]): if isinstance(url, str): if isinstance(validators.url(url), validators.ValidationError): raise ValueError(ERROR_MESSAGES.INVALID_URL) + + parsed_url = urllib.parse.urlparse(url) + + # Protocol validation - only allow http/https + if parsed_url.scheme not in ["http", "https"]: + log.warning( + f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}" + ) + raise ValueError(ERROR_MESSAGES.INVALID_URL) + + # Blocklist check using unified filtering logic + if WEB_FETCH_FILTER_LIST: + result = is_string_allowed(url, WEB_FETCH_FILTER_LIST) + if len(result) == 0: + log.warning(f"URL blocked by filter list: {url}") + raise ValueError(ERROR_MESSAGES.INVALID_URL) + if not ENABLE_RAG_LOCAL_WEB_FETCH: # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses parsed_url = urllib.parse.urlparse(url) @@ -82,17 +145,6 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]: return valid_urls -def resolve_hostname(hostname): - # Get address information - addr_info = socket.getaddrinfo(hostname, None) - - # Extract IP addresses from address information - ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET] - ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6] - - return ipv4_addresses, ipv6_addresses - - def extract_metadata(soup, url): metadata = {"source": url} if title := soup.find("title"): @@ -642,6 +694,10 @@ def get_web_loader( # Check if the URLs are valid safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls) + if not safe_urls: + log.warning(f"All provided URLs were blocked or invalid: {urls}") + raise ValueError(ERROR_MESSAGES.INVALID_URL) + web_loader_args = { "web_paths": safe_urls, "verify_ssl": verify_ssl,