This commit is contained in:
Timothy Jaeryang Baek 2025-12-29 00:39:43 +04:00
parent 9405628e46
commit 475dd91ed7

View file

@ -25,6 +25,10 @@ from open_webui.utils.auth import (
)
from open_webui.constants import ERROR_MESSAGES
from sqlalchemy.orm import Session
from open_webui.internal.db import get_session
log = logging.getLogger(__name__)
router = APIRouter()
@ -345,13 +349,13 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
)
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup:
"""Convert internal Group model to SCIM Group"""
member_ids = Groups.get_group_user_ids_by_id(group.id)
member_ids = Groups.get_group_user_ids_by_id(group.id, db) or []
members = []
for user_id in member_ids:
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(user_id, db=db)
if user:
members.append(
SCIMGroupMember(
@ -714,10 +718,11 @@ async def get_groups(
count: int = Query(20, ge=1, le=100),
filter: Optional[str] = None,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""List SCIM Groups"""
# Get all groups
groups_list = Groups.get_all_groups()
groups_list = Groups.get_all_groups(db=db)
# Apply pagination
total = len(groups_list)
@ -726,7 +731,7 @@ async def get_groups(
paginated_groups = groups_list[start:end]
# Convert to SCIM format
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups]
return SCIMListResponse(
totalResults=total,
@ -741,6 +746,7 @@ async def get_group(
group_id: str,
request: Request,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Get SCIM Group by ID"""
group = Groups.get_group_by_id(group_id)
@ -750,7 +756,7 @@ async def get_group(
detail=f"Group {group_id} not found",
)
return group_to_scim(group, request)
return group_to_scim(group, request, db=db)
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
@ -758,6 +764,7 @@ async def create_group(
request: Request,
group_data: SCIMGroupCreateRequest,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Create SCIM Group"""
# Extract member IDs
@ -775,14 +782,14 @@ async def create_group(
)
# Need to get the creating user's ID - we'll use the first admin
admin_user = Users.get_super_admin_user()
admin_user = Users.get_super_admin_user(db=db)
if not admin_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No admin user found",
)
new_group = Groups.insert_new_group(admin_user.id, form)
new_group = Groups.insert_new_group(admin_user.id, form, db=db)
if not new_group:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -798,12 +805,12 @@ async def create_group(
description=new_group.description,
)
Groups.update_group_by_id(new_group.id, update_form)
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
Groups.update_group_by_id(new_group.id, update_form, db=db)
Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db)
new_group = Groups.get_group_by_id(new_group.id)
new_group = Groups.get_group_by_id(new_group.id, db=db)
return group_to_scim(new_group, request)
return group_to_scim(new_group, request, db=db)
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
@ -812,9 +819,10 @@ async def update_group(
request: Request,
group_data: SCIMGroupUpdateRequest,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Update SCIM Group (full update)"""
group = Groups.get_group_by_id(group_id)
group = Groups.get_group_by_id(group_id, db=db)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -832,17 +840,17 @@ async def update_group(
# Handle members if provided
if group_data.members is not None:
member_ids = [member.value for member in group_data.members]
Groups.set_group_user_ids_by_id(group_id, member_ids)
Groups.set_group_user_ids_by_id(group_id, member_ids, db=db)
# Update group
updated_group = Groups.update_group_by_id(group_id, update_form)
updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
if not updated_group:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update group",
)
return group_to_scim(updated_group, request)
return group_to_scim(updated_group, request, db=db)
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
@ -851,9 +859,10 @@ async def patch_group(
request: Request,
patch_data: SCIMPatchRequest,
_: bool = Depends(get_scim_auth),
db: Session = Depends(get_session),
):
"""Update SCIM Group (partial update)"""
group = Groups.get_group_by_id(group_id)
group = Groups.get_group_by_id(group_id, db=db)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -878,7 +887,7 @@ async def patch_group(
elif path == "members":
# Replace all members
Groups.set_group_user_ids_by_id(
group_id, [member["value"] for member in value]
group_id, [member["value"] for member in value], db=db
)
elif op == "add":
@ -887,22 +896,24 @@ async def patch_group(
if isinstance(value, list):
for member in value:
if isinstance(member, dict) and "value" in member:
Groups.add_users_to_group(group_id, [member["value"]])
Groups.add_users_to_group(
group_id, [member["value"]], db=db
)
elif op == "remove":
if path and path.startswith("members[value eq"):
# Remove specific member
member_id = path.split('"')[1]
Groups.remove_users_from_group(group_id, [member_id])
Groups.remove_users_from_group(group_id, [member_id], db=db)
# Update group
updated_group = Groups.update_group_by_id(group_id, update_form)
updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
if not updated_group:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update group",
)
return group_to_scim(updated_group, request)
return group_to_scim(updated_group, request, db=db)
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)