diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 0364a123c7..754f6e3dfa 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -7,7 +7,10 @@ from open_webui.internal.db import Base, get_db from open_webui.models.groups import Groups from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case +from sqlalchemy.dialects.postgresql import JSONB + + +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case, cast from sqlalchemy import or_, func, select, and_, text from sqlalchemy.sql import exists @@ -293,28 +296,41 @@ class ChannelTable: channels = db.query(Channel).all() return [ChannelModel.model_validate(channel) for channel in channels] - def _has_permission(self, query, filter: dict, permission: str = "read"): + def _has_permission(self, db, query, filter: dict, permission: str = "read"): group_ids = filter.get("group_ids", []) user_id = filter.get("user_id") - json_group_ids = Channel.access_control[permission]["group_ids"] + dialect_name = db.bind.dialect.name + # Public access conditions = [] if group_ids or user_id: - conditions.append(Channel.access_control.is_(None)) + conditions.extend( + [ + Channel.access_control.is_(None), + cast(Channel.access_control, String) == "null", + ] + ) + # User-level permission if user_id: conditions.append(Channel.user_id == user_id) + # Group-level permission if group_ids: group_conditions = [] - for gid in group_ids: - # CASE: gid IN JSON array - # SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%' - # Postgres → access_control->'write'->'group_ids' @> '[gid]' - group_conditions.append(json_group_ids.contains([gid])) - + if dialect_name == "sqlite": + group_conditions.append( + Channel.access_control[permission]["group_ids"].contains([gid]) + ) + elif dialect_name == "postgresql": + group_conditions.append( + cast( + Channel.access_control[permission]["group_ids"], + JSONB, + ).contains([gid]) + ) conditions.append(or_(*group_conditions)) if conditions: @@ -351,7 +367,7 @@ class ChannelTable: ), ) query = self._has_permission( - query, {"user_id": user_id, "group_ids": user_group_ids} + db, query, {"user_id": user_id, "group_ids": user_group_ids} ) standard_channels = query.all() diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 6d3d9858bc..1c44d311ba 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -13,6 +13,8 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy import String, cast, or_, and_, func from sqlalchemy.dialects import postgresql, sqlite + +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy import BigInteger, Column, Text, JSON, Boolean @@ -220,28 +222,41 @@ class ModelsTable: or has_access(user_id, permission, model.access_control, user_group_ids) ] - def _has_write_permission(self, query, filter: dict): + def _has_permission(self, db, query, filter: dict, permission: str = "read"): group_ids = filter.get("group_ids", []) user_id = filter.get("user_id") - json_group_ids = Model.access_control["write"]["group_ids"] + dialect_name = db.bind.dialect.name + # Public access conditions = [] if group_ids or user_id: - conditions.append(Model.access_control.is_(None)) + conditions.extend( + [ + Model.access_control.is_(None), + cast(Model.access_control, String) == "null", + ] + ) + # User-level permission if user_id: conditions.append(Model.user_id == user_id) + # Group-level permission if group_ids: group_conditions = [] - for gid in group_ids: - # CASE: gid IN JSON array - # SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%' - # Postgres → access_control->'write'->'group_ids' @> '[gid]' - group_conditions.append(json_group_ids.contains([gid])) - + if dialect_name == "sqlite": + group_conditions.append( + Model.access_control[permission]["group_ids"].contains([gid]) + ) + elif dialect_name == "postgresql": + group_conditions.append( + cast( + Model.access_control[permission]["group_ids"], + JSONB, + ).contains([gid]) + ) conditions.append(or_(*group_conditions)) if conditions: @@ -267,15 +282,20 @@ class ModelsTable: ) ) - # Apply access control filtering - query = self._has_write_permission(query, filter) - view_option = filter.get("view_option") if view_option == "created": query = query.filter(Model.user_id == user_id) elif view_option == "shared": query = query.filter(Model.user_id != user_id) + # Apply access control filtering + query = self._has_permission( + db, + query, + filter, + permission="write", + ) + tag = filter.get("tag") if tag: # TODO: This is a simple implementation and should be improved for performance