mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
refac
This commit is contained in:
parent
1818f2b3d9
commit
277f3a91f1
2 changed files with 59 additions and 23 deletions
|
|
@ -7,7 +7,10 @@ from open_webui.internal.db import Base, get_db
|
||||||
from open_webui.models.groups import Groups
|
from open_webui.models.groups import Groups
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case, cast
|
||||||
from sqlalchemy import or_, func, select, and_, text
|
from sqlalchemy import or_, func, select, and_, text
|
||||||
from sqlalchemy.sql import exists
|
from sqlalchemy.sql import exists
|
||||||
|
|
||||||
|
|
@ -293,28 +296,41 @@ class ChannelTable:
|
||||||
channels = db.query(Channel).all()
|
channels = db.query(Channel).all()
|
||||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||||
|
|
||||||
def _has_permission(self, query, filter: dict, permission: str = "read"):
|
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||||
group_ids = filter.get("group_ids", [])
|
group_ids = filter.get("group_ids", [])
|
||||||
user_id = filter.get("user_id")
|
user_id = filter.get("user_id")
|
||||||
|
|
||||||
json_group_ids = Channel.access_control[permission]["group_ids"]
|
dialect_name = db.bind.dialect.name
|
||||||
|
|
||||||
|
# Public access
|
||||||
conditions = []
|
conditions = []
|
||||||
if group_ids or user_id:
|
if group_ids or user_id:
|
||||||
conditions.append(Channel.access_control.is_(None))
|
conditions.extend(
|
||||||
|
[
|
||||||
|
Channel.access_control.is_(None),
|
||||||
|
cast(Channel.access_control, String) == "null",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# User-level permission
|
||||||
if user_id:
|
if user_id:
|
||||||
conditions.append(Channel.user_id == user_id)
|
conditions.append(Channel.user_id == user_id)
|
||||||
|
|
||||||
|
# Group-level permission
|
||||||
if group_ids:
|
if group_ids:
|
||||||
group_conditions = []
|
group_conditions = []
|
||||||
|
|
||||||
for gid in group_ids:
|
for gid in group_ids:
|
||||||
# CASE: gid IN JSON array
|
if dialect_name == "sqlite":
|
||||||
# SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%'
|
group_conditions.append(
|
||||||
# Postgres → access_control->'write'->'group_ids' @> '[gid]'
|
Channel.access_control[permission]["group_ids"].contains([gid])
|
||||||
group_conditions.append(json_group_ids.contains([gid]))
|
)
|
||||||
|
elif dialect_name == "postgresql":
|
||||||
|
group_conditions.append(
|
||||||
|
cast(
|
||||||
|
Channel.access_control[permission]["group_ids"],
|
||||||
|
JSONB,
|
||||||
|
).contains([gid])
|
||||||
|
)
|
||||||
conditions.append(or_(*group_conditions))
|
conditions.append(or_(*group_conditions))
|
||||||
|
|
||||||
if conditions:
|
if conditions:
|
||||||
|
|
@ -351,7 +367,7 @@ class ChannelTable:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
query = self._has_permission(
|
query = self._has_permission(
|
||||||
query, {"user_id": user_id, "group_ids": user_group_ids}
|
db, query, {"user_id": user_id, "group_ids": user_group_ids}
|
||||||
)
|
)
|
||||||
|
|
||||||
standard_channels = query.all()
|
standard_channels = query.all()
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from sqlalchemy import String, cast, or_, and_, func
|
from sqlalchemy import String, cast, or_, and_, func
|
||||||
from sqlalchemy.dialects import postgresql, sqlite
|
from sqlalchemy.dialects import postgresql, sqlite
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -220,28 +222,41 @@ class ModelsTable:
|
||||||
or has_access(user_id, permission, model.access_control, user_group_ids)
|
or has_access(user_id, permission, model.access_control, user_group_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _has_write_permission(self, query, filter: dict):
|
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||||
group_ids = filter.get("group_ids", [])
|
group_ids = filter.get("group_ids", [])
|
||||||
user_id = filter.get("user_id")
|
user_id = filter.get("user_id")
|
||||||
|
|
||||||
json_group_ids = Model.access_control["write"]["group_ids"]
|
dialect_name = db.bind.dialect.name
|
||||||
|
|
||||||
|
# Public access
|
||||||
conditions = []
|
conditions = []
|
||||||
if group_ids or user_id:
|
if group_ids or user_id:
|
||||||
conditions.append(Model.access_control.is_(None))
|
conditions.extend(
|
||||||
|
[
|
||||||
|
Model.access_control.is_(None),
|
||||||
|
cast(Model.access_control, String) == "null",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# User-level permission
|
||||||
if user_id:
|
if user_id:
|
||||||
conditions.append(Model.user_id == user_id)
|
conditions.append(Model.user_id == user_id)
|
||||||
|
|
||||||
|
# Group-level permission
|
||||||
if group_ids:
|
if group_ids:
|
||||||
group_conditions = []
|
group_conditions = []
|
||||||
|
|
||||||
for gid in group_ids:
|
for gid in group_ids:
|
||||||
# CASE: gid IN JSON array
|
if dialect_name == "sqlite":
|
||||||
# SQLite → json_extract(access_control, '$.write.group_ids') LIKE '%gid%'
|
group_conditions.append(
|
||||||
# Postgres → access_control->'write'->'group_ids' @> '[gid]'
|
Model.access_control[permission]["group_ids"].contains([gid])
|
||||||
group_conditions.append(json_group_ids.contains([gid]))
|
)
|
||||||
|
elif dialect_name == "postgresql":
|
||||||
|
group_conditions.append(
|
||||||
|
cast(
|
||||||
|
Model.access_control[permission]["group_ids"],
|
||||||
|
JSONB,
|
||||||
|
).contains([gid])
|
||||||
|
)
|
||||||
conditions.append(or_(*group_conditions))
|
conditions.append(or_(*group_conditions))
|
||||||
|
|
||||||
if conditions:
|
if conditions:
|
||||||
|
|
@ -267,15 +282,20 @@ class ModelsTable:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply access control filtering
|
|
||||||
query = self._has_write_permission(query, filter)
|
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get("view_option")
|
||||||
if view_option == "created":
|
if view_option == "created":
|
||||||
query = query.filter(Model.user_id == user_id)
|
query = query.filter(Model.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == "shared":
|
||||||
query = query.filter(Model.user_id != user_id)
|
query = query.filter(Model.user_id != user_id)
|
||||||
|
|
||||||
|
# Apply access control filtering
|
||||||
|
query = self._has_permission(
|
||||||
|
db,
|
||||||
|
query,
|
||||||
|
filter,
|
||||||
|
permission="write",
|
||||||
|
)
|
||||||
|
|
||||||
tag = filter.get("tag")
|
tag = filter.get("tag")
|
||||||
if tag:
|
if tag:
|
||||||
# TODO: This is a simple implementation and should be improved for performance
|
# TODO: This is a simple implementation and should be improved for performance
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue