refac: has_users

Co-Authored-By: pickle-dice <159401444+hassan-ajek@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek 2025-08-05 22:15:22 +04:00
parent 53a328d00b
commit f24b76d9a3
3 changed files with 13 additions and 14 deletions

View file

@ -258,6 +258,10 @@ class UsersTable:
with get_db() as db: with get_db() as db:
return db.query(User).count() return db.query(User).count()
def has_users(self) -> bool:
with get_db() as db:
return db.query(db.query(User).exists()).scalar()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
try: try:
with get_db() as db: with get_db() as db:

View file

@ -351,11 +351,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email)
if not user: if not user:
try: try:
user_count = Users.get_num_users()
role = ( role = (
"admin" "admin"
if user_count == 0 if not Users.has_users()
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
@ -489,7 +487,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if Users.get_user_by_email(admin_email.lower()): if Users.get_user_by_email(admin_email.lower()):
user = Auths.authenticate_user(admin_email.lower(), admin_password) user = Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
if Users.get_num_users() != 0: if Users.has_users():
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup( await signup(
@ -556,6 +554,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse) @router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm): async def signup(request: Request, response: Response, form_data: SignupForm):
has_users = Users.has_users()
if WEBUI_AUTH: if WEBUI_AUTH:
if ( if (
@ -566,12 +565,11 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
else: else:
if Users.get_num_users() != 0: if has_users:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
user_count = Users.get_num_users()
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@ -581,9 +579,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
role = ( role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE
"admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE
)
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
if len(form_data.password.encode("utf-8")) > 72: if len(form_data.password.encode("utf-8")) > 72:
@ -644,7 +640,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
user.id, request.app.state.config.USER_PERMISSIONS user.id, request.app.state.config.USER_PERMISSIONS
) )
if user_count == 0: if not has_users:
# Disable signup after the first user is created # Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False request.app.state.config.ENABLE_SIGNUP = False

View file

@ -88,11 +88,12 @@ class OAuthManager:
return self.oauth.create_client(provider_name) return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data): def get_user_role(self, user, user_data):
if user and Users.get_num_users() == 1: user_count = Users.get_num_users()
if user and user_count == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login # If the user is the only user, assign the role "admin" - actually repairs role for single user on login
log.debug("Assigning the only user the admin role") log.debug("Assigning the only user the admin role")
return "admin" return "admin"
if not user and Users.get_num_users() == 0: if not user and user_count == 0:
# If there are no users, assign the role "admin", as the first user will be an admin # If there are no users, assign the role "admin", as the first user will be an admin
log.debug("Assigning the first user the admin role") log.debug("Assigning the first user the admin role")
return "admin" return "admin"
@ -449,8 +450,6 @@ class OAuthManager:
log.debug(f"Updated profile picture for user {user.email}") log.debug(f"Updated profile picture for user {user.email}")
if not user: if not user:
user_count = Users.get_num_users()
# If the user does not exist, check if signups are enabled # If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP: if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists # Check if an existing user with the same email already exists