mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 06:35:20 +00:00
refac
This commit is contained in:
parent
9405628e46
commit
475dd91ed7
1 changed files with 33 additions and 22 deletions
|
|
@ -25,6 +25,10 @@ from open_webui.utils.auth import (
|
|||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import get_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -345,13 +349,13 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
|||
)
|
||||
|
||||
|
||||
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||
def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup:
|
||||
"""Convert internal Group model to SCIM Group"""
|
||||
member_ids = Groups.get_group_user_ids_by_id(group.id)
|
||||
member_ids = Groups.get_group_user_ids_by_id(group.id, db) or []
|
||||
members = []
|
||||
|
||||
for user_id in member_ids:
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
if user:
|
||||
members.append(
|
||||
SCIMGroupMember(
|
||||
|
|
@ -714,10 +718,11 @@ async def get_groups(
|
|||
count: int = Query(20, ge=1, le=100),
|
||||
filter: Optional[str] = None,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""List SCIM Groups"""
|
||||
# Get all groups
|
||||
groups_list = Groups.get_all_groups()
|
||||
groups_list = Groups.get_all_groups(db=db)
|
||||
|
||||
# Apply pagination
|
||||
total = len(groups_list)
|
||||
|
|
@ -726,7 +731,7 @@ async def get_groups(
|
|||
paginated_groups = groups_list[start:end]
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
|
||||
scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups]
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total,
|
||||
|
|
@ -741,6 +746,7 @@ async def get_group(
|
|||
group_id: str,
|
||||
request: Request,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Get SCIM Group by ID"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
|
|
@ -750,7 +756,7 @@ async def get_group(
|
|||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
return group_to_scim(group, request)
|
||||
return group_to_scim(group, request, db=db)
|
||||
|
||||
|
||||
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
|
||||
|
|
@ -758,6 +764,7 @@ async def create_group(
|
|||
request: Request,
|
||||
group_data: SCIMGroupCreateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Create SCIM Group"""
|
||||
# Extract member IDs
|
||||
|
|
@ -775,14 +782,14 @@ async def create_group(
|
|||
)
|
||||
|
||||
# Need to get the creating user's ID - we'll use the first admin
|
||||
admin_user = Users.get_super_admin_user()
|
||||
admin_user = Users.get_super_admin_user(db=db)
|
||||
if not admin_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="No admin user found",
|
||||
)
|
||||
|
||||
new_group = Groups.insert_new_group(admin_user.id, form)
|
||||
new_group = Groups.insert_new_group(admin_user.id, form, db=db)
|
||||
if not new_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -798,12 +805,12 @@ async def create_group(
|
|||
description=new_group.description,
|
||||
)
|
||||
|
||||
Groups.update_group_by_id(new_group.id, update_form)
|
||||
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
|
||||
Groups.update_group_by_id(new_group.id, update_form, db=db)
|
||||
Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db)
|
||||
|
||||
new_group = Groups.get_group_by_id(new_group.id)
|
||||
new_group = Groups.get_group_by_id(new_group.id, db=db)
|
||||
|
||||
return group_to_scim(new_group, request)
|
||||
return group_to_scim(new_group, request, db=db)
|
||||
|
||||
|
||||
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
|
|
@ -812,9 +819,10 @@ async def update_group(
|
|||
request: Request,
|
||||
group_data: SCIMGroupUpdateRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Update SCIM Group (full update)"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = Groups.get_group_by_id(group_id, db=db)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -832,17 +840,17 @@ async def update_group(
|
|||
# Handle members if provided
|
||||
if group_data.members is not None:
|
||||
member_ids = [member.value for member in group_data.members]
|
||||
Groups.set_group_user_ids_by_id(group_id, member_ids)
|
||||
Groups.set_group_user_ids_by_id(group_id, member_ids, db=db)
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
|
||||
if not updated_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update group",
|
||||
)
|
||||
|
||||
return group_to_scim(updated_group, request)
|
||||
return group_to_scim(updated_group, request, db=db)
|
||||
|
||||
|
||||
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
|
|
@ -851,9 +859,10 @@ async def patch_group(
|
|||
request: Request,
|
||||
patch_data: SCIMPatchRequest,
|
||||
_: bool = Depends(get_scim_auth),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Update SCIM Group (partial update)"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = Groups.get_group_by_id(group_id, db=db)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -878,7 +887,7 @@ async def patch_group(
|
|||
elif path == "members":
|
||||
# Replace all members
|
||||
Groups.set_group_user_ids_by_id(
|
||||
group_id, [member["value"] for member in value]
|
||||
group_id, [member["value"] for member in value], db=db
|
||||
)
|
||||
|
||||
elif op == "add":
|
||||
|
|
@ -887,22 +896,24 @@ async def patch_group(
|
|||
if isinstance(value, list):
|
||||
for member in value:
|
||||
if isinstance(member, dict) and "value" in member:
|
||||
Groups.add_users_to_group(group_id, [member["value"]])
|
||||
Groups.add_users_to_group(
|
||||
group_id, [member["value"]], db=db
|
||||
)
|
||||
elif op == "remove":
|
||||
if path and path.startswith("members[value eq"):
|
||||
# Remove specific member
|
||||
member_id = path.split('"')[1]
|
||||
Groups.remove_users_from_group(group_id, [member_id])
|
||||
Groups.remove_users_from_group(group_id, [member_id], db=db)
|
||||
|
||||
# Update group
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
||||
updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
|
||||
if not updated_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update group",
|
||||
)
|
||||
|
||||
return group_to_scim(updated_group, request)
|
||||
return group_to_scim(updated_group, request, db=db)
|
||||
|
||||
|
||||
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
|
|
|||
Loading…
Reference in a new issue