This commit is contained in:
Timothy Jaeryang Baek 2025-12-29 01:31:27 +04:00
parent 5d1459df16
commit 2453b75ff0
5 changed files with 28 additions and 24 deletions

View file

@ -382,14 +382,14 @@ async def get_channel_by_id(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
)
write_access = has_access(
user.id, type="write", access_control=channel.access_control, strict=False
user.id, type="write", access_control=channel.access_control, strict=False, db=db
)
user_count = len(get_users_with_access("read", channel.access_control))
@ -762,7 +762,7 @@ async def get_channel_messages(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -830,7 +830,7 @@ async def get_pinned_channel_messages(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1228,7 +1228,7 @@ async def get_channel_message(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1282,7 +1282,7 @@ async def get_channel_message_data(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1336,7 +1336,7 @@ async def pin_channel_message(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1402,7 +1402,7 @@ async def get_channel_thread_messages(
)
else:
if user.role != "admin" and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1476,7 +1476,7 @@ async def update_message_by_id(
user.role != "admin"
and message.user_id != user.id
and not has_access(
user.id, type="read", access_control=channel.access_control
user.id, type="read", access_control=channel.access_control, db=db
)
):
raise HTTPException(
@ -1543,7 +1543,7 @@ async def add_reaction_to_message(
)
else:
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
user.id, type="write", access_control=channel.access_control, strict=False, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1618,7 +1618,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
)
else:
if user.role != "admin" and not has_access(
user.id, type="write", access_control=channel.access_control, strict=False
user.id, type="write", access_control=channel.access_control, strict=False, db=db
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -1713,6 +1713,7 @@ async def delete_message_by_id(
type="write",
access_control=channel.access_control,
strict=False,
db=db,
)
):
raise HTTPException(

View file

@ -421,14 +421,14 @@ async def get_all_models(request: Request, user: UserModel = None):
return models
async def get_filtered_models(models, user):
async def get_filtered_models(models, user, db=None):
# Filter models based on user access control
filtered_models = []
for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
user.id, type="read", access_control=model_info.access_control, db=db
):
filtered_models.append(model)
return filtered_models

View file

@ -453,14 +453,14 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
return responses
async def get_filtered_models(models, user):
async def get_filtered_models(models, user, db=None):
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
user.id, type="read", access_control=model_info.access_control, db=db
):
filtered_models.append(model)
return filtered_models

View file

@ -360,7 +360,7 @@ def check_model_access(user, model):
raise Exception("Model not found")
def get_filtered_models(models, user):
def get_filtered_models(models, user, db=None):
# Filter out models that the user does not have access to
if (
user.role == "user"
@ -373,7 +373,7 @@ def get_filtered_models(models, user):
}
filtered_models = []
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
for model in models:
if model.get("arena"):
if has_access(

View file

@ -1132,7 +1132,7 @@ class OAuthManager:
return role
def update_user_groups(self, user, user_data, default_permissions):
def update_user_groups(self, user, user_data, default_permissions, db=None):
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
@ -1161,8 +1161,8 @@ class OAuthManager:
else:
user_oauth_groups = []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_all_groups()
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id, db=db)
all_available_groups: list[GroupModel] = Groups.get_all_groups(db=db)
# Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
@ -1188,7 +1188,7 @@ class OAuthManager:
)
# Use determined creator ID (admin or fallback to current user)
created_group = Groups.insert_new_group(
creator_id, new_group_form
creator_id, new_group_form, db=db
)
if created_group:
log.info(
@ -1206,7 +1206,7 @@ class OAuthManager:
# Refresh the list of all available groups if any were created
if groups_created:
all_available_groups = Groups.get_all_groups()
all_available_groups = Groups.get_all_groups(db=db)
log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}")
@ -1227,7 +1227,7 @@ class OAuthManager:
log.debug(
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
)
Groups.remove_users_from_group(group_model.id, [user.id])
Groups.remove_users_from_group(group_model.id, [user.id], db=db)
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
@ -1242,6 +1242,7 @@ class OAuthManager:
permissions=group_permissions,
),
overwrite=False,
db=db,
)
# Add user to new groups
@ -1257,7 +1258,7 @@ class OAuthManager:
f"Adding user to group {group_model.name} as it was found in their oauth groups"
)
Groups.add_users_to_group(group_model.id, [user.id])
Groups.add_users_to_group(group_model.id, [user.id], db=db)
# In case a group is created, but perms are never assigned to the group by hitting "save"
group_permissions = group_model.permissions
@ -1272,6 +1273,7 @@ class OAuthManager:
permissions=group_permissions,
),
overwrite=False,
db=db,
)
async def _process_picture_url(
@ -1566,6 +1568,7 @@ class OAuthManager:
user=user,
user_data=user_data,
default_permissions=request.app.state.config.USER_PERMISSIONS,
db=db,
)
except Exception as e: