diff --git a/backend/open_webui/test/test_oauth_google_groups.py b/backend/open_webui/test/test_oauth_google_groups.py new file mode 100644 index 0000000000..9bc1de9af2 --- /dev/null +++ b/backend/open_webui/test/test_oauth_google_groups.py @@ -0,0 +1,266 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import aiohttp +from open_webui.utils.oauth import OAuthManager +from open_webui.config import AppConfig + + +class TestOAuthGoogleGroups: + """Basic tests for Google OAuth Groups functionality""" + + def setup_method(self): + """Setup test fixtures""" + self.oauth_manager = OAuthManager(app=MagicMock()) + + @pytest.mark.asyncio + async def test_fetch_google_groups_success(self): + """Test successful Google groups fetching with proper aiohttp mocking""" + # Mock response data from Google Cloud Identity API + mock_response_data = { + "memberships": [ + { + "groupKey": {"id": "admin@company.com"}, + "group": "groups/123", + "displayName": "Admin Group" + }, + { + "groupKey": {"id": "users@company.com"}, + "group": "groups/456", + "displayName": "Users Group" + } + ] + } + + # Create properly structured async mocks + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=mock_response_data) + + # Mock the async context manager for session.get() + mock_get_context = MagicMock() + mock_get_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_get_context.__aexit__ = AsyncMock(return_value=None) + + # Mock the session + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_get_context) + + # Mock the async context manager for ClientSession + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + with patch("aiohttp.ClientSession", return_value=mock_session_context): + groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity( + access_token="test_token", + user_email="user@company.com" + ) + + # Verify the results + assert groups == ["admin@company.com", "users@company.com"] + + # Verify the HTTP call was made correctly + mock_session.get.assert_called_once() + call_args = mock_session.get.call_args + + # Check the URL contains the user email (URL encoded) + url_arg = call_args[0][0] # First positional argument + assert "user%40company.com" in url_arg # @ is encoded as %40 + assert "searchTransitiveGroups" in url_arg + + # Check headers contain the bearer token + headers_arg = call_args[1]["headers"] # headers keyword argument + assert headers_arg["Authorization"] == "Bearer test_token" + assert headers_arg["Content-Type"] == "application/json" + + @pytest.mark.asyncio + async def test_fetch_google_groups_api_error(self): + """Test handling of API errors when fetching groups""" + # Mock failed response + mock_response = MagicMock() + mock_response.status = 403 + mock_response.text = AsyncMock(return_value="Permission denied") + + # Mock the async context manager for session.get() + mock_get_context = MagicMock() + mock_get_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_get_context.__aexit__ = AsyncMock(return_value=None) + + # Mock the session + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_get_context) + + # Mock the async context manager for ClientSession + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + with patch("aiohttp.ClientSession", return_value=mock_session_context): + groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity( + access_token="test_token", + user_email="user@company.com" + ) + + # Should return empty list on error + assert groups == [] + + @pytest.mark.asyncio + async def test_fetch_google_groups_network_error(self): + """Test handling of network errors when fetching groups""" + # Mock the session that raises an exception when get() is called + mock_session = MagicMock() + mock_session.get.side_effect = aiohttp.ClientError("Network error") + + # Mock the async context manager for ClientSession + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + with patch("aiohttp.ClientSession", return_value=mock_session_context): + groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity( + access_token="test_token", + user_email="user@company.com" + ) + + # Should return empty list on network error + assert groups == [] + + @pytest.mark.asyncio + async def test_get_user_role_with_google_groups(self): + """Test role assignment using Google groups""" + # Mock configuration + mock_config = MagicMock() + mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True + mock_config.OAUTH_ROLES_CLAIM = "groups" + mock_config.OAUTH_ALLOWED_ROLES = ["users@company.com"] + mock_config.OAUTH_ADMIN_ROLES = ["admin@company.com"] + mock_config.DEFAULT_USER_ROLE = "pending" + mock_config.OAUTH_EMAIL_CLAIM = "email" + + user_data = {"email": "user@company.com"} + + # Mock Google OAuth scope check and Users class + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \ + patch("open_webui.utils.oauth.Users") as mock_users, \ + patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch: + + mock_scope.value = "openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly" + mock_fetch.return_value = ["admin@company.com", "users@company.com"] + mock_users.get_num_users.return_value = 5 # Not first user + + role = await self.oauth_manager.get_user_role( + user=None, + user_data=user_data, + provider="google", + access_token="test_token" + ) + + # Should assign admin role since user is in admin group + assert role == "admin" + mock_fetch.assert_called_once_with("test_token", "user@company.com") + + @pytest.mark.asyncio + async def test_get_user_role_fallback_to_claims(self): + """Test fallback to traditional claims when Google groups fail""" + mock_config = MagicMock() + mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True + mock_config.OAUTH_ROLES_CLAIM = "groups" + mock_config.OAUTH_ALLOWED_ROLES = ["users"] + mock_config.OAUTH_ADMIN_ROLES = ["admin"] + mock_config.DEFAULT_USER_ROLE = "pending" + mock_config.OAUTH_EMAIL_CLAIM = "email" + + user_data = { + "email": "user@company.com", + "groups": ["users"] + } + + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \ + patch("open_webui.utils.oauth.Users") as mock_users, \ + patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch: + + # Mock scope without Cloud Identity + mock_scope.value = "openid email profile" + mock_users.get_num_users.return_value = 5 # Not first user + + role = await self.oauth_manager.get_user_role( + user=None, + user_data=user_data, + provider="google", + access_token="test_token" + ) + + # Should use traditional claims since Cloud Identity scope not present + assert role == "user" + mock_fetch.assert_not_called() + + @pytest.mark.asyncio + async def test_get_user_role_non_google_provider(self): + """Test that non-Google providers use traditional claims""" + mock_config = MagicMock() + mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True + mock_config.OAUTH_ROLES_CLAIM = "roles" + mock_config.OAUTH_ALLOWED_ROLES = ["user"] + mock_config.OAUTH_ADMIN_ROLES = ["admin"] + mock_config.DEFAULT_USER_ROLE = "pending" + + user_data = {"roles": ["user"]} + + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.Users") as mock_users, \ + patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch: + + mock_users.get_num_users.return_value = 5 # Not first user + + role = await self.oauth_manager.get_user_role( + user=None, + user_data=user_data, + provider="microsoft", + access_token="test_token" + ) + + # Should use traditional claims for non-Google providers + assert role == "user" + mock_fetch.assert_not_called() + + @pytest.mark.asyncio + async def test_update_user_groups_with_google_groups(self): + """Test group management using Google groups from user_data""" + mock_config = MagicMock() + mock_config.OAUTH_GROUPS_CLAIM = "groups" + mock_config.OAUTH_BLOCKED_GROUPS = "[]" + mock_config.ENABLE_OAUTH_GROUP_CREATION = False + + # Mock user with Google groups data + mock_user = MagicMock() + mock_user.id = "user123" + + user_data = { + "google_groups": ["developers@company.com", "employees@company.com"] + } + + # Mock existing groups and user groups + mock_existing_group = MagicMock() + mock_existing_group.name = "developers@company.com" + mock_existing_group.id = "group1" + mock_existing_group.user_ids = [] + mock_existing_group.permissions = {"read": True} + mock_existing_group.description = "Developers group" + + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.Groups") as mock_groups: + + mock_groups.get_groups_by_member_id.return_value = [] + mock_groups.get_groups.return_value = [mock_existing_group] + + await self.oauth_manager.update_user_groups( + user=mock_user, + user_data=user_data, + default_permissions={"read": True} + ) + + # Should use Google groups instead of traditional claims + mock_groups.get_groups_by_member_id.assert_called_once_with("user123") + mock_groups.update_group_by_id.assert_called() diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index e0bf7582c6..328355d131 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -6,6 +6,7 @@ import sys import urllib import uuid import json +from urllib.parse import quote from datetime import datetime, timedelta import re @@ -53,6 +54,7 @@ from open_webui.config import ( OAUTH_UPDATE_PICTURE_ON_LOGIN, WEBHOOK_URL, JWT_EXPIRES_IN, + GOOGLE_OAUTH_SCOPE, AppConfig, ) from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES @@ -832,7 +834,7 @@ class OAuthManager: log.error(f"Exception during token refresh for provider {provider}: {e}") return None - def get_user_role(self, user, user_data): + async def get_user_role(self, user, user_data, provider=None, access_token=None): user_count = Users.get_num_users() if user and user_count == 1: # If the user is the only user, assign the role "admin" - actually repairs role for single user on login @@ -866,6 +868,47 @@ class OAuthManager: if isinstance(claim_data, str) or isinstance(claim_data, int): oauth_roles = [str(claim_data)] + # Check if this is Google OAuth with Cloud Identity scope + if ( + provider == "google" + and access_token + and "https://www.googleapis.com/auth/cloud-identity.groups.readonly" + in GOOGLE_OAUTH_SCOPE.value + ): + + log.debug( + "Google OAuth with Cloud Identity scope detected - fetching groups via API" + ) + user_email = user_data.get(auth_manager_config.OAUTH_EMAIL_CLAIM, "") + if user_email: + try: + google_groups = ( + await self._fetch_google_groups_via_cloud_identity( + access_token, user_email + ) + ) + # Store groups in user_data for potential group management later + if "google_groups" not in user_data: + user_data["google_groups"] = google_groups + + # Use Google groups as oauth_roles for role determination + oauth_roles = google_groups + log.debug(f"Using Google groups as roles: {oauth_roles}") + except Exception as e: + log.error(f"Failed to fetch Google groups: {e}") + # Fall back to default behavior with claims + oauth_roles = [] + + # If not using Google groups or Google groups fetch failed, use traditional claims method + if not oauth_roles: + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else [] + log.debug(f"Oauth Roles claim: {oauth_claim}") log.debug(f"User roles from oauth: {oauth_roles}") log.debug(f"Accepted user roles: {oauth_allowed_roles}") @@ -883,7 +926,9 @@ class OAuthManager: for admin_role in oauth_admin_roles: # If the user has any of the admin roles, assign the role "admin" if admin_role in oauth_roles: - log.debug("Assigned user the admin role") + log.debug( + f"Assigned user the admin role based on group: {admin_role}" + ) role = "admin" break else: @@ -896,7 +941,88 @@ class OAuthManager: return role - def update_user_groups(self, user, user_data, default_permissions): + async def _fetch_google_groups_via_cloud_identity( + self, access_token: str, user_email: str + ) -> list[str]: + """ + Fetch Google Workspace groups for a user via Cloud Identity API. + + Args: + access_token: OAuth access token with cloud-identity.groups.readonly scope + user_email: User's email address + + Returns: + List of group email addresses the user belongs to + """ + groups = [] + base_url = "https://content-cloudidentity.googleapis.com/v1/groups/-/memberships:searchTransitiveGroups" + + # Create the query string with proper URL encoding + query_string = f"member_key_id == '{user_email}' && 'cloudidentity.googleapis.com/groups.security' in labels" + encoded_query = quote(query_string) + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + page_token = "" + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + while True: + # Build URL with query parameter + url = f"{base_url}?query={encoded_query}" + + # Add page token to URL if present + if page_token: + url += f"&pageToken={quote(page_token)}" + + log.debug("Fetching Google groups via Cloud Identity API") + + async with session.get( + url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as resp: + if resp.status == 200: + data = await resp.json() + + # Extract group emails from memberships + memberships = data.get("memberships", []) + log.debug(f"Found {len(memberships)} memberships") + for membership in memberships: + group_key = membership.get("groupKey", {}) + group_email = group_key.get("id", "") + if group_email: + groups.append(group_email) + log.debug(f"Found group membership: {group_email}") + + # Check for next page + page_token = data.get("nextPageToken", "") + if not page_token: + break + else: + error_text = await resp.text() + log.error( + f"Failed to fetch Google groups (status {resp.status})" + ) + # Log error details without sensitive information + try: + error_json = json.loads(error_text) + if "error" in error_json: + log.error(f"API error: {error_json['error'].get('message', 'Unknown error')}") + except json.JSONDecodeError: + log.error("Error response contains non-JSON data") + break + + except Exception as e: + log.error(f"Error fetching Google groups via Cloud Identity API: {e}") + + log.info(f"Retrieved {len(groups)} Google groups for user {user_email}") + return groups + + async def update_user_groups( + self, user, user_data, default_permissions, provider=None, access_token=None + ): log.debug("Running OAUTH Group management") oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM @@ -907,19 +1033,27 @@ class OAuthManager: blocked_groups = [] user_oauth_groups = [] - # Nested claim search for groups claim - if oauth_claim: - claim_data = user_data - nested_claims = oauth_claim.split(".") - for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim, {}) - if isinstance(claim_data, list): - user_oauth_groups = claim_data - elif isinstance(claim_data, str): - user_oauth_groups = [claim_data] - else: - user_oauth_groups = [] + # Check if Google groups were fetched via Cloud Identity API + if "google_groups" in user_data: + log.debug( + "Using Google groups from Cloud Identity API for group management" + ) + user_oauth_groups = user_data["google_groups"] + else: + # Nested claim search for groups claim (traditional method) + if oauth_claim: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + + if isinstance(claim_data, list): + user_oauth_groups = claim_data + elif isinstance(claim_data, str): + user_oauth_groups = [claim_data] + else: + user_oauth_groups = [] user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) all_available_groups: list[GroupModel] = Groups.get_groups() @@ -1106,9 +1240,8 @@ class OAuthManager: except Exception as e: log.warning(f"OAuth callback error: {e}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Try to get userinfo from the token first, some providers include it there user_data: UserInfo = token.get("userinfo") + if ( (not user_data) or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data) @@ -1186,13 +1319,13 @@ class OAuthManager: else: log.warning(f"OAuth callback failed, email is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + email = email.lower() # If allowed domains are configured, check if the email domain is in the list if ( "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS - and email.split("@")[-1] - not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS ): log.warning( f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" @@ -1201,6 +1334,7 @@ class OAuthManager: # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) + if not user: # If the user does not exist, check if merging is enabled if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: @@ -1211,16 +1345,18 @@ class OAuthManager: Users.update_user_oauth_sub_by_id(user.id, provider_sub) if user: - determined_role = self.get_user_role(user, user_data) + determined_role = await self.get_user_role( + user, user_data, provider, token.get("access_token") + ) if user.role != determined_role: Users.update_user_role_by_id(user.id, determined_role) + # Update profile picture if enabled and different from current if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN: picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM if picture_claim: new_picture_url = user_data.get( - picture_claim, - OAUTH_PROVIDERS[provider].get("picture_url", ""), + picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") ) processed_picture_url = await self._process_picture_url( new_picture_url, token.get("access_token") @@ -1230,7 +1366,8 @@ class OAuthManager: user.id, processed_picture_url ) log.debug(f"Updated profile picture for user {user.email}") - else: + + if not user: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists @@ -1241,14 +1378,14 @@ class OAuthManager: picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM if picture_claim: picture_url = user_data.get( - picture_claim, - OAUTH_PROVIDERS[provider].get("picture_url", ""), + picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") ) picture_url = await self._process_picture_url( picture_url, token.get("access_token") ) else: picture_url = "/user.png" + username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM name = user_data.get(username_claim) @@ -1256,6 +1393,10 @@ class OAuthManager: log.warning("Username claim is missing, using email as name") name = email + role = await self.get_user_role( + None, user_data, provider, token.get("access_token") + ) + user = Auths.insert_new_auth( email=email, password=get_password_hash( @@ -1263,7 +1404,7 @@ class OAuthManager: ), # Random password, not used name=name, profile_image_url=picture_url, - role=self.get_user_role(None, user_data), + role=role, oauth_sub=provider_sub, ) @@ -1288,14 +1429,17 @@ class OAuthManager: data={"id": user.id}, expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), ) + if ( - auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT - and user.role != "admin" + auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT + and user.role != "admin" ): - self.update_user_groups( + await self.update_user_groups( user=user, user_data=user_data, default_permissions=request.app.state.config.USER_PERMISSIONS, + provider=provider, + access_token=token.get("access_token"), ) except Exception as e: diff --git a/docs/oauth-google-groups.md b/docs/oauth-google-groups.md new file mode 100644 index 0000000000..40bc62ba5b --- /dev/null +++ b/docs/oauth-google-groups.md @@ -0,0 +1,95 @@ +# Google OAuth with Cloud Identity Groups Support + +This example demonstrates how to configure Open WebUI to use Google OAuth with Cloud Identity API for group-based role management. + +## Configuration + +### Environment Variables + +```bash +# Google OAuth Configuration +GOOGLE_CLIENT_ID="your-google-client-id.apps.googleusercontent.com" +GOOGLE_CLIENT_SECRET="your-google-client-secret" + +# IMPORTANT: Include the Cloud Identity Groups scope +GOOGLE_OAUTH_SCOPE="openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly" + +# Enable OAuth features +ENABLE_OAUTH_SIGNUP=true +ENABLE_OAUTH_ROLE_MANAGEMENT=true +ENABLE_OAUTH_GROUP_MANAGEMENT=true + +# Configure admin roles using Google group emails +OAUTH_ADMIN_ROLES="admin@yourcompany.com,superadmin@yourcompany.com" +OAUTH_ALLOWED_ROLES="users@yourcompany.com,employees@yourcompany.com" + +# Optional: Configure group creation +ENABLE_OAUTH_GROUP_CREATION=true +``` + +## How It Works + +1. **Scope Detection**: When a user logs in with Google OAuth, the system checks if the `https://www.googleapis.com/auth/cloud-identity.groups.readonly` scope is present in `GOOGLE_OAUTH_SCOPE`. + +2. **Groups Fetching**: If the scope is present, the system uses the Google Cloud Identity API to fetch all groups the user belongs to, instead of relying on claims in the OAuth token. + +3. **Role Assignment**: + - If the user belongs to any group listed in `OAUTH_ADMIN_ROLES`, they get admin privileges + - If the user belongs to any group listed in `OAUTH_ALLOWED_ROLES`, they get user privileges + - Default role is applied if no matching groups are found + +4. **Group Management**: If `ENABLE_OAUTH_GROUP_MANAGEMENT` is enabled, Open WebUI groups are synchronized with Google Workspace groups. + +## Google Cloud Console Setup + +1. **Enable APIs**: + - Cloud Identity API + - Cloud Identity Groups API + +2. **OAuth 2.0 Setup**: + - Create OAuth 2.0 credentials + - Add authorized redirect URIs + - Configure consent screen + +3. **Required Scopes**: + ``` + openid + email + profile + https://www.googleapis.com/auth/cloud-identity.groups.readonly + ``` + +## Example Groups Structure + +``` +Your Google Workspace: +├── admin@yourcompany.com (Admin group) +├── superadmin@yourcompany.com (Super admin group) +├── users@yourcompany.com (Regular users) +├── employees@yourcompany.com (All employees) +└── developers@yourcompany.com (Development team) +``` + +## Fallback Behavior + +If the Cloud Identity scope is not present or the API call fails, the system falls back to the traditional method of reading roles from OAuth token claims. + +## Security Considerations + +- The Cloud Identity API requires proper authentication and authorization +- Only users with appropriate permissions can access group membership information +- Groups are fetched server-side, not exposed to the client +- Access tokens are handled securely and not logged + +## Troubleshooting + +1. **Groups not detected**: Ensure the Cloud Identity API is enabled and the OAuth client has the required scope +2. **Permission denied**: Verify the service account or OAuth client has Cloud Identity API access +3. **No admin role**: Check that the user belongs to a group listed in `OAUTH_ADMIN_ROLES` + +## Benefits Over Token Claims + +- **Real-time**: Groups are fetched fresh on each login +- **Complete**: Gets all group memberships, including nested groups +- **Accurate**: No dependency on ID token size limits +- **Flexible**: Can handle complex group hierarchies in Google Workspace \ No newline at end of file