mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 14:45:18 +00:00
refac
This commit is contained in:
parent
9405628e46
commit
475dd91ed7
1 changed files with 33 additions and 22 deletions
|
|
@ -25,6 +25,10 @@ from open_webui.utils.auth import (
|
||||||
)
|
)
|
||||||
from open_webui.constants import ERROR_MESSAGES
|
from open_webui.constants import ERROR_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from open_webui.internal.db import get_session
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -345,13 +349,13 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup:
|
||||||
"""Convert internal Group model to SCIM Group"""
|
"""Convert internal Group model to SCIM Group"""
|
||||||
member_ids = Groups.get_group_user_ids_by_id(group.id)
|
member_ids = Groups.get_group_user_ids_by_id(group.id, db) or []
|
||||||
members = []
|
members = []
|
||||||
|
|
||||||
for user_id in member_ids:
|
for user_id in member_ids:
|
||||||
user = Users.get_user_by_id(user_id)
|
user = Users.get_user_by_id(user_id, db=db)
|
||||||
if user:
|
if user:
|
||||||
members.append(
|
members.append(
|
||||||
SCIMGroupMember(
|
SCIMGroupMember(
|
||||||
|
|
@ -714,10 +718,11 @@ async def get_groups(
|
||||||
count: int = Query(20, ge=1, le=100),
|
count: int = Query(20, ge=1, le=100),
|
||||||
filter: Optional[str] = None,
|
filter: Optional[str] = None,
|
||||||
_: bool = Depends(get_scim_auth),
|
_: bool = Depends(get_scim_auth),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""List SCIM Groups"""
|
"""List SCIM Groups"""
|
||||||
# Get all groups
|
# Get all groups
|
||||||
groups_list = Groups.get_all_groups()
|
groups_list = Groups.get_all_groups(db=db)
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
total = len(groups_list)
|
total = len(groups_list)
|
||||||
|
|
@ -726,7 +731,7 @@ async def get_groups(
|
||||||
paginated_groups = groups_list[start:end]
|
paginated_groups = groups_list[start:end]
|
||||||
|
|
||||||
# Convert to SCIM format
|
# Convert to SCIM format
|
||||||
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
|
scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups]
|
||||||
|
|
||||||
return SCIMListResponse(
|
return SCIMListResponse(
|
||||||
totalResults=total,
|
totalResults=total,
|
||||||
|
|
@ -741,6 +746,7 @@ async def get_group(
|
||||||
group_id: str,
|
group_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
_: bool = Depends(get_scim_auth),
|
_: bool = Depends(get_scim_auth),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Get SCIM Group by ID"""
|
"""Get SCIM Group by ID"""
|
||||||
group = Groups.get_group_by_id(group_id)
|
group = Groups.get_group_by_id(group_id)
|
||||||
|
|
@ -750,7 +756,7 @@ async def get_group(
|
||||||
detail=f"Group {group_id} not found",
|
detail=f"Group {group_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return group_to_scim(group, request)
|
return group_to_scim(group, request, db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
|
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
|
||||||
|
|
@ -758,6 +764,7 @@ async def create_group(
|
||||||
request: Request,
|
request: Request,
|
||||||
group_data: SCIMGroupCreateRequest,
|
group_data: SCIMGroupCreateRequest,
|
||||||
_: bool = Depends(get_scim_auth),
|
_: bool = Depends(get_scim_auth),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Create SCIM Group"""
|
"""Create SCIM Group"""
|
||||||
# Extract member IDs
|
# Extract member IDs
|
||||||
|
|
@ -775,14 +782,14 @@ async def create_group(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Need to get the creating user's ID - we'll use the first admin
|
# Need to get the creating user's ID - we'll use the first admin
|
||||||
admin_user = Users.get_super_admin_user()
|
admin_user = Users.get_super_admin_user(db=db)
|
||||||
if not admin_user:
|
if not admin_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="No admin user found",
|
detail="No admin user found",
|
||||||
)
|
)
|
||||||
|
|
||||||
new_group = Groups.insert_new_group(admin_user.id, form)
|
new_group = Groups.insert_new_group(admin_user.id, form, db=db)
|
||||||
if not new_group:
|
if not new_group:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
|
@ -798,12 +805,12 @@ async def create_group(
|
||||||
description=new_group.description,
|
description=new_group.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
Groups.update_group_by_id(new_group.id, update_form)
|
Groups.update_group_by_id(new_group.id, update_form, db=db)
|
||||||
Groups.set_group_user_ids_by_id(new_group.id, member_ids)
|
Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db)
|
||||||
|
|
||||||
new_group = Groups.get_group_by_id(new_group.id)
|
new_group = Groups.get_group_by_id(new_group.id, db=db)
|
||||||
|
|
||||||
return group_to_scim(new_group, request)
|
return group_to_scim(new_group, request, db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
|
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
|
||||||
|
|
@ -812,9 +819,10 @@ async def update_group(
|
||||||
request: Request,
|
request: Request,
|
||||||
group_data: SCIMGroupUpdateRequest,
|
group_data: SCIMGroupUpdateRequest,
|
||||||
_: bool = Depends(get_scim_auth),
|
_: bool = Depends(get_scim_auth),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Update SCIM Group (full update)"""
|
"""Update SCIM Group (full update)"""
|
||||||
group = Groups.get_group_by_id(group_id)
|
group = Groups.get_group_by_id(group_id, db=db)
|
||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|
@ -832,17 +840,17 @@ async def update_group(
|
||||||
# Handle members if provided
|
# Handle members if provided
|
||||||
if group_data.members is not None:
|
if group_data.members is not None:
|
||||||
member_ids = [member.value for member in group_data.members]
|
member_ids = [member.value for member in group_data.members]
|
||||||
Groups.set_group_user_ids_by_id(group_id, member_ids)
|
Groups.set_group_user_ids_by_id(group_id, member_ids, db=db)
|
||||||
|
|
||||||
# Update group
|
# Update group
|
||||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
|
||||||
if not updated_group:
|
if not updated_group:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to update group",
|
detail="Failed to update group",
|
||||||
)
|
)
|
||||||
|
|
||||||
return group_to_scim(updated_group, request)
|
return group_to_scim(updated_group, request, db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
|
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
|
||||||
|
|
@ -851,9 +859,10 @@ async def patch_group(
|
||||||
request: Request,
|
request: Request,
|
||||||
patch_data: SCIMPatchRequest,
|
patch_data: SCIMPatchRequest,
|
||||||
_: bool = Depends(get_scim_auth),
|
_: bool = Depends(get_scim_auth),
|
||||||
|
db: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Update SCIM Group (partial update)"""
|
"""Update SCIM Group (partial update)"""
|
||||||
group = Groups.get_group_by_id(group_id)
|
group = Groups.get_group_by_id(group_id, db=db)
|
||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
|
@ -878,7 +887,7 @@ async def patch_group(
|
||||||
elif path == "members":
|
elif path == "members":
|
||||||
# Replace all members
|
# Replace all members
|
||||||
Groups.set_group_user_ids_by_id(
|
Groups.set_group_user_ids_by_id(
|
||||||
group_id, [member["value"] for member in value]
|
group_id, [member["value"] for member in value], db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
elif op == "add":
|
elif op == "add":
|
||||||
|
|
@ -887,22 +896,24 @@ async def patch_group(
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
for member in value:
|
for member in value:
|
||||||
if isinstance(member, dict) and "value" in member:
|
if isinstance(member, dict) and "value" in member:
|
||||||
Groups.add_users_to_group(group_id, [member["value"]])
|
Groups.add_users_to_group(
|
||||||
|
group_id, [member["value"]], db=db
|
||||||
|
)
|
||||||
elif op == "remove":
|
elif op == "remove":
|
||||||
if path and path.startswith("members[value eq"):
|
if path and path.startswith("members[value eq"):
|
||||||
# Remove specific member
|
# Remove specific member
|
||||||
member_id = path.split('"')[1]
|
member_id = path.split('"')[1]
|
||||||
Groups.remove_users_from_group(group_id, [member_id])
|
Groups.remove_users_from_group(group_id, [member_id], db=db)
|
||||||
|
|
||||||
# Update group
|
# Update group
|
||||||
updated_group = Groups.update_group_by_id(group_id, update_form)
|
updated_group = Groups.update_group_by_id(group_id, update_form, db=db)
|
||||||
if not updated_group:
|
if not updated_group:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to update group",
|
detail="Failed to update group",
|
||||||
)
|
)
|
||||||
|
|
||||||
return group_to_scim(updated_group, request)
|
return group_to_scim(updated_group, request, db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue