mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: has_users
Co-Authored-By: pickle-dice <159401444+hassan-ajek@users.noreply.github.com>
This commit is contained in:
parent
53a328d00b
commit
f24b76d9a3
3 changed files with 13 additions and 14 deletions
|
|
@ -258,6 +258,10 @@ class UsersTable:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return db.query(User).count()
|
return db.query(User).count()
|
||||||
|
|
||||||
|
def has_users(self) -> bool:
|
||||||
|
with get_db() as db:
|
||||||
|
return db.query(db.query(User).exists()).scalar()
|
||||||
|
|
||||||
def get_first_user(self) -> UserModel:
|
def get_first_user(self) -> UserModel:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
|
|
||||||
|
|
@ -351,11 +351,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email)
|
||||||
if not user:
|
if not user:
|
||||||
try:
|
try:
|
||||||
user_count = Users.get_num_users()
|
|
||||||
|
|
||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin"
|
||||||
if user_count == 0
|
if not Users.has_users()
|
||||||
else request.app.state.config.DEFAULT_USER_ROLE
|
else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -489,7 +487,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
if Users.get_user_by_email(admin_email.lower()):
|
if Users.get_user_by_email(admin_email.lower()):
|
||||||
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
user = Auths.authenticate_user(admin_email.lower(), admin_password)
|
||||||
else:
|
else:
|
||||||
if Users.get_num_users() != 0:
|
if Users.has_users():
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||||
|
|
||||||
await signup(
|
await signup(
|
||||||
|
|
@ -556,6 +554,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||||
|
|
||||||
@router.post("/signup", response_model=SessionUserResponse)
|
@router.post("/signup", response_model=SessionUserResponse)
|
||||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
|
has_users = Users.has_users()
|
||||||
|
|
||||||
if WEBUI_AUTH:
|
if WEBUI_AUTH:
|
||||||
if (
|
if (
|
||||||
|
|
@ -566,12 +565,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if Users.get_num_users() != 0:
|
if has_users:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||||
)
|
)
|
||||||
|
|
||||||
user_count = Users.get_num_users()
|
|
||||||
if not validate_email_format(form_data.email.lower()):
|
if not validate_email_format(form_data.email.lower()):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||||
|
|
@ -581,9 +579,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
role = (
|
role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
|
||||||
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
|
|
||||||
)
|
|
||||||
|
|
||||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||||
if len(form_data.password.encode("utf-8")) > 72:
|
if len(form_data.password.encode("utf-8")) > 72:
|
||||||
|
|
@ -644,7 +640,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||||
user.id, request.app.state.config.USER_PERMISSIONS
|
user.id, request.app.state.config.USER_PERMISSIONS
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_count == 0:
|
if not has_users:
|
||||||
# Disable signup after the first user is created
|
# Disable signup after the first user is created
|
||||||
request.app.state.config.ENABLE_SIGNUP = False
|
request.app.state.config.ENABLE_SIGNUP = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,11 +88,12 @@ class OAuthManager:
|
||||||
return self.oauth.create_client(provider_name)
|
return self.oauth.create_client(provider_name)
|
||||||
|
|
||||||
def get_user_role(self, user, user_data):
|
def get_user_role(self, user, user_data):
|
||||||
if user and Users.get_num_users() == 1:
|
user_count = Users.get_num_users()
|
||||||
|
if user and user_count == 1:
|
||||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||||
log.debug("Assigning the only user the admin role")
|
log.debug("Assigning the only user the admin role")
|
||||||
return "admin"
|
return "admin"
|
||||||
if not user and Users.get_num_users() == 0:
|
if not user and user_count == 0:
|
||||||
# If there are no users, assign the role "admin", as the first user will be an admin
|
# If there are no users, assign the role "admin", as the first user will be an admin
|
||||||
log.debug("Assigning the first user the admin role")
|
log.debug("Assigning the first user the admin role")
|
||||||
return "admin"
|
return "admin"
|
||||||
|
|
@ -449,8 +450,6 @@ class OAuthManager:
|
||||||
log.debug(f"Updated profile picture for user {user.email}")
|
log.debug(f"Updated profile picture for user {user.email}")
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
user_count = Users.get_num_users()
|
|
||||||
|
|
||||||
# If the user does not exist, check if signups are enabled
|
# If the user does not exist, check if signups are enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||||
# Check if an existing user with the same email already exists
|
# Check if an existing user with the same email already exists
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue