mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 20:35:19 +00:00
refac: oauth_sub -> oauth migration
This commit is contained in:
parent
369298a83e
commit
0a4358c3d1
3 changed files with 43 additions and 17 deletions
|
|
@ -88,7 +88,7 @@ class AuthsTable:
|
||||||
name: str,
|
name: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = "/user.png",
|
||||||
role: str = "pending",
|
role: str = "pending",
|
||||||
oauth_sub: Optional[str] = None,
|
oauth: Optional[dict] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
log.info("insert_new_auth")
|
log.info("insert_new_auth")
|
||||||
|
|
@ -102,7 +102,7 @@ class AuthsTable:
|
||||||
db.add(result)
|
db.add(result)
|
||||||
|
|
||||||
user = Users.insert_new_user(
|
user = Users.insert_new_user(
|
||||||
id, name, email, profile_image_url, role, oauth_sub
|
id, name, email, profile_image_url, role, oauth=oauth
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,7 @@ class UsersTable:
|
||||||
email: str,
|
email: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = "/user.png",
|
||||||
role: str = "pending",
|
role: str = "pending",
|
||||||
oauth_sub: Optional[str] = None,
|
oauth: Optional[dict] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = UserModel(
|
user = UserModel(
|
||||||
|
|
@ -238,7 +238,7 @@ class UsersTable:
|
||||||
"last_active_at": int(time.time()),
|
"last_active_at": int(time.time()),
|
||||||
"created_at": int(time.time()),
|
"created_at": int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
"updated_at": int(time.time()),
|
||||||
"oauth_sub": oauth_sub,
|
"oauth": oauth,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = User(**user.model_dump())
|
result = User(**user.model_dump())
|
||||||
|
|
@ -274,11 +274,15 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
|
def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
user = db.query(User).filter_by(oauth_sub=sub).first()
|
user = (
|
||||||
return UserModel.model_validate(user)
|
db.query(User)
|
||||||
|
.filter(User.oauth.contains({provider: {"sub": sub}}))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return UserModel.model_validate(user) if user else None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -493,16 +497,35 @@ class UsersTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_oauth_sub_by_id(
|
def update_user_oauth_by_id(
|
||||||
self, id: str, oauth_sub: str
|
self, id: str, provider: str, sub: str
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
|
"""
|
||||||
|
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
||||||
|
Example resulting structure:
|
||||||
|
{
|
||||||
|
"google": { "sub": "123" },
|
||||||
|
"github": { "sub": "abc" }
|
||||||
|
}
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
|
user = db.query(User).filter_by(id=id).first()
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load existing oauth JSON or create empty
|
||||||
|
oauth = user.oauth or {}
|
||||||
|
|
||||||
|
# Update or insert provider entry
|
||||||
|
oauth[provider] = {"sub": sub}
|
||||||
|
|
||||||
|
# Persist updated JSON
|
||||||
|
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
user = db.query(User).filter_by(id=id).first()
|
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1329,7 +1329,10 @@ class OAuthManager:
|
||||||
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
provider_sub = f"{provider}@{sub}"
|
oauth_data = {}
|
||||||
|
oauth_data[provider] = {
|
||||||
|
"sub": sub,
|
||||||
|
}
|
||||||
|
|
||||||
# Email extraction
|
# Email extraction
|
||||||
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
|
||||||
|
|
@ -1376,12 +1379,12 @@ class OAuthManager:
|
||||||
log.warning(f"Error fetching GitHub email: {e}")
|
log.warning(f"Error fetching GitHub email: {e}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
elif ENABLE_OAUTH_EMAIL_FALLBACK:
|
elif ENABLE_OAUTH_EMAIL_FALLBACK:
|
||||||
email = f"{provider_sub}.local"
|
email = f"{provider}@{sub}.local"
|
||||||
else:
|
else:
|
||||||
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
log.warning(f"OAuth callback failed, email is missing: {user_data}")
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
email = email.lower()
|
|
||||||
|
|
||||||
|
email = email.lower()
|
||||||
# If allowed domains are configured, check if the email domain is in the list
|
# If allowed domains are configured, check if the email domain is in the list
|
||||||
if (
|
if (
|
||||||
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
|
||||||
|
|
@ -1394,7 +1397,7 @@ class OAuthManager:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||||
|
|
||||||
# Check if the user exists
|
# Check if the user exists
|
||||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
user = Users.get_user_by_oauth_sub(provider, sub)
|
||||||
if not user:
|
if not user:
|
||||||
# If the user does not exist, check if merging is enabled
|
# If the user does not exist, check if merging is enabled
|
||||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||||
|
|
@ -1402,7 +1405,7 @@ class OAuthManager:
|
||||||
user = Users.get_user_by_email(email)
|
user = Users.get_user_by_email(email)
|
||||||
if user:
|
if user:
|
||||||
# Update the user with the new oauth sub
|
# Update the user with the new oauth sub
|
||||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
Users.update_user_oauth_by_id(user.id, provider, sub)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
determined_role = self.get_user_role(user, user_data)
|
determined_role = self.get_user_role(user, user_data)
|
||||||
|
|
@ -1461,7 +1464,7 @@ class OAuthManager:
|
||||||
name=name,
|
name=name,
|
||||||
profile_image_url=picture_url,
|
profile_image_url=picture_url,
|
||||||
role=self.get_user_role(None, user_data),
|
role=self.get_user_role(None, user_data),
|
||||||
oauth_sub=provider_sub,
|
oauth=oauth_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
if auth_manager_config.WEBHOOK_URL:
|
if auth_manager_config.WEBHOOK_URL:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue