diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index bd2fd3d4f7..d26b499700 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -25,6 +25,10 @@ from open_webui.utils.auth import ( ) from open_webui.constants import ERROR_MESSAGES + +from sqlalchemy.orm import Session +from open_webui.internal.db import get_session + log = logging.getLogger(__name__) router = APIRouter() @@ -345,13 +349,13 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser: ) -def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: +def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup: """Convert internal Group model to SCIM Group""" - member_ids = Groups.get_group_user_ids_by_id(group.id) + member_ids = Groups.get_group_user_ids_by_id(group.id, db) or [] members = [] for user_id in member_ids: - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(user_id, db=db) if user: members.append( SCIMGroupMember( @@ -714,10 +718,11 @@ async def get_groups( count: int = Query(20, ge=1, le=100), filter: Optional[str] = None, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """List SCIM Groups""" # Get all groups - groups_list = Groups.get_all_groups() + groups_list = Groups.get_all_groups(db=db) # Apply pagination total = len(groups_list) @@ -726,7 +731,7 @@ async def get_groups( paginated_groups = groups_list[start:end] # Convert to SCIM format - scim_groups = [group_to_scim(group, request) for group in paginated_groups] + scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups] return SCIMListResponse( totalResults=total, @@ -741,6 +746,7 @@ async def get_group( group_id: str, request: Request, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Get SCIM Group by ID""" group = Groups.get_group_by_id(group_id) @@ -750,7 +756,7 @@ async def get_group( detail=f"Group {group_id} not found", ) - return group_to_scim(group, request) + return group_to_scim(group, request, db=db) @router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) @@ -758,6 +764,7 @@ async def create_group( request: Request, group_data: SCIMGroupCreateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Create SCIM Group""" # Extract member IDs @@ -775,14 +782,14 @@ async def create_group( ) # Need to get the creating user's ID - we'll use the first admin - admin_user = Users.get_super_admin_user() + admin_user = Users.get_super_admin_user(db=db) if not admin_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No admin user found", ) - new_group = Groups.insert_new_group(admin_user.id, form) + new_group = Groups.insert_new_group(admin_user.id, form, db=db) if not new_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -798,12 +805,12 @@ async def create_group( description=new_group.description, ) - Groups.update_group_by_id(new_group.id, update_form) - Groups.set_group_user_ids_by_id(new_group.id, member_ids) + Groups.update_group_by_id(new_group.id, update_form, db=db) + Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db) - new_group = Groups.get_group_by_id(new_group.id) + new_group = Groups.get_group_by_id(new_group.id, db=db) - return group_to_scim(new_group, request) + return group_to_scim(new_group, request, db=db) @router.put("/Groups/{group_id}", response_model=SCIMGroup) @@ -812,9 +819,10 @@ async def update_group( request: Request, group_data: SCIMGroupUpdateRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM Group (full update)""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -832,17 +840,17 @@ async def update_group( # Handle members if provided if group_data.members is not None: member_ids = [member.value for member in group_data.members] - Groups.set_group_user_ids_by_id(group_id, member_ids) + Groups.set_group_user_ids_by_id(group_id, member_ids, db=db) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form) + updated_group = Groups.update_group_by_id(group_id, update_form, db=db) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update group", ) - return group_to_scim(updated_group, request) + return group_to_scim(updated_group, request, db=db) @router.patch("/Groups/{group_id}", response_model=SCIMGroup) @@ -851,9 +859,10 @@ async def patch_group( request: Request, patch_data: SCIMPatchRequest, _: bool = Depends(get_scim_auth), + db: Session = Depends(get_session), ): """Update SCIM Group (partial update)""" - group = Groups.get_group_by_id(group_id) + group = Groups.get_group_by_id(group_id, db=db) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -878,7 +887,7 @@ async def patch_group( elif path == "members": # Replace all members Groups.set_group_user_ids_by_id( - group_id, [member["value"] for member in value] + group_id, [member["value"] for member in value], db=db ) elif op == "add": @@ -887,22 +896,24 @@ async def patch_group( if isinstance(value, list): for member in value: if isinstance(member, dict) and "value" in member: - Groups.add_users_to_group(group_id, [member["value"]]) + Groups.add_users_to_group( + group_id, [member["value"]], db=db + ) elif op == "remove": if path and path.startswith("members[value eq"): # Remove specific member member_id = path.split('"')[1] - Groups.remove_users_from_group(group_id, [member_id]) + Groups.remove_users_from_group(group_id, [member_id], db=db) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form) + updated_group = Groups.update_group_by_id(group_id, update_form, db=db) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update group", ) - return group_to_scim(updated_group, request) + return group_to_scim(updated_group, request, db=db) @router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)