mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-17 14:55:23 +00:00
feat/security: Add SSRF protection with configurable blocklist
Co-Authored-By: Classic298 <27028174+Classic298@users.noreply.github.com>
This commit is contained in:
parent
6cdb13d5cb
commit
02238d3113
3 changed files with 102 additions and 25 deletions
|
|
@ -2861,6 +2861,26 @@ ENABLE_RAG_LOCAL_WEB_FETCH = (
|
||||||
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WEB_FETCH_FILTER_LIST = [
|
||||||
|
"!169.254.169.254",
|
||||||
|
"!fd00:ec2::254",
|
||||||
|
"!metadata.google.internal",
|
||||||
|
"!metadata.azure.com",
|
||||||
|
"!100.100.100.200",
|
||||||
|
]
|
||||||
|
|
||||||
|
web_fetch_filter_list = os.getenv("WEB_FETCH_FILTER_LIST", "")
|
||||||
|
if web_fetch_filter_list == "":
|
||||||
|
web_fetch_filter_list = []
|
||||||
|
else:
|
||||||
|
web_fetch_filter_list = [
|
||||||
|
item.strip() for item in web_fetch_filter_list.split(",") if item.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
WEB_FETCH_FILTER_LIST = list(set(DEFAULT_WEB_FETCH_FILTER_LIST + web_fetch_filter_list))
|
||||||
|
|
||||||
|
|
||||||
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
||||||
"YOUTUBE_LOADER_LANGUAGE",
|
"YOUTUBE_LOADER_LANGUAGE",
|
||||||
"rag.youtube_loader_language",
|
"rag.youtube_loader_language",
|
||||||
|
|
|
||||||
|
|
@ -5,16 +5,13 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from open_webui.retrieval.web.utils import is_string_allowed, resolve_hostname
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_results(results, filter_list):
|
def get_filtered_results(results, filter_list):
|
||||||
if not filter_list:
|
if not filter_list:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# Domains starting without "!" → allowed
|
|
||||||
allow_list = [d for d in filter_list if not d.startswith("!")]
|
|
||||||
# Domains starting with "!" → blocked
|
|
||||||
block_list = [d[1:] for d in filter_list if d.startswith("!")]
|
|
||||||
|
|
||||||
filtered_results = []
|
filtered_results = []
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
|
|
@ -23,17 +20,21 @@ def get_filtered_results(results, filter_list):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
domain = urlparse(url).netloc
|
domain = urlparse(url).netloc
|
||||||
|
if not domain:
|
||||||
# If allow list is non-empty, require domain to match one of them
|
|
||||||
if allow_list:
|
|
||||||
if not any(domain.endswith(allowed) for allowed in allow_list):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Block list always removes matches
|
hostnames = [domain]
|
||||||
if any(domain.endswith(blocked) for blocked in block_list):
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
ipv4_addresses, ipv6_addresses = resolve_hostname(domain)
|
||||||
|
hostnames.extend(ipv4_addresses)
|
||||||
|
hostnames.extend(ipv6_addresses)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if any(is_string_allowed(hostname, filter_list) for hostname in hostnames):
|
||||||
filtered_results.append(result)
|
filtered_results.append(result)
|
||||||
|
continue
|
||||||
|
|
||||||
return filtered_results
|
return filtered_results
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ import validators
|
||||||
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
from open_webui.retrieval.loaders.tavily import TavilyLoader
|
||||||
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
from open_webui.retrieval.loaders.external_web import ExternalWebLoader
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
@ -38,6 +39,7 @@ from open_webui.config import (
|
||||||
TAVILY_EXTRACT_DEPTH,
|
TAVILY_EXTRACT_DEPTH,
|
||||||
EXTERNAL_WEB_LOADER_URL,
|
EXTERNAL_WEB_LOADER_URL,
|
||||||
EXTERNAL_WEB_LOADER_API_KEY,
|
EXTERNAL_WEB_LOADER_API_KEY,
|
||||||
|
WEB_FETCH_FILTER_LIST,
|
||||||
)
|
)
|
||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
|
|
@ -46,10 +48,71 @@ log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname(hostname):
|
||||||
|
# Get address information
|
||||||
|
addr_info = socket.getaddrinfo(hostname, None)
|
||||||
|
|
||||||
|
# Extract IP addresses from address information
|
||||||
|
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
||||||
|
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
||||||
|
|
||||||
|
return ipv4_addresses, ipv6_addresses
|
||||||
|
|
||||||
|
|
||||||
|
def get_allow_block_lists(filter_list):
|
||||||
|
allow_list = []
|
||||||
|
block_list = []
|
||||||
|
|
||||||
|
if filter_list:
|
||||||
|
for d in filter_list:
|
||||||
|
if d.startswith("!"):
|
||||||
|
# Domains starting with "!" → blocked
|
||||||
|
block_list.append(d[1:])
|
||||||
|
else:
|
||||||
|
# Domains starting without "!" → allowed
|
||||||
|
allow_list.append(d)
|
||||||
|
|
||||||
|
return allow_list, block_list
|
||||||
|
|
||||||
|
|
||||||
|
def is_string_allowed(string: str, filter_list: Optional[list[str]] = None) -> bool:
|
||||||
|
if not filter_list:
|
||||||
|
return True
|
||||||
|
|
||||||
|
allow_list, block_list = get_allow_block_lists(filter_list)
|
||||||
|
# If allow list is non-empty, require domain to match one of them
|
||||||
|
if allow_list:
|
||||||
|
if not any(string.endswith(allowed) for allowed in allow_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Block list always removes matches
|
||||||
|
if any(string.endswith(blocked) for blocked in block_list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def validate_url(url: Union[str, Sequence[str]]):
|
def validate_url(url: Union[str, Sequence[str]]):
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
if isinstance(validators.url(url), validators.ValidationError):
|
if isinstance(validators.url(url), validators.ValidationError):
|
||||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
|
||||||
|
# Protocol validation - only allow http/https
|
||||||
|
if parsed_url.scheme not in ["http", "https"]:
|
||||||
|
log.warning(
|
||||||
|
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
|
||||||
|
)
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
|
# Blocklist check using unified filtering logic
|
||||||
|
if WEB_FETCH_FILTER_LIST:
|
||||||
|
result = is_string_allowed(url, WEB_FETCH_FILTER_LIST)
|
||||||
|
if len(result) == 0:
|
||||||
|
log.warning(f"URL blocked by filter list: {url}")
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||||
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
|
||||||
parsed_url = urllib.parse.urlparse(url)
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
|
@ -82,17 +145,6 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||||
return valid_urls
|
return valid_urls
|
||||||
|
|
||||||
|
|
||||||
def resolve_hostname(hostname):
|
|
||||||
# Get address information
|
|
||||||
addr_info = socket.getaddrinfo(hostname, None)
|
|
||||||
|
|
||||||
# Extract IP addresses from address information
|
|
||||||
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
|
|
||||||
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
|
|
||||||
|
|
||||||
return ipv4_addresses, ipv6_addresses
|
|
||||||
|
|
||||||
|
|
||||||
def extract_metadata(soup, url):
|
def extract_metadata(soup, url):
|
||||||
metadata = {"source": url}
|
metadata = {"source": url}
|
||||||
if title := soup.find("title"):
|
if title := soup.find("title"):
|
||||||
|
|
@ -642,6 +694,10 @@ def get_web_loader(
|
||||||
# Check if the URLs are valid
|
# Check if the URLs are valid
|
||||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||||
|
|
||||||
|
if not safe_urls:
|
||||||
|
log.warning(f"All provided URLs were blocked or invalid: {urls}")
|
||||||
|
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||||
|
|
||||||
web_loader_args = {
|
web_loader_args = {
|
||||||
"web_paths": safe_urls,
|
"web_paths": safe_urls,
|
||||||
"verify_ssl": verify_ssl,
|
"verify_ssl": verify_ssl,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue