mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
feat/security: Add SSRF protection with configurable blocklist
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
parent
6cdb13d5cb
commit
02238d3113
3 changed files with 102 additions and 25 deletions
|
|
@ -2861,6 +2861,26 @@ ENABLE_RAG_LOCAL_WEB_FETCH = (
|
|||
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_WEB_FETCH_FILTER_LIST = [
|
||||
"!169.254.169.254",
|
||||
"!fd00:ec2::254",
|
||||
"!metadata.google.internal",
|
||||
"!metadata.azure.com",
|
||||
"!100.100.100.200",
|
||||
]
|
||||
|
||||
web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "")
|
||||
if web_fetch_filter_list == "":
|
||||
web_fetch_filter_list = []
|
||||
else:
|
||||
web_fetch_filter_list = [
|
||||
item.strip() for item in web_fetch_filter_list.split(",") if item.strip()
|
||||
]
|
||||
|
||||
WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list))
|
||||
|
||||
|
||||
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
||||
"YOUTUBE_LOADER_LANGUAGE",
|
||||
"rag.youtube_loader_language",
|
||||
|
|
|
|||
|
|
@ -5,16 +5,13 @@ from urllib.parse import urlparse
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname
|
||||
|
||||
|
||||
def get_filtered_results(results, filter_list):
|
||||
if not filter_list:
|
||||
return results
|
||||
|
||||
# Domains starting without "!" → allowed
|
||||
allow_list = [d for d in filter_list if not d.startswith("!")]
|
||||
# Domains starting with "!" → blocked
|
||||
block_list = [d[1:] for d in filter_list if d.startswith("!")]
|
||||
|
||||
filtered_results = []
|
||||
|
||||
for result in results:
|
||||
|
|
@ -23,17 +20,21 @@ def get_filtered_results(results, filter_list):
|
|||
continue
|
||||
|
||||
domain = urlparse(url).netloc
|
||||
|
||||
# If allow list is non-empty, require domain to match one of them
|
||||
if allow_list:
|
||||
if not any(domain.endswith(allowed) for allowed in allow_list):
|
||||
continue
|
||||
|
||||
# Block list always removes matches
|
||||
if any(domain.endswith(blocked) for blocked in block_list):
|
||||
if not domain:
|
||||
continue
|
||||
|
||||
filtered_results.append(result)
|
||||
hostnames = [domain]
|
||||
|
||||
try:
|
||||
ipv4_addresses, ipv6_addresses = resolve_hostname(domain)
|
||||
hostnames.extend(ipv4_addresses)
|
||||
hostnames.extend(ipv6_addresses)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if any(is_string_allowed(hostname, filter_list) for hostname in hostnames):
|
||||
filtered_results.append(result)
|
||||
continue
|
||||
|
||||
return filtered_results
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import validators
|
|||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
|
@ -38,6 +39,7 @@ from open_webui.config import (
|
|||
TAVILY_EXTRACT_DEPTH,
|
||||
EXTERNAL_WEB_LOADER_URL,
|
||||
EXTERNAL_WEB_LOADER_API_KEY,
|
||||
WEB_FETCH_FILTER_LIST,
|
||||
)
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
|
@ -46,10 +48,71 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
|
||||
# Extract IP addresses from address information
|
||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def get_allow_block_lists(filter_list):
|
||||
allow_list = []
|
||||
block_list = []
|
||||
|
||||
if filter_list:
|
||||
for d in filter_list:
|
||||
if d.startswith("!"):
|
||||
# Domains starting with "!" → blocked
|
||||
block_list.append(d[1:])
|
||||
else:
|
||||
# Domains starting without "!" → allowed
|
||||
allow_list.append(d)
|
||||
|
||||
return allow_list, block_list
|
||||
|
||||
|
||||
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
|
||||
if not filter_list:
|
||||
return True
|
||||
|
||||
allow_list, block_list = get_allow_block_lists(filter_list)
|
||||
# If allow list is non-empty, require domain to match one of them
|
||||
if allow_list:
|
||||
if not any(string.endswith(allowed) for allowed in allow_list):
|
||||
return False
|
||||
|
||||
# Block list always removes matches
|
||||
if any(string.endswith(blocked) for blocked in block_list):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_url(url: Union[str, Sequence[str]]):
|
||||
if isinstance(url, str):
|
||||
if isinstance(validators.url(url), validators.ValidationError):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
|
||||
# Protocol validation - only allow http/https
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
log.warning(
|
||||
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
|
||||
)
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
# Blocklist check using unified filtering logic
|
||||
if WEB_FETCH_FILTER_LIST:
|
||||
result = is_string_allowed(url, WEB_FETCH_FILTER_LIST)
|
||||
if len(result) == 0:
|
||||
log.warning(f"URL blocked by filter list: {url}")
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
|
|
@ -82,17 +145,6 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
|||
return valid_urls
|
||||
|
||||
|
||||
def resolve_hostname(hostname):
|
||||
# Get address information
|
||||
addr_info = socket.getaddrinfo(hostname, None)
|
||||
|
||||
# Extract IP addresses from address information
|
||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def extract_metadata(soup, url):
|
||||
metadata = {"source": url}
|
||||
if title := soup.find("title"):
|
||||
|
|
@ -642,6 +694,10 @@ def get_web_loader(
|
|||
# Check if the URLs are valid
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
if not safe_urls:
|
||||
log.warning(f"All provided URLs were blocked or invalid: {urls}")
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
web_loader_args = {
|
||||
"web_paths": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
|
|
|
|||
Loading…
Reference in a new issue