mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
feat/enh: optional password validation
This commit is contained in:
parent
e6c7495c1a
commit
680cde8f9b
5 changed files with 62 additions and 12 deletions
|
|
@ -45,7 +45,7 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
)
|
)
|
||||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
||||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||||
INVALID_PASSWORD = (
|
INCORRECT_PASSWORD = (
|
||||||
"The password provided is incorrect. Please check for typos and try again."
|
"The password provided is incorrect. Please check for typos and try again."
|
||||||
)
|
)
|
||||||
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
||||||
|
|
@ -105,6 +105,10 @@ class ERROR_MESSAGES(str, Enum):
|
||||||
)
|
)
|
||||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
||||||
|
|
||||||
|
INVALID_PASSWORD = lambda err="": (
|
||||||
|
err if err else "The password does not meet the required validation criteria."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TASKS(str, Enum):
|
class TASKS(str, Enum):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ import shutil
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
@ -429,6 +431,17 @@ WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ENABLE_PASSWORD_VALIDATION = (
|
||||||
|
os.environ.get("ENABLE_PASSWORD_VALIDATION", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
PASSWORD_VALIDATION_REGEX_PATTERN = os.environ.get(
|
||||||
|
"PASSWORD_VALIDATION_REGEX_PATTERN",
|
||||||
|
"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[^\w\s]).{8,}$",
|
||||||
|
)
|
||||||
|
|
||||||
|
PASSWORD_VALIDATION_REGEX_PATTERN = re.compile(PASSWORD_VALIDATION_REGEX_PATTERN)
|
||||||
|
|
||||||
|
|
||||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||||
from open_webui.utils.auth import (
|
from open_webui.utils.auth import (
|
||||||
|
validate_password,
|
||||||
verify_password,
|
verify_password,
|
||||||
decode_token,
|
decode_token,
|
||||||
invalidate_token,
|
invalidate_token,
|
||||||
|
|
@ -181,10 +182,14 @@ async def update_password(
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(400, detail=str(e))
|
||||||
hashed = get_password_hash(form_data.new_password)
|
hashed = get_password_hash(form_data.new_password)
|
||||||
return Auths.update_user_password_by_id(user.id, hashed)
|
return Auths.update_user_password_by_id(user.id, hashed)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
|
@ -627,16 +632,14 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
except Exception as e:
|
||||||
if len(form_data.password.encode("utf-8")) > 72:
|
raise HTTPException(400, detail=str(e))
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
|
||||||
)
|
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
|
|
||||||
|
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
form_data.email.lower(),
|
form_data.email.lower(),
|
||||||
hashed,
|
hashed,
|
||||||
|
|
@ -805,6 +808,11 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
form_data.email.lower(),
|
form_data.email.lower(),
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,12 @@ from open_webui.constants import ERROR_MESSAGES
|
||||||
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
|
||||||
|
|
||||||
|
|
||||||
from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
|
from open_webui.utils.auth import (
|
||||||
|
get_admin_user,
|
||||||
|
get_password_hash,
|
||||||
|
get_verified_user,
|
||||||
|
validate_password,
|
||||||
|
)
|
||||||
from open_webui.utils.access_control import get_permissions, has_permission
|
from open_webui.utils.access_control import get_permissions, has_permission
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -497,8 +502,12 @@ async def update_user_by_id(
|
||||||
)
|
)
|
||||||
|
|
||||||
if form_data.password:
|
if form_data.password:
|
||||||
|
try:
|
||||||
|
validate_password(form_data.password)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(400, detail=str(e))
|
||||||
|
|
||||||
hashed = get_password_hash(form_data.password)
|
hashed = get_password_hash(form_data.password)
|
||||||
log.debug(f"hashed: {hashed}")
|
|
||||||
Auths.update_user_password_by_id(user_id, hashed)
|
Auths.update_user_password_by_id(user_id, hashed)
|
||||||
|
|
||||||
Auths.update_email_by_id(user_id, form_data.email.lower())
|
Auths.update_email_by_id(user_id, form_data.email.lower())
|
||||||
|
|
|
||||||
|
|
@ -28,8 +28,10 @@ from open_webui.models.users import Users
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
from open_webui.env import (
|
from open_webui.env import (
|
||||||
|
ENABLE_PASSWORD_VALIDATION,
|
||||||
OFFLINE_MODE,
|
OFFLINE_MODE,
|
||||||
LICENSE_BLOB,
|
LICENSE_BLOB,
|
||||||
|
PASSWORD_VALIDATION_REGEX_PATTERN,
|
||||||
REDIS_KEY_PREFIX,
|
REDIS_KEY_PREFIX,
|
||||||
pk,
|
pk,
|
||||||
WEBUI_SECRET_KEY,
|
WEBUI_SECRET_KEY,
|
||||||
|
|
@ -162,6 +164,20 @@ def get_password_hash(password: str) -> str:
|
||||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_password(password: str) -> bool:
|
||||||
|
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||||
|
if len(password.encode("utf-8")) > 72:
|
||||||
|
raise Exception(
|
||||||
|
ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ENABLE_PASSWORD_VALIDATION:
|
||||||
|
if not PASSWORD_VALIDATION_REGEX_PATTERN.match(password):
|
||||||
|
raise Exception(ERROR_MESSAGES.INVALID_PASSWORD())
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""Verify a password against its hash"""
|
"""Verify a password against its hash"""
|
||||||
return (
|
return (
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue