mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
refac
This commit is contained in:
parent
91755309ce
commit
6d38ac41b6
2 changed files with 40 additions and 39 deletions
|
|
@ -681,15 +681,16 @@ async def signout(request: Request, response: Response):
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
if ENABLE_OAUTH_SIGNUP.value:
|
||||||
# TODO: update this to use oauth_session_tokens in User Object
|
# TODO: update this to use oauth_session_tokens in User Object
|
||||||
oauth_id_token = request.cookies.get("oauth_id_token")
|
oauth_id_token = request.cookies.get("oauth_id_token")
|
||||||
|
|
||||||
if oauth_id_token and OPENID_PROVIDER_URL.value:
|
if oauth_id_token and OPENID_PROVIDER_URL.value:
|
||||||
try:
|
try:
|
||||||
async with ClientSession(trust_env=True) as session:
|
async with ClientSession(trust_env=True) as session:
|
||||||
async with session.get(OPENID_PROVIDER_URL.value) as resp:
|
async with session.get(OPENID_PROVIDER_URL.value) as r:
|
||||||
if resp.status == 200:
|
if r.status == 200:
|
||||||
openid_data = await resp.json()
|
openid_data = await r.json()
|
||||||
logout_url = openid_data.get("end_session_endpoint")
|
logout_url = openid_data.get("end_session_endpoint")
|
||||||
if logout_url:
|
|
||||||
|
|
||||||
|
if logout_url:
|
||||||
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
||||||
response.delete_cookie("oauth_id_token")
|
response.delete_cookie("oauth_id_token")
|
||||||
response.delete_cookie("oauth_access_token")
|
response.delete_cookie("oauth_access_token")
|
||||||
|
|
@ -710,7 +711,7 @@ async def signout(request: Request, response: Response):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=resp.status,
|
status_code=r.status,
|
||||||
detail="Failed to fetch OpenID configuration",
|
detail="Failed to fetch OpenID configuration",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ from open_webui.env import (
|
||||||
WEBUI_NAME,
|
WEBUI_NAME,
|
||||||
WEBUI_AUTH_COOKIE_SAME_SITE,
|
WEBUI_AUTH_COOKIE_SAME_SITE,
|
||||||
WEBUI_AUTH_COOKIE_SECURE,
|
WEBUI_AUTH_COOKIE_SECURE,
|
||||||
|
ENABLE_OAUTH_SESSION_TOKENS_COOKIES,
|
||||||
)
|
)
|
||||||
from open_webui.utils.misc import parse_duration
|
from open_webui.utils.misc import parse_duration
|
||||||
from open_webui.utils.auth import get_password_hash, create_token
|
from open_webui.utils.auth import get_password_hash, create_token
|
||||||
|
|
@ -410,6 +411,8 @@ class OAuthManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"OAuth callback error: {e}")
|
log.warning(f"OAuth callback error: {e}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
# Try to get userinfo from the token first, some providers include it there
|
||||||
user_data: UserInfo = token.get("userinfo")
|
user_data: UserInfo = token.get("userinfo")
|
||||||
if (
|
if (
|
||||||
(not user_data)
|
(not user_data)
|
||||||
|
|
@ -421,18 +424,19 @@ class OAuthManager:
|
||||||
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
log.warning(f"OAuth callback failed, user data is missing: {token}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
|
# Extract the "sub" claim, using custom claim if configured
|
||||||
if auth_manager_config.OAUTH_SUB_CLAIM:
|
if auth_manager_config.OAUTH_SUB_CLAIM:
|
||||||
sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
sub = user_data.get(auth_manager_config.OAUTH_SUB_CLAIM)
|
||||||
else:
|
else:
|
||||||
# Fallback to the default sub claim if not configured
|
# Fallback to the default sub claim if not configured
|
||||||
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub"))
|
||||||
|
|
||||||
if not sub:
|
if not sub:
|
||||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
provider_sub = f"{provider}@{sub}"
|
provider_sub = f"{provider}@{sub}"
|
||||||
|
|
||||||
|
# Email extraction
|
||||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||||
email = user_data.get(email_claim, "")
|
email = user_data.get(email_claim, "")
|
||||||
# We currently mandate that email addresses are provided
|
# We currently mandate that email addresses are provided
|
||||||
|
|
@ -480,6 +484,8 @@ class OAuthManager:
|
||||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
email = email.lower()
|
email = email.lower()
|
||||||
|
|
||||||
|
# If allowed domains are configured, check if the email domain is in the list
|
||||||
if (
|
if (
|
||||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||||
and email.split("@")[-1]
|
and email.split("@")[-1]
|
||||||
|
|
@ -492,7 +498,6 @@ class OAuthManager:
|
||||||
|
|
||||||
# Check if the user exists
|
# Check if the user exists
|
||||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
# If the user does not exist, check if merging is enabled
|
# If the user does not exist, check if merging is enabled
|
||||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||||
|
|
@ -506,7 +511,6 @@ class OAuthManager:
|
||||||
determined_role = self.get_user_role(user, user_data)
|
determined_role = self.get_user_role(user, user_data)
|
||||||
if user.role != determined_role:
|
if user.role != determined_role:
|
||||||
Users.update_user_role_by_id(user.id, determined_role)
|
Users.update_user_role_by_id(user.id, determined_role)
|
||||||
|
|
||||||
# Update profile picture if enabled and different from current
|
# Update profile picture if enabled and different from current
|
||||||
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
||||||
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM
|
||||||
|
|
@ -523,8 +527,7 @@ class OAuthManager:
|
||||||
user.id, processed_picture_url
|
user.id, processed_picture_url
|
||||||
)
|
)
|
||||||
log.debug(f"Updated profile picture for user {user.email}")
|
log.debug(f"Updated profile picture for user {user.email}")
|
||||||
|
else:
|
||||||
if not user:
|
|
||||||
# If the user does not exist, check if signups are enabled
|
# If the user does not exist, check if signups are enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||||
# Check if an existing user with the same email already exists
|
# Check if an existing user with the same email already exists
|
||||||
|
|
@ -543,7 +546,6 @@ class OAuthManager:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
picture_url = "/user.png"
|
picture_url = "/user.png"
|
||||||
|
|
||||||
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM
|
||||||
|
|
||||||
name = user_data.get(username_claim)
|
name = user_data.get(username_claim)
|
||||||
|
|
@ -551,8 +553,6 @@ class OAuthManager:
|
||||||
log.warning("Username claim is missing, using email as name")
|
log.warning("Username claim is missing, using email as name")
|
||||||
name = email
|
name = email
|
||||||
|
|
||||||
role = self.get_user_role(None, user_data)
|
|
||||||
|
|
||||||
user = Auths.insert_new_auth(
|
user = Auths.insert_new_auth(
|
||||||
email=email,
|
email=email,
|
||||||
password=get_password_hash(
|
password=get_password_hash(
|
||||||
|
|
@ -560,7 +560,7 @@ class OAuthManager:
|
||||||
), # Random password, not used
|
), # Random password, not used
|
||||||
name=name,
|
name=name,
|
||||||
profile_image_url=picture_url,
|
profile_image_url=picture_url,
|
||||||
role=role,
|
role=self.get_user_role(None, user_data),
|
||||||
oauth_sub=provider_sub,
|
oauth_sub=provider_sub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -585,7 +585,6 @@ class OAuthManager:
|
||||||
data={"id": user.id},
|
data={"id": user.id},
|
||||||
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
|
auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT
|
||||||
and user.role != "admin"
|
and user.role != "admin"
|
||||||
|
|
@ -626,6 +625,7 @@ class OAuthManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
if ENABLE_OAUTH_SIGNUP.value:
|
if ENABLE_OAUTH_SIGNUP.value:
|
||||||
|
if ENABLE_OAUTH_SESSION_TOKENS_COOKIES:
|
||||||
oauth_id_token = token.get("id_token")
|
oauth_id_token = token.get("id_token")
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="oauth_id_token",
|
key="oauth_id_token",
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue