mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
refac
This commit is contained in:
parent
1818f2b3d9
commit
277f3a91f1
2 changed files with 59 additions and 23 deletions
|
|
@ -7,7 +7,10 @@ from open_webui.internal.db import Base, get_db
|
|||
from open_webui.models.groups import Groups
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case, cast
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
|
|
@ -293,28 +296,41 @@ class ChannelTable:
|
|||
channels = db.query(Channel).all()
|
||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
def _has_permission(self, query, filter: dict, permission: str = "read"):
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||
group_ids = filter.get("group_ids", [])
|
||||
user_id = filter.get("user_id")
|
||||
|
||||
json_group_ids = Channel.access_control[permission]["group_ids"]
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
# Public access
|
||||
conditions = []
|
||||
if group_ids or user_id:
|
||||
conditions.append(Channel.access_control.is_(None))
|
||||
conditions.extend(
|
||||
[
|
||||
Channel.access_control.is_(None),
|
||||
cast(Channel.access_control, String) == "null",
|
||||
]
|
||||
)
|
||||
|
||||
# User-level permission
|
||||
if user_id:
|
||||
conditions.append(Channel.user_id == user_id)
|
||||
|
||||
# Group-level permission
|
||||
if group_ids:
|
||||
group_conditions = []
|
||||
|
||||
for gid in group_ids:
|
||||
# CASE: gid IN JSON array
|
||||
# SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%'
|
||||
# Postgres → access_control->'write'->'group_ids' @> '[gid]'
|
||||
group_conditions.append(json_group_ids.contains([gid]))
|
||||
|
||||
if dialect_name == "sqlite":
|
||||
group_conditions.append(
|
||||
Channel.access_control[permission]["group_ids"].contains([gid])
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
group_conditions.append(
|
||||
cast(
|
||||
Channel.access_control[permission]["group_ids"],
|
||||
JSONB,
|
||||
).contains([gid])
|
||||
)
|
||||
conditions.append(or_(*group_conditions))
|
||||
|
||||
if conditions:
|
||||
|
|
@ -351,7 +367,7 @@ class ChannelTable:
|
|||
),
|
||||
)
|
||||
query = self._has_permission(
|
||||
query, {"user_id": user_id, "group_ids": user_group_ids}
|
||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}
|
||||
)
|
||||
|
||||
standard_channels = query.all()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from pydantic import BaseModel, ConfigDict
|
|||
|
||||
from sqlalchemy import String, cast, or_, and_, func
|
||||
from sqlalchemy.dialects import postgresql, sqlite
|
||||
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
|
||||
|
||||
|
|
@ -220,28 +222,41 @@ class ModelsTable:
|
|||
or has_access(user_id, permission, model.access_control, user_group_ids)
|
||||
]
|
||||
|
||||
def _has_write_permission(self, query, filter: dict):
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||
group_ids = filter.get("group_ids", [])
|
||||
user_id = filter.get("user_id")
|
||||
|
||||
json_group_ids = Model.access_control["write"]["group_ids"]
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
# Public access
|
||||
conditions = []
|
||||
if group_ids or user_id:
|
||||
conditions.append(Model.access_control.is_(None))
|
||||
conditions.extend(
|
||||
[
|
||||
Model.access_control.is_(None),
|
||||
cast(Model.access_control, String) == "null",
|
||||
]
|
||||
)
|
||||
|
||||
# User-level permission
|
||||
if user_id:
|
||||
conditions.append(Model.user_id == user_id)
|
||||
|
||||
# Group-level permission
|
||||
if group_ids:
|
||||
group_conditions = []
|
||||
|
||||
for gid in group_ids:
|
||||
# CASE: gid IN JSON array
|
||||
# SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%'
|
||||
# Postgres → access_control->'write'->'group_ids' @> '[gid]'
|
||||
group_conditions.append(json_group_ids.contains([gid]))
|
||||
|
||||
if dialect_name == "sqlite":
|
||||
group_conditions.append(
|
||||
Model.access_control[permission]["group_ids"].contains([gid])
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
group_conditions.append(
|
||||
cast(
|
||||
Model.access_control[permission]["group_ids"],
|
||||
JSONB,
|
||||
).contains([gid])
|
||||
)
|
||||
conditions.append(or_(*group_conditions))
|
||||
|
||||
if conditions:
|
||||
|
|
@ -267,15 +282,20 @@ class ModelsTable:
|
|||
)
|
||||
)
|
||||
|
||||
# Apply access control filtering
|
||||
query = self._has_write_permission(query, filter)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
query = query.filter(Model.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
query = query.filter(Model.user_id != user_id)
|
||||
|
||||
# Apply access control filtering
|
||||
query = self._has_permission(
|
||||
db,
|
||||
query,
|
||||
filter,
|
||||
permission="write",
|
||||
)
|
||||
|
||||
tag = filter.get("tag")
|
||||
if tag:
|
||||
# TODO: This is a simple implementation and should be improved for performance
|
||||
|
|
|
|||
Loading…
Reference in a new issue