diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index f3ca874538..50ad88b546 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -382,14 +382,14 @@ async def get_channel_by_id( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) write_access = has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, type="write", access_control=channel.access_control, strict=False, db=db ) user_count = len(get_users_with_access("read", channel.access_control)) @@ -762,7 +762,7 @@ async def get_channel_messages( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -830,7 +830,7 @@ async def get_pinned_channel_messages( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1228,7 +1228,7 @@ async def get_channel_message( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1282,7 +1282,7 @@ async def get_channel_message_data( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1336,7 +1336,7 @@ async def pin_channel_message( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1402,7 +1402,7 @@ async def get_channel_thread_messages( ) else: if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1476,7 +1476,7 @@ async def update_message_by_id( user.role != "admin" and message.user_id != user.id and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="read", access_control=channel.access_control, db=db ) ): raise HTTPException( @@ -1543,7 +1543,7 @@ async def add_reaction_to_message( ) else: if user.role != "admin" and not has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, type="write", access_control=channel.access_control, strict=False, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1618,7 +1618,7 @@ async def remove_reaction_by_id_and_user_id_and_name( ) else: if user.role != "admin" and not has_access( - user.id, type="write", access_control=channel.access_control, strict=False + user.id, type="write", access_control=channel.access_control, strict=False, db=db ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -1713,6 +1713,7 @@ async def delete_message_by_id( type="write", access_control=channel.access_control, strict=False, + db=db, ) ): raise HTTPException( diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index b1bd4850fe..e6b2e9d9e9 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -421,14 +421,14 @@ async def get_all_models(request: Request, user: UserModel = None): return models -async def get_filtered_models(models, user): +async def get_filtered_models(models, user, db=None): # Filter models based on user access control filtered_models = [] for model in models.get("models", []): model_info = Models.get_model_by_id(model["model"]) if model_info: if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ): filtered_models.append(model) return filtered_models diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index d5bdf505a9..3e8a10b9fa 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -453,14 +453,14 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: return responses -async def get_filtered_models(models, user): +async def get_filtered_models(models, user, db=None): # Filter models based on user access control filtered_models = [] for model in models.get("data", []): model_info = Models.get_model_by_id(model["id"]) if model_info: if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control + user.id, type="read", access_control=model_info.access_control, db=db ): filtered_models.append(model) return filtered_models diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index 431542d340..6ef3c16091 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -360,7 +360,7 @@ def check_model_access(user, model): raise Exception("Model not found") -def get_filtered_models(models, user): +def get_filtered_models(models, user, db=None): # Filter out models that the user does not have access to if ( user.role == "user" @@ -373,7 +373,7 @@ def get_filtered_models(models, user): } filtered_models = [] - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} for model in models: if model.get("arena"): if has_access( diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index ffdd99b878..ef90c74c24 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1132,7 +1132,7 @@ class OAuthManager: return role - def update_user_groups(self, user, user_data, default_permissions): + def update_user_groups(self, user, user_data, default_permissions, db=None): log.debug("Running OAUTH Group management") oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM @@ -1161,8 +1161,8 @@ class OAuthManager: else: user_oauth_groups = [] - user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) - all_available_groups: list[GroupModel] = Groups.get_all_groups() + user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id, db=db) + all_available_groups: list[GroupModel] = Groups.get_all_groups(db=db) # Create groups if they don't exist and creation is enabled if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: @@ -1188,7 +1188,7 @@ class OAuthManager: ) # Use determined creator ID (admin or fallback to current user) created_group = Groups.insert_new_group( - creator_id, new_group_form + creator_id, new_group_form, db=db ) if created_group: log.info( @@ -1206,7 +1206,7 @@ class OAuthManager: # Refresh the list of all available groups if any were created if groups_created: - all_available_groups = Groups.get_all_groups() + all_available_groups = Groups.get_all_groups(db=db) log.debug("Refreshed list of all available groups after creation.") log.debug(f"Oauth Groups claim: {oauth_claim}") @@ -1227,7 +1227,7 @@ class OAuthManager: log.debug( f"Removing user from group {group_model.name} as it is no longer in their oauth groups" ) - Groups.remove_users_from_group(group_model.id, [user.id]) + Groups.remove_users_from_group(group_model.id, [user.id], db=db) # In case a group is created, but perms are never assigned to the group by hitting "save" group_permissions = group_model.permissions @@ -1242,6 +1242,7 @@ class OAuthManager: permissions=group_permissions, ), overwrite=False, + db=db, ) # Add user to new groups @@ -1257,7 +1258,7 @@ class OAuthManager: f"Adding user to group {group_model.name} as it was found in their oauth groups" ) - Groups.add_users_to_group(group_model.id, [user.id]) + Groups.add_users_to_group(group_model.id, [user.id], db=db) # In case a group is created, but perms are never assigned to the group by hitting "save" group_permissions = group_model.permissions @@ -1272,6 +1273,7 @@ class OAuthManager: permissions=group_permissions, ), overwrite=False, + db=db, ) async def _process_picture_url( @@ -1566,6 +1568,7 @@ class OAuthManager: user=user, user_data=user_data, default_permissions=request.app.state.config.USER_PERMISSIONS, + db=db, ) except Exception as e: