feat/security: Add SSRF protection with configurable blocklist

Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-11-18 04:40:55 -05:00
parent 6cdb13d5cb
commit 02238d3113
3 changed files with 102 additions and 25 deletions

View file

@ -2861,6 +2861,26 @@ ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
DEFAULT_WEB_FETCH_FILTER_LIST = [
"!169.254.169.254",
"!fd00:ec2::254",
"!metadata.google.internal",
"!metadata.azure.com",
"!100.100.100.200",
]
web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "")
if web_fetch_filter_list == "":
web_fetch_filter_list = []
else:
web_fetch_filter_list = [
item.strip() for item in web_fetch_filter_list.split(",") if item.strip()
]
WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list))
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
"YOUTUBE_LOADER_LANGUAGE",
"rag.youtube_loader_language",

View file

@ -5,16 +5,13 @@ from urllib.parse import urlparse
from pydantic import BaseModel
from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname
def get_filtered_results(results, filter_list):
if not filter_list:
return results
# Domains starting without "!" → allowed
allow_list = [d for d in filter_list if not d.startswith("!")]
# Domains starting with "!" → blocked
block_list = [d[1:] for d in filter_list if d.startswith("!")]
filtered_results = []
for result in results:
@ -23,17 +20,21 @@ def get_filtered_results(results, filter_list):
continue
domain = urlparse(url).netloc
# If allow list is non-empty, require domain to match one of them
if allow_list:
if not any(domain.endswith(allowed) for allowed in allow_list):
continue
# Block list always removes matches
if any(domain.endswith(blocked) for blocked in block_list):
if not domain:
continue
filtered_results.append(result)
hostnames = [domain]
try:
ipv4_addresses, ipv6_addresses = resolve_hostname(domain)
hostnames.extend(ipv4_addresses)
hostnames.extend(ipv6_addresses)
except Exception:
pass
if any(is_string_allowed(hostname, filter_list) for hostname in hostnames):
filtered_results.append(result)
continue
return filtered_results

View file

@ -24,6 +24,7 @@ import validators
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from open_webui.retrieval.loaders.tavily import TavilyLoader
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
from open_webui.constants import ERROR_MESSAGES
@ -38,6 +39,7 @@ from open_webui.config import (
TAVILY_EXTRACT_DEPTH,
EXTERNAL_WEB_LOADER_URL,
EXTERNAL_WEB_LOADER_API_KEY,
WEB_FETCH_FILTER_LIST,
)
from open_webui.env import SRC_LOG_LEVELS
@ -46,10 +48,71 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def get_allow_block_lists(filter_list):
allow_list = []
block_list = []
if filter_list:
for d in filter_list:
if d.startswith("!"):
# Domains starting with "!" → blocked
block_list.append(d[1:])
else:
# Domains starting without "!" → allowed
allow_list.append(d)
return allow_list, block_list
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
if not filter_list:
return True
allow_list, block_list = get_allow_block_lists(filter_list)
# If allow list is non-empty, require domain to match one of them
if allow_list:
if not any(string.endswith(allowed) for allowed in allow_list):
return False
# Block list always removes matches
if any(string.endswith(blocked) for blocked in block_list):
return False
return True
def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
parsed_url = urllib.parse.urlparse(url)
# Protocol validation - only allow http/https
if parsed_url.scheme not in ["http", "https"]:
log.warning(
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
)
raise ValueError(ERROR_MESSAGES.INVALID_URL)
# Blocklist check using unified filtering logic
if WEB_FETCH_FILTER_LIST:
result = is_string_allowed(url, WEB_FETCH_FILTER_LIST)
if len(result) == 0:
log.warning(f"URL blocked by filter list: {url}")
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
@ -82,17 +145,6 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
return valid_urls
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def extract_metadata(soup, url):
metadata = {"source": url}
if title := soup.find("title"):
@ -642,6 +694,10 @@ def get_web_loader(
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
if not safe_urls:
log.warning(f"All provided URLs were blocked or invalid: {urls}")
raise ValueError(ERROR_MESSAGES.INVALID_URL)
web_loader_args = {
"web_paths": safe_urls,
"verify_ssl": verify_ssl,