mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 04:15:25 +00:00
refac: db group
This commit is contained in:
parent
a7c7993bbf
commit
c1d760692f
5 changed files with 109 additions and 40 deletions
|
|
@ -11,7 +11,18 @@ from open_webui.models.files import FileMetadataResponse
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
JSON,
|
||||||
|
and_,
|
||||||
|
func,
|
||||||
|
ForeignKey,
|
||||||
|
cast,
|
||||||
|
or_,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -41,7 +52,6 @@ class Group(Base):
|
||||||
|
|
||||||
|
|
||||||
class GroupModel(BaseModel):
|
class GroupModel(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
id: str
|
id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
|
|
@ -56,6 +66,8 @@ class GroupModel(BaseModel):
|
||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
class GroupMember(Base):
|
class GroupMember(Base):
|
||||||
__tablename__ = "group_member"
|
__tablename__ = "group_member"
|
||||||
|
|
@ -84,17 +96,8 @@ class GroupMemberModel(BaseModel):
|
||||||
####################
|
####################
|
||||||
|
|
||||||
|
|
||||||
class GroupResponse(BaseModel):
|
class GroupResponse(GroupModel):
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
permissions: Optional[dict] = None
|
|
||||||
data: Optional[dict] = None
|
|
||||||
meta: Optional[dict] = None
|
|
||||||
member_count: Optional[int] = None
|
member_count: Optional[int] = None
|
||||||
created_at: int # timestamp in epoch
|
|
||||||
updated_at: int # timestamp in epoch
|
|
||||||
|
|
||||||
|
|
||||||
class GroupForm(BaseModel):
|
class GroupForm(BaseModel):
|
||||||
|
|
@ -112,6 +115,11 @@ class GroupUpdateForm(GroupForm):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GroupListResponse(BaseModel):
|
||||||
|
items: list[GroupResponse] = []
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
class GroupTable:
|
class GroupTable:
|
||||||
def insert_new_group(
|
def insert_new_group(
|
||||||
self, user_id: str, form_data: GroupForm
|
self, user_id: str, form_data: GroupForm
|
||||||
|
|
@ -140,13 +148,87 @@ class GroupTable:
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_groups(self) -> list[GroupModel]:
|
def get_all_groups(self) -> list[GroupModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
|
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
|
||||||
|
return [GroupModel.model_validate(group) for group in groups]
|
||||||
|
|
||||||
|
def get_groups(self, filter) -> list[GroupResponse]:
|
||||||
|
with get_db() as db:
|
||||||
|
query = db.query(Group)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
if "query" in filter:
|
||||||
|
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||||
|
if "member_id" in filter:
|
||||||
|
query = query.join(
|
||||||
|
GroupMember, GroupMember.group_id == Group.id
|
||||||
|
).filter(GroupMember.user_id == filter["member_id"])
|
||||||
|
|
||||||
|
if "share" in filter:
|
||||||
|
share_value = filter["share"]
|
||||||
|
json_share = Group.data["config"]["share"].as_boolean()
|
||||||
|
|
||||||
|
if share_value:
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
Group.data.is_(None),
|
||||||
|
json_share.is_(None),
|
||||||
|
json_share == True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.filter(
|
||||||
|
and_(Group.data.isnot(None), json_share == False)
|
||||||
|
)
|
||||||
|
groups = query.order_by(Group.updated_at.desc()).all()
|
||||||
return [
|
return [
|
||||||
GroupModel.model_validate(group)
|
GroupResponse.model_validate(
|
||||||
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
|
{
|
||||||
|
**GroupModel.model_validate(group).model_dump(),
|
||||||
|
"member_count": self.get_group_member_count_by_id(group.id),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group in groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def search_groups(
|
||||||
|
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
|
||||||
|
) -> GroupListResponse:
|
||||||
|
with get_db() as db:
|
||||||
|
query = db.query(Group)
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
if "query" in filter:
|
||||||
|
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||||
|
if "member_id" in filter:
|
||||||
|
query = query.join(
|
||||||
|
GroupMember, GroupMember.group_id == Group.id
|
||||||
|
).filter(GroupMember.user_id == filter["member_id"])
|
||||||
|
|
||||||
|
if "share" in filter:
|
||||||
|
# 'share' is stored in data JSON, support both sqlite and postgres
|
||||||
|
share_value = filter["share"]
|
||||||
|
print("Filtering by share:", share_value)
|
||||||
|
query = query.filter(
|
||||||
|
Group.data.op("->>")("share") == str(share_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
query = query.order_by(Group.updated_at.desc())
|
||||||
|
groups = query.offset(skip).limit(limit).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"items": [
|
||||||
|
GroupResponse.model_validate(
|
||||||
|
**GroupModel.model_validate(group).model_dump(),
|
||||||
|
member_count=self.get_group_member_count_by_id(group.id),
|
||||||
|
)
|
||||||
|
for group in groups
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
}
|
||||||
|
|
||||||
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
||||||
with get_db() as db:
|
with get_db() as db:
|
||||||
return [
|
return [
|
||||||
|
|
@ -293,7 +375,7 @@ class GroupTable:
|
||||||
) -> list[GroupModel]:
|
) -> list[GroupModel]:
|
||||||
|
|
||||||
# check for existing groups
|
# check for existing groups
|
||||||
existing_groups = self.get_groups()
|
existing_groups = self.get_all_groups()
|
||||||
existing_group_names = {group.name for group in existing_groups}
|
existing_group_names = {group.name for group in existing_groups}
|
||||||
|
|
||||||
new_groups = []
|
new_groups = []
|
||||||
|
|
|
||||||
|
|
@ -32,31 +32,17 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/", response_model=list[GroupResponse])
|
@router.get("/", response_model=list[GroupResponse])
|
||||||
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
|
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
|
||||||
if user.role == "admin":
|
|
||||||
groups = Groups.get_groups()
|
|
||||||
else:
|
|
||||||
groups = Groups.get_groups_by_member_id(user.id)
|
|
||||||
|
|
||||||
group_list = []
|
filter = {}
|
||||||
|
if user.role != "admin":
|
||||||
|
filter["member_id"] = user.id
|
||||||
|
|
||||||
for group in groups:
|
if share is not None:
|
||||||
if share is not None:
|
filter["share"] = share
|
||||||
# Check if the group has data and a config with share key
|
|
||||||
if (
|
|
||||||
group.data
|
|
||||||
and "share" in group.data.get("config", {})
|
|
||||||
and group.data["config"]["share"] != share
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
group_list.append(
|
groups = Groups.get_groups(filter=filter)
|
||||||
GroupResponse(
|
|
||||||
**group.model_dump(),
|
|
||||||
member_count=Groups.get_group_member_count_by_id(group.id),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return group_list
|
return groups
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
|
|
|
||||||
|
|
@ -719,7 +719,7 @@ async def get_groups(
|
||||||
):
|
):
|
||||||
"""List SCIM Groups"""
|
"""List SCIM Groups"""
|
||||||
# Get all groups
|
# Get all groups
|
||||||
groups_list = Groups.get_groups()
|
groups_list = Groups.get_all_groups()
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
total = len(groups_list)
|
total = len(groups_list)
|
||||||
|
|
|
||||||
|
|
@ -1102,7 +1102,7 @@ class OAuthManager:
|
||||||
user_oauth_groups = []
|
user_oauth_groups = []
|
||||||
|
|
||||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
all_available_groups: list[GroupModel] = Groups.get_all_groups()
|
||||||
|
|
||||||
# Create groups if they don't exist and creation is enabled
|
# Create groups if they don't exist and creation is enabled
|
||||||
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
||||||
|
|
@ -1146,7 +1146,7 @@ class OAuthManager:
|
||||||
|
|
||||||
# Refresh the list of all available groups if any were created
|
# Refresh the list of all available groups if any were created
|
||||||
if groups_created:
|
if groups_created:
|
||||||
all_available_groups = Groups.get_groups()
|
all_available_groups = Groups.get_all_groups()
|
||||||
log.debug("Refreshed list of all available groups after creation.")
|
log.debug("Refreshed list of all available groups after creation.")
|
||||||
|
|
||||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||||
|
|
|
||||||
|
|
@ -100,6 +100,7 @@
|
||||||
<EditGroupModal
|
<EditGroupModal
|
||||||
bind:show={showAddGroupModal}
|
bind:show={showAddGroupModal}
|
||||||
edit={false}
|
edit={false}
|
||||||
|
tabs={['general', 'permissions']}
|
||||||
permissions={defaultPermissions}
|
permissions={defaultPermissions}
|
||||||
onSubmit={addGroupHandler}
|
onSubmit={addGroupHandler}
|
||||||
/>
|
/>
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue