refac: db group

This commit is contained in:
Timothy Jaeryang Baek 2025-11-28 22:48:58 -05:00
parent a7c7993bbf
commit c1d760692f
5 changed files with 109 additions and 40 deletions

View file

@ -11,7 +11,18 @@ from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON, func, ForeignKey from sqlalchemy import (
BigInteger,
Column,
String,
Text,
JSON,
and_,
func,
ForeignKey,
cast,
or_,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -41,7 +52,6 @@ class Group(Base):
class GroupModel(BaseModel): class GroupModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
user_id: str user_id: str
@ -56,6 +66,8 @@ class GroupModel(BaseModel):
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
class GroupMember(Base): class GroupMember(Base):
__tablename__ = "group_member" __tablename__ = "group_member"
@ -84,17 +96,8 @@ class GroupMemberModel(BaseModel):
#################### ####################
class GroupResponse(BaseModel): class GroupResponse(GroupModel):
id: str
user_id: str
name: str
description: str
permissions: Optional[dict] = None
data: Optional[dict] = None
meta: Optional[dict] = None
member_count: Optional[int] = None member_count: Optional[int] = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
class GroupForm(BaseModel): class GroupForm(BaseModel):
@ -112,6 +115,11 @@ class GroupUpdateForm(GroupForm):
pass pass
class GroupListResponse(BaseModel):
items: list[GroupResponse] = []
total: int = 0
class GroupTable: class GroupTable:
def insert_new_group( def insert_new_group(
self, user_id: str, form_data: GroupForm self, user_id: str, form_data: GroupForm
@ -140,13 +148,87 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_groups(self) -> list[GroupModel]: def get_all_groups(self) -> list[GroupModel]:
with get_db() as db: with get_db() as db:
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
return [GroupModel.model_validate(group) for group in groups]
def get_groups(self, filter) -> list[GroupResponse]:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
share_value = filter["share"]
json_share = Group.data["config"]["share"].as_boolean()
if share_value:
query = query.filter(
or_(
Group.data.is_(None),
json_share.is_(None),
json_share == True,
)
)
else:
query = query.filter(
and_(Group.data.isnot(None), json_share == False)
)
groups = query.order_by(Group.updated_at.desc()).all()
return [ return [
GroupModel.model_validate(group) GroupResponse.model_validate(
for group in db.query(Group).order_by(Group.updated_at.desc()).all() {
**GroupModel.model_validate(group).model_dump(),
"member_count": self.get_group_member_count_by_id(group.id),
}
)
for group in groups
] ]
def search_groups(
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
) -> GroupListResponse:
with get_db() as db:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
query = query.join(
GroupMember, GroupMember.group_id == Group.id
).filter(GroupMember.user_id == filter["member_id"])
if "share" in filter:
# 'share' is stored in data JSON, support both sqlite and postgres
share_value = filter["share"]
print("Filtering by share:", share_value)
query = query.filter(
Group.data.op("->>")("share") == str(share_value)
)
total = query.count()
query = query.order_by(Group.updated_at.desc())
groups = query.offset(skip).limit(limit).all()
return {
"items": [
GroupResponse.model_validate(
**GroupModel.model_validate(group).model_dump(),
member_count=self.get_group_member_count_by_id(group.id),
)
for group in groups
],
"total": total,
}
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]: def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db: with get_db() as db:
return [ return [
@ -293,7 +375,7 @@ class GroupTable:
) -> list[GroupModel]: ) -> list[GroupModel]:
# check for existing groups # check for existing groups
existing_groups = self.get_groups() existing_groups = self.get_all_groups()
existing_group_names = {group.name for group in existing_groups} existing_group_names = {group.name for group in existing_groups}
new_groups = [] new_groups = []

View file

@ -32,31 +32,17 @@ router = APIRouter()
@router.get("/", response_model=list[GroupResponse]) @router.get("/", response_model=list[GroupResponse])
async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)):
if user.role == "admin":
groups = Groups.get_groups()
else:
groups = Groups.get_groups_by_member_id(user.id)
group_list = [] filter = {}
if user.role != "admin":
filter["member_id"] = user.id
for group in groups: if share is not None:
if share is not None: filter["share"] = share
# Check if the group has data and a config with share key
if (
group.data
and "share" in group.data.get("config", {})
and group.data["config"]["share"] != share
):
continue
group_list.append( groups = Groups.get_groups(filter=filter)
GroupResponse(
**group.model_dump(),
member_count=Groups.get_group_member_count_by_id(group.id),
)
)
return group_list return groups
############################ ############################

View file

@ -719,7 +719,7 @@ async def get_groups(
): ):
"""List SCIM Groups""" """List SCIM Groups"""
# Get all groups # Get all groups
groups_list = Groups.get_groups() groups_list = Groups.get_all_groups()
# Apply pagination # Apply pagination
total = len(groups_list) total = len(groups_list)

View file

@ -1102,7 +1102,7 @@ class OAuthManager:
user_oauth_groups = [] user_oauth_groups = []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups() all_available_groups: list[GroupModel] = Groups.get_all_groups()
# Create groups if they don't exist and creation is enabled # Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
@ -1146,7 +1146,7 @@ class OAuthManager:
# Refresh the list of all available groups if any were created # Refresh the list of all available groups if any were created
if groups_created: if groups_created:
all_available_groups = Groups.get_groups() all_available_groups = Groups.get_all_groups()
log.debug("Refreshed list of all available groups after creation.") log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}") log.debug(f"Oauth Groups claim: {oauth_claim}")

View file

@ -100,6 +100,7 @@
<EditGroupModal <EditGroupModal
bind:show={showAddGroupModal} bind:show={showAddGroupModal}
edit={false} edit={false}
tabs={['general', 'permissions']}
permissions={defaultPermissions} permissions={defaultPermissions}
onSubmit={addGroupHandler} onSubmit={addGroupHandler}
/> />