This commit is contained in:
Timothy Jaeryang Baek 2025-12-29 00:39:43 +04:00
parent 9405628e46
commit 475dd91ed7

View file

@ -25,6 +25,10 @@ from open_webui.utils.auth import (
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -345,13 +349,13 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
) )
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup:
"""Convert internal Group model to SCIM Group""" """Convert internal Group model to SCIM Group"""
member_ids = Groups.get_group_user_ids_by_id(group.id) member_ids = Groups.get_group_user_ids_by_id(group.id, db) or []
members = [] members = []
for user_id in member_ids: for user_id in member_ids:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id, db=db)
if user: if user:
members.append( members.append(
SCIMGroupMember( SCIMGroupMember(
@ -714,10 +718,11 @@ async def get_groups(
count: int = Query(20, ge=1, le=100), count: int = Query(20, ge=1, le=100),
filter: Optional[str] = None, filter: Optional[str] = None,
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
): ):
"""List SCIM Groups""" """List SCIM Groups"""
# Get all groups # Get all groups
groups_list = Groups.get_all_groups() groups_list = Groups.get_all_groups(db=db)
# Apply pagination # Apply pagination
total = len(groups_list) total = len(groups_list)
@ -726,7 +731,7 @@ async def get_groups(
paginated_groups = groups_list[start:end] paginated_groups = groups_list[start:end]
# Convert to SCIM format # Convert to SCIM format
scim_groups = [group_to_scim(group, request) for group in paginated_groups] scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups]
return SCIMListResponse( return SCIMListResponse(
totalResults=total, totalResults=total,
@ -741,6 +746,7 @@ async def get_group(
group_id: str, group_id: str,
request: Request, request: Request,
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
): ):
"""Get SCIM Group by ID""" """Get SCIM Group by ID"""
group = Groups.get_group_by_id(group_id) group = Groups.get_group_by_id(group_id)
@ -750,7 +756,7 @@ async def get_group(
detail=f"Group {group_id} not found", detail=f"Group {group_id} not found",
) )
return group_to_scim(group, request) return group_to_scim(group, request, db=db)
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) @router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
@ -758,6 +764,7 @@ async def create_group(
request: Request, request: Request,
group_data: SCIMGroupCreateRequest, group_data: SCIMGroupCreateRequest,
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
): ):
"""Create SCIM Group""" """Create SCIM Group"""
# Extract member IDs # Extract member IDs
@ -775,14 +782,14 @@ async def create_group(
) )
# Need to get the creating user's ID - we'll use the first admin # Need to get the creating user's ID - we'll use the first admin
admin_user = Users.get_super_admin_user() admin_user = Users.get_super_admin_user(db=db)
if not admin_user: if not admin_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No admin user found", detail="No admin user found",
) )
new_group = Groups.insert_new_group(admin_user.id, form) new_group = Groups.insert_new_group(admin_user.id, form, db=db)
if not new_group: if not new_group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -798,12 +805,12 @@ async def create_group(
description=new_group.description, description=new_group.description,
) )
Groups.update_group_by_id(new_group.id, update_form) Groups.update_group_by_id(new_group.id, update_form, db=db)
Groups.set_group_user_ids_by_id(new_group.id, member_ids) Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db)
new_group = Groups.get_group_by_id(new_group.id) new_group = Groups.get_group_by_id(new_group.id, db=db)
return group_to_scim(new_group, request) return group_to_scim(new_group, request, db=db)
@router.put("/Groups/{group_id}", response_model=SCIMGroup) @router.put("/Groups/{group_id}", response_model=SCIMGroup)
@ -812,9 +819,10 @@ async def update_group(
request: Request, request: Request,
group_data: SCIMGroupUpdateRequest, group_data: SCIMGroupUpdateRequest,
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
): ):
"""Update SCIM Group (full update)""" """Update SCIM Group (full update)"""
group = Groups.get_group_by_id(group_id) group = Groups.get_group_by_id(group_id, db=db)
if not group: if not group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -832,17 +840,17 @@ async def update_group(
# Handle members if provided # Handle members if provided
if group_data.members is not None: if group_data.members is not None:
member_ids = [member.value for member in group_data.members] member_ids = [member.value for member in group_data.members]
Groups.set_group_user_ids_by_id(group_id, member_ids) Groups.set_group_user_ids_by_id(group_id, member_ids, db=db)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
if not updated_group: if not updated_group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update group", detail="Failed to update group",
) )
return group_to_scim(updated_group, request) return group_to_scim(updated_group, request, db=db)
@router.patch("/Groups/{group_id}", response_model=SCIMGroup) @router.patch("/Groups/{group_id}", response_model=SCIMGroup)
@ -851,9 +859,10 @@ async def patch_group(
request: Request, request: Request,
patch_data: SCIMPatchRequest, patch_data: SCIMPatchRequest,
_: bool = Depends(get_scim_auth), _: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
): ):
"""Update SCIM Group (partial update)""" """Update SCIM Group (partial update)"""
group = Groups.get_group_by_id(group_id) group = Groups.get_group_by_id(group_id, db=db)
if not group: if not group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -878,7 +887,7 @@ async def patch_group(
elif path == "members": elif path == "members":
# Replace all members # Replace all members
Groups.set_group_user_ids_by_id( Groups.set_group_user_ids_by_id(
group_id, [member["value"] for member in value] group_id, [member["value"] for member in value], db=db
) )
elif op == "add": elif op == "add":
@ -887,22 +896,24 @@ async def patch_group(
if isinstance(value, list): if isinstance(value, list):
for member in value: for member in value:
if isinstance(member, dict) and "value" in member: if isinstance(member, dict) and "value" in member:
Groups.add_users_to_group(group_id, [member["value"]]) Groups.add_users_to_group(
group_id, [member["value"]], db=db
)
elif op == "remove": elif op == "remove":
if path and path.startswith("members[value eq"): if path and path.startswith("members[value eq"):
# Remove specific member # Remove specific member
member_id = path.split('"')[1] member_id = path.split('"')[1]
Groups.remove_users_from_group(group_id, [member_id]) Groups.remove_users_from_group(group_id, [member_id], db=db)
# Update group # Update group
updated_group = Groups.update_group_by_id(group_id, update_form) updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
if not updated_group: if not updated_group:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update group", detail="Failed to update group",
) )
return group_to_scim(updated_group, request) return group_to_scim(updated_group, request, db=db)
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)