refac: db group

This commit is contained in:
Timothy Jaeryang Baek 2025-11-28 22:48:58 -05:00
parent a7c7993bbf
commit c1d760692f
5 changed files with 109 additions and 40 deletions

View file

@ -11,7 +11,18 @@ from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey
from sqlalchemy import (
BigInteger,
Column,
String,
Text,
JSON,
and_,
func,
ForeignKey,
cast,
or_,
)
log = logging.getLogger(__name__)
@ -41,7 +52,6 @@ class Group(Base):
class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
@ -56,6 +66,8 @@ class GroupModel(BaseModel):
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class GroupMember(Base):
__tablename__ = "group_member"
@ -84,17 +96,8 @@ class GroupMemberModel(BaseModel):
####################
class GroupResponse(BaseModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class GroupResponse(GroupModel):
member_count: Optional[int] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel):
@ -112,6 +115,11 @@ class GroupUpdateForm(GroupForm):
pass
class GroupListResponse(BaseModel):
items: list[GroupResponse] = []
total: int = 0
class GroupTable:
def insert_new_group(
self, user_id: str, form_data: GroupForm
@ -140,13 +148,87 @@ class GroupTable:
except Exception:
return None
def get_groups(self) -> list[GroupModel]:
def get_all_groups(self) -> list[GroupModel]:
with get_db() as db:
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
return [GroupModel.model_validate(group) for group in groups]
def get_groups(self, filter) -> list[GroupResponse]:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
share_value = filter["share"]
json_share = Group.data["config"]["share"].as_boolean()
if share_value:
query = query.filter(
or_(
Group.data.is_(None),
json_share.is_(None),
json_share == True,
)
)
else:
query = query.filter(
and_(Group.data.isnot(None), json_share == False)
)
groups = query.order_by(Group.updated_at.desc()).all()
return [
GroupModel.model_validate(group)
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
GroupResponse.model_validate(
{
**GroupModel.model_validate(group).model_dump(),
"member_count": self.get_group_member_count_by_id(group.id),
}
)
for group in groups
]
def search_groups(
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
) -> GroupListResponse:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
# 'share' is stored in data JSON, support both sqlite and postgres
share_value = filter["share"]
print("Filtering by share:", share_value)
query = query.filter(
Group.data.op("->>")("share") == str(share_value)
)
total = query.count()
query = query.order_by(Group.updated_at.desc())
groups = query.offset(skip).limit(limit).all()
return {
"items": [
GroupResponse.model_validate(
**GroupModel.model_validate(group).model_dump(),
member_count=self.get_group_member_count_by_id(group.id),
)
for group in groups
],
"total": total,
}
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
@ -293,7 +375,7 @@ class GroupTable:
) -> list[GroupModel]:
# check for existing groups
existing_groups = self.get_groups()
existing_groups = self.get_all_groups()
existing_group_names = {group.name for group in existing_groups}
new_groups = []

View file

@ -32,31 +32,17 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse])
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
if user.role == "admin":
groups = Groups.get_groups()
else:
groups = Groups.get_groups_by_member_id(user.id)
group_list = []
filter = {}
if user.role != "admin":
filter["member_id"] = user.id
for group in groups:
if share is not None:
# Check if the group has data and a config with share key
if (
group.data
and "share" in group.data.get("config", {})
and group.data["config"]["share"] != share
):
continue
filter["share"] = share
group_list.append(
GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
)
groups = Groups.get_groups(filter=filter)
return group_list
return groups
############################

View file

@ -719,7 +719,7 @@ async def get_groups(
):
"""List SCIM Groups"""
# Get all groups
groups_list = Groups.get_groups()
groups_list = Groups.get_all_groups()
# Apply pagination
total = len(groups_list)

View file

@ -1102,7 +1102,7 @@ class OAuthManager:
user_oauth_groups = []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
all_available_groups: list[GroupModel] = Groups.get_all_groups()
# Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
@ -1146,7 +1146,7 @@ class OAuthManager:
# Refresh the list of all available groups if any were created
if groups_created:
all_available_groups = Groups.get_groups()
all_available_groups = Groups.get_all_groups()
log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}")

View file

@ -100,6 +100,7 @@
<EditGroupModal
bind:show={showAddGroupModal}
edit={false}
tabs={['general', 'permissions']}
permissions={defaultPermissions}
onSubmit={addGroupHandler}
/>