refac: oauth_sub -> oauth migration

This commit is contained in:
Timothy Jaeryang Baek 2025-11-28 06:39:36 -05:00
parent 369298a83e
commit 0a4358c3d1
3 changed files with 43 additions and 17 deletions

View file

@ -88,7 +88,7 @@ class AuthsTable:
name: str, name: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth: Optional[dict] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
log.info("insert_new_auth") log.info("insert_new_auth")
@ -102,7 +102,7 @@ class AuthsTable:
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth=oauth
) )
db.commit() db.commit()

View file

@ -225,7 +225,7 @@ class UsersTable:
email: str, email: str,
profile_image_url: str = "/user.png", profile_image_url: str = "/user.png",
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth: Optional[dict] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db() as db: with get_db() as db:
user = UserModel( user = UserModel(
@ -238,7 +238,7 @@ class UsersTable:
"last_active_at": int(time.time()), "last_active_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
"updated_at": int(time.time()), "updated_at": int(time.time()),
"oauth_sub": oauth_sub, "oauth": oauth,
} }
) )
result = User(**user.model_dump()) result = User(**user.model_dump())
@ -274,11 +274,15 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, provider: str, sub: str) -> Optional[UserModel]:
try: try:
with get_db() as db: with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first() user = (
return UserModel.model_validate(user) db.query(User)
.filter(User.oauth.contains({provider: {"sub": sub}}))
.first()
)
return UserModel.model_validate(user) if user else None
except Exception: except Exception:
return None return None
@ -493,16 +497,35 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_oauth_sub_by_id( def update_user_oauth_by_id(
self, id: str, oauth_sub: str self, id: str, provider: str, sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
"""
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
Example resulting structure:
{
"google": { "sub": "123" },
"github": { "sub": "abc" }
}
"""
try: try:
with get_db() as db: with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) user = db.query(User).filter_by(id=id).first()
if not user:
return None
# Load existing oauth JSON or create empty
oauth = user.oauth or {}
# Update or insert provider entry
oauth[provider] = {"sub": sub}
# Persist updated JSON
db.query(User).filter_by(id=id).update({"oauth": oauth})
db.commit() db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except Exception: except Exception:
return None return None

View file

@ -1329,7 +1329,10 @@ class OAuthManager:
log.warning(f"OAuth callback failed, sub is missing: {user_data}") log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}" oauth_data = {}
oauth_data[provider] = {
"sub": sub,
}
# Email extraction # Email extraction
email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM
@ -1376,12 +1379,12 @@ class OAuthManager:
log.warning(f"Error fetching GitHub email: {e}") log.warning(f"Error fetching GitHub email: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
elif ENABLE_OAUTH_EMAIL_FALLBACK: elif ENABLE_OAUTH_EMAIL_FALLBACK:
email = f"{provider_sub}.local" email = f"{provider}@{sub}.local"
else: else:
log.warning(f"OAuth callback failed, email is missing: {user_data}") log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
email = email.lower()
email = email.lower()
# If allowed domains are configured, check if the email domain is in the list # If allowed domains are configured, check if the email domain is in the list
if ( if (
"*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS
@ -1394,7 +1397,7 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists # Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub) user = Users.get_user_by_oauth_sub(provider, sub)
if not user: if not user:
# If the user does not exist, check if merging is enabled # If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
@ -1402,7 +1405,7 @@ class OAuthManager:
user = Users.get_user_by_email(email) user = Users.get_user_by_email(email)
if user: if user:
# Update the user with the new oauth sub # Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub) Users.update_user_oauth_by_id(user.id, provider, sub)
if user: if user:
determined_role = self.get_user_role(user, user_data) determined_role = self.get_user_role(user, user_data)
@ -1461,7 +1464,7 @@ class OAuthManager:
name=name, name=name,
profile_image_url=picture_url, profile_image_url=picture_url,
role=self.get_user_role(None, user_data), role=self.get_user_role(None, user_data),
oauth_sub=provider_sub, oauth=oauth_data,
) )
if auth_manager_config.WEBHOOK_URL: if auth_manager_config.WEBHOOK_URL: