This commit is contained in:
Timothy Jaeryang Baek 2025-11-30 14:06:16 -05:00
parent 1818f2b3d9
commit 277f3a91f1
2 changed files with 59 additions and 23 deletions

View file

@ -7,7 +7,10 @@ from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups from open_webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case, cast
from sqlalchemy import or_, func, select, and_, text from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
@ -293,28 +296,41 @@ class ChannelTable:
channels = db.query(Channel).all() channels = db.query(Channel).all()
return [ChannelModel.model_validate(channel) for channel in channels] return [ChannelModel.model_validate(channel) for channel in channels]
def _has_permission(self, query, filter: dict, permission: str = "read"): def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", []) group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id") user_id = filter.get("user_id")
json_group_ids = Channel.access_control[permission]["group_ids"] dialect_name = db.bind.dialect.name
# Public access
conditions = [] conditions = []
if group_ids or user_id: if group_ids or user_id:
conditions.append(Channel.access_control.is_(None)) conditions.extend(
[
Channel.access_control.is_(None),
cast(Channel.access_control, String) == "null",
]
)
# User-level permission
if user_id: if user_id:
conditions.append(Channel.user_id == user_id) conditions.append(Channel.user_id == user_id)
# Group-level permission
if group_ids: if group_ids:
group_conditions = [] group_conditions = []
for gid in group_ids: for gid in group_ids:
# CASE: gid IN JSON array if dialect_name == "sqlite":
# SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%' group_conditions.append(
# Postgres → access_control->'write'->'group_ids' @> '[gid]' Channel.access_control[permission]["group_ids"].contains([gid])
group_conditions.append(json_group_ids.contains([gid])) )
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Channel.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions)) conditions.append(or_(*group_conditions))
if conditions: if conditions:
@ -351,7 +367,7 @@ class ChannelTable:
), ),
) )
query = self._has_permission( query = self._has_permission(
query, {"user_id": user_id, "group_ids": user_group_ids} db, query, {"user_id": user_id, "group_ids": user_group_ids}
) )
standard_channels = query.all() standard_channels = query.all()

View file

@ -13,6 +13,8 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, cast, or_, and_, func from sqlalchemy import String, cast, or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
@ -220,28 +222,41 @@ class ModelsTable:
or has_access(user_id, permission, model.access_control, user_group_ids) or has_access(user_id, permission, model.access_control, user_group_ids)
] ]
def _has_write_permission(self, query, filter: dict): def _has_permission(self, db, query, filter: dict, permission: str = "read"):
group_ids = filter.get("group_ids", []) group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id") user_id = filter.get("user_id")
json_group_ids = Model.access_control["write"]["group_ids"] dialect_name = db.bind.dialect.name
# Public access
conditions = [] conditions = []
if group_ids or user_id: if group_ids or user_id:
conditions.append(Model.access_control.is_(None)) conditions.extend(
[
Model.access_control.is_(None),
cast(Model.access_control, String) == "null",
]
)
# User-level permission
if user_id: if user_id:
conditions.append(Model.user_id == user_id) conditions.append(Model.user_id == user_id)
# Group-level permission
if group_ids: if group_ids:
group_conditions = [] group_conditions = []
for gid in group_ids: for gid in group_ids:
# CASE: gid IN JSON array if dialect_name == "sqlite":
# SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%' group_conditions.append(
# Postgres → access_control->'write'->'group_ids' @> '[gid]' Model.access_control[permission]["group_ids"].contains([gid])
group_conditions.append(json_group_ids.contains([gid])) )
elif dialect_name == "postgresql":
group_conditions.append(
cast(
Model.access_control[permission]["group_ids"],
JSONB,
).contains([gid])
)
conditions.append(or_(*group_conditions)) conditions.append(or_(*group_conditions))
if conditions: if conditions:
@ -267,15 +282,20 @@ class ModelsTable:
) )
) )
# Apply access control filtering
query = self._has_write_permission(query, filter)
view_option = filter.get("view_option") view_option = filter.get("view_option")
if view_option == "created": if view_option == "created":
query = query.filter(Model.user_id == user_id) query = query.filter(Model.user_id == user_id)
elif view_option == "shared": elif view_option == "shared":
query = query.filter(Model.user_id != user_id) query = query.filter(Model.user_id != user_id)
# Apply access control filtering
query = self._has_permission(
db,
query,
filter,
permission="write",
)
tag = filter.get("tag") tag = filter.get("tag")
if tag: if tag:
# TODO: This is a simple implementation and should be improved for performance # TODO: This is a simple implementation and should be improved for performance