mirror of
https://github.com/open-webui/open-webui.git
synced 2026-01-02 14:45:18 +00:00
refac
This commit is contained in:
parent
5d1459df16
commit
2453b75ff0
5 changed files with 28 additions and 24 deletions
|
|
@ -382,14 +382,14 @@ async def get_channel_by_id(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
)
|
)
|
||||||
|
|
||||||
write_access = has_access(
|
write_access = has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False, db=db
|
||||||
)
|
)
|
||||||
|
|
||||||
user_count = len(get_users_with_access("read", channel.access_control))
|
user_count = len(get_users_with_access("read", channel.access_control))
|
||||||
|
|
@ -762,7 +762,7 @@ async def get_channel_messages(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -830,7 +830,7 @@ async def get_pinned_channel_messages(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1228,7 +1228,7 @@ async def get_channel_message(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1282,7 +1282,7 @@ async def get_channel_message_data(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1336,7 +1336,7 @@ async def pin_channel_message(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1402,7 +1402,7 @@ async def get_channel_thread_messages(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1476,7 +1476,7 @@ async def update_message_by_id(
|
||||||
user.role != "admin"
|
user.role != "admin"
|
||||||
and message.user_id != user.id
|
and message.user_id != user.id
|
||||||
and not has_access(
|
and not has_access(
|
||||||
user.id, type="read", access_control=channel.access_control
|
user.id, type="read", access_control=channel.access_control, db=db
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -1543,7 +1543,7 @@ async def add_reaction_to_message(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1618,7 +1618,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if user.role != "admin" and not has_access(
|
if user.role != "admin" and not has_access(
|
||||||
user.id, type="write", access_control=channel.access_control, strict=False
|
user.id, type="write", access_control=channel.access_control, strict=False, db=db
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
|
@ -1713,6 +1713,7 @@ async def delete_message_by_id(
|
||||||
type="write",
|
type="write",
|
||||||
access_control=channel.access_control,
|
access_control=channel.access_control,
|
||||||
strict=False,
|
strict=False,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
||||||
|
|
@ -421,14 +421,14 @@ async def get_all_models(request: Request, user: UserModel = None):
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
async def get_filtered_models(models, user):
|
async def get_filtered_models(models, user, db=None):
|
||||||
# Filter models based on user access control
|
# Filter models based on user access control
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
for model in models.get("models", []):
|
for model in models.get("models", []):
|
||||||
model_info = Models.get_model_by_id(model["model"])
|
model_info = Models.get_model_by_id(model["model"])
|
||||||
if model_info:
|
if model_info:
|
||||||
if user.id == model_info.user_id or has_access(
|
if user.id == model_info.user_id or has_access(
|
||||||
user.id, type="read", access_control=model_info.access_control
|
user.id, type="read", access_control=model_info.access_control, db=db
|
||||||
):
|
):
|
||||||
filtered_models.append(model)
|
filtered_models.append(model)
|
||||||
return filtered_models
|
return filtered_models
|
||||||
|
|
|
||||||
|
|
@ -453,14 +453,14 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
|
||||||
async def get_filtered_models(models, user):
|
async def get_filtered_models(models, user, db=None):
|
||||||
# Filter models based on user access control
|
# Filter models based on user access control
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
for model in models.get("data", []):
|
for model in models.get("data", []):
|
||||||
model_info = Models.get_model_by_id(model["id"])
|
model_info = Models.get_model_by_id(model["id"])
|
||||||
if model_info:
|
if model_info:
|
||||||
if user.id == model_info.user_id or has_access(
|
if user.id == model_info.user_id or has_access(
|
||||||
user.id, type="read", access_control=model_info.access_control
|
user.id, type="read", access_control=model_info.access_control, db=db
|
||||||
):
|
):
|
||||||
filtered_models.append(model)
|
filtered_models.append(model)
|
||||||
return filtered_models
|
return filtered_models
|
||||||
|
|
|
||||||
|
|
@ -360,7 +360,7 @@ def check_model_access(user, model):
|
||||||
raise Exception("Model not found")
|
raise Exception("Model not found")
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_models(models, user):
|
def get_filtered_models(models, user, db=None):
|
||||||
# Filter out models that the user does not have access to
|
# Filter out models that the user does not have access to
|
||||||
if (
|
if (
|
||||||
user.role == "user"
|
user.role == "user"
|
||||||
|
|
@ -373,7 +373,7 @@ def get_filtered_models(models, user):
|
||||||
}
|
}
|
||||||
|
|
||||||
filtered_models = []
|
filtered_models = []
|
||||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||||
for model in models:
|
for model in models:
|
||||||
if model.get("arena"):
|
if model.get("arena"):
|
||||||
if has_access(
|
if has_access(
|
||||||
|
|
|
||||||
|
|
@ -1132,7 +1132,7 @@ class OAuthManager:
|
||||||
|
|
||||||
return role
|
return role
|
||||||
|
|
||||||
def update_user_groups(self, user, user_data, default_permissions):
|
def update_user_groups(self, user, user_data, default_permissions, db=None):
|
||||||
log.debug("Running OAUTH Group management")
|
log.debug("Running OAUTH Group management")
|
||||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||||
|
|
||||||
|
|
@ -1161,8 +1161,8 @@ class OAuthManager:
|
||||||
else:
|
else:
|
||||||
user_oauth_groups = []
|
user_oauth_groups = []
|
||||||
|
|
||||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id, db=db)
|
||||||
all_available_groups: list[GroupModel] = Groups.get_all_groups()
|
all_available_groups: list[GroupModel] = Groups.get_all_groups(db=db)
|
||||||
|
|
||||||
# Create groups if they don't exist and creation is enabled
|
# Create groups if they don't exist and creation is enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
||||||
|
|
@ -1188,7 +1188,7 @@ class OAuthManager:
|
||||||
)
|
)
|
||||||
# Use determined creator ID (admin or fallback to current user)
|
# Use determined creator ID (admin or fallback to current user)
|
||||||
created_group = Groups.insert_new_group(
|
created_group = Groups.insert_new_group(
|
||||||
creator_id, new_group_form
|
creator_id, new_group_form, db=db
|
||||||
)
|
)
|
||||||
if created_group:
|
if created_group:
|
||||||
log.info(
|
log.info(
|
||||||
|
|
@ -1206,7 +1206,7 @@ class OAuthManager:
|
||||||
|
|
||||||
# Refresh the list of all available groups if any were created
|
# Refresh the list of all available groups if any were created
|
||||||
if groups_created:
|
if groups_created:
|
||||||
all_available_groups = Groups.get_all_groups()
|
all_available_groups = Groups.get_all_groups(db=db)
|
||||||
log.debug("Refreshed list of all available groups after creation.")
|
log.debug("Refreshed list of all available groups after creation.")
|
||||||
|
|
||||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||||
|
|
@ -1227,7 +1227,7 @@ class OAuthManager:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
f"Removing user from group {group_model.name} as it is no longer in their oauth groups"
|
||||||
)
|
)
|
||||||
Groups.remove_users_from_group(group_model.id, [user.id])
|
Groups.remove_users_from_group(group_model.id, [user.id], db=db)
|
||||||
|
|
||||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||||
group_permissions = group_model.permissions
|
group_permissions = group_model.permissions
|
||||||
|
|
@ -1242,6 +1242,7 @@ class OAuthManager:
|
||||||
permissions=group_permissions,
|
permissions=group_permissions,
|
||||||
),
|
),
|
||||||
overwrite=False,
|
overwrite=False,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user to new groups
|
# Add user to new groups
|
||||||
|
|
@ -1257,7 +1258,7 @@ class OAuthManager:
|
||||||
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
f"Adding user to group {group_model.name} as it was found in their oauth groups"
|
||||||
)
|
)
|
||||||
|
|
||||||
Groups.add_users_to_group(group_model.id, [user.id])
|
Groups.add_users_to_group(group_model.id, [user.id], db=db)
|
||||||
|
|
||||||
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
# In case a group is created, but perms are never assigned to the group by hitting "save"
|
||||||
group_permissions = group_model.permissions
|
group_permissions = group_model.permissions
|
||||||
|
|
@ -1272,6 +1273,7 @@ class OAuthManager:
|
||||||
permissions=group_permissions,
|
permissions=group_permissions,
|
||||||
),
|
),
|
||||||
overwrite=False,
|
overwrite=False,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_picture_url(
|
async def _process_picture_url(
|
||||||
|
|
@ -1566,6 +1568,7 @@ class OAuthManager:
|
||||||
user=user,
|
user=user,
|
||||||
user_data=user_data,
|
user_data=user_data,
|
||||||
default_permissions=request.app.state.config.USER_PERMISSIONS,
|
default_permissions=request.app.state.config.USER_PERMISSIONS,
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue