mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 20:35:19 +00:00
533 lines
16 KiB
Python
533 lines
16 KiB
Python
import json
|
|
import logging
|
|
import time
|
|
from typing import Optional
|
|
import uuid
|
|
|
|
from open_webui.internal.db import Base, get_db
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
from open_webui.models.files import FileMetadataResponse
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
from sqlalchemy import (
|
|
BigInteger,
|
|
Column,
|
|
String,
|
|
Text,
|
|
JSON,
|
|
and_,
|
|
func,
|
|
ForeignKey,
|
|
cast,
|
|
or_,
|
|
)
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
|
|
####################
|
|
# UserGroup DB Schema
|
|
####################
|
|
|
|
|
|
class Group(Base):
|
|
__tablename__ = "group"
|
|
|
|
id = Column(Text, unique=True, primary_key=True)
|
|
user_id = Column(Text)
|
|
|
|
name = Column(Text)
|
|
description = Column(Text)
|
|
|
|
data = Column(JSON, nullable=True)
|
|
meta = Column(JSON, nullable=True)
|
|
|
|
permissions = Column(JSON, nullable=True)
|
|
|
|
created_at = Column(BigInteger)
|
|
updated_at = Column(BigInteger)
|
|
|
|
|
|
class GroupModel(BaseModel):
|
|
id: str
|
|
user_id: str
|
|
|
|
name: str
|
|
description: str
|
|
|
|
data: Optional[dict] = None
|
|
meta: Optional[dict] = None
|
|
|
|
permissions: Optional[dict] = None
|
|
|
|
created_at: int # timestamp in epoch
|
|
updated_at: int # timestamp in epoch
|
|
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
class GroupMember(Base):
|
|
__tablename__ = "group_member"
|
|
|
|
id = Column(Text, unique=True, primary_key=True)
|
|
group_id = Column(
|
|
Text,
|
|
ForeignKey("group.id", ondelete="CASCADE"),
|
|
nullable=False,
|
|
)
|
|
user_id = Column(Text, nullable=False)
|
|
created_at = Column(BigInteger, nullable=True)
|
|
updated_at = Column(BigInteger, nullable=True)
|
|
|
|
|
|
class GroupMemberModel(BaseModel):
|
|
id: str
|
|
group_id: str
|
|
user_id: str
|
|
created_at: Optional[int] = None # timestamp in epoch
|
|
updated_at: Optional[int] = None # timestamp in epoch
|
|
|
|
|
|
####################
|
|
# Forms
|
|
####################
|
|
|
|
|
|
class GroupResponse(GroupModel):
|
|
member_count: Optional[int] = None
|
|
|
|
|
|
class GroupForm(BaseModel):
|
|
name: str
|
|
description: str
|
|
permissions: Optional[dict] = None
|
|
data: Optional[dict] = None
|
|
|
|
|
|
class UserIdsForm(BaseModel):
|
|
user_ids: Optional[list[str]] = None
|
|
|
|
|
|
class GroupUpdateForm(GroupForm):
|
|
pass
|
|
|
|
|
|
class GroupListResponse(BaseModel):
|
|
items: list[GroupResponse] = []
|
|
total: int = 0
|
|
|
|
|
|
class GroupTable:
|
|
def insert_new_group(
|
|
self, user_id: str, form_data: GroupForm
|
|
) -> Optional[GroupModel]:
|
|
with get_db() as db:
|
|
group = GroupModel(
|
|
**{
|
|
**form_data.model_dump(exclude_none=True),
|
|
"id": str(uuid.uuid4()),
|
|
"user_id": user_id,
|
|
"created_at": int(time.time()),
|
|
"updated_at": int(time.time()),
|
|
}
|
|
)
|
|
|
|
try:
|
|
result = Group(**group.model_dump())
|
|
db.add(result)
|
|
db.commit()
|
|
db.refresh(result)
|
|
if result:
|
|
return GroupModel.model_validate(result)
|
|
else:
|
|
return None
|
|
|
|
except Exception:
|
|
return None
|
|
|
|
def get_all_groups(self) -> list[GroupModel]:
|
|
with get_db() as db:
|
|
groups = db.query(Group).order_by(Group.updated_at.desc()).all()
|
|
return [GroupModel.model_validate(group) for group in groups]
|
|
|
|
def get_groups(self, filter) -> list[GroupResponse]:
|
|
with get_db() as db:
|
|
query = db.query(Group)
|
|
|
|
if filter:
|
|
if "query" in filter:
|
|
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
|
if "member_id" in filter:
|
|
query = query.join(
|
|
GroupMember, GroupMember.group_id == Group.id
|
|
).filter(GroupMember.user_id == filter["member_id"])
|
|
|
|
if "share" in filter:
|
|
share_value = filter["share"]
|
|
json_share = Group.data["config"]["share"].as_boolean()
|
|
|
|
if share_value:
|
|
query = query.filter(
|
|
or_(
|
|
Group.data.is_(None),
|
|
json_share.is_(None),
|
|
json_share == True,
|
|
)
|
|
)
|
|
else:
|
|
query = query.filter(
|
|
and_(Group.data.isnot(None), json_share == False)
|
|
)
|
|
groups = query.order_by(Group.updated_at.desc()).all()
|
|
return [
|
|
GroupResponse.model_validate(
|
|
{
|
|
**GroupModel.model_validate(group).model_dump(),
|
|
"member_count": self.get_group_member_count_by_id(group.id),
|
|
}
|
|
)
|
|
for group in groups
|
|
]
|
|
|
|
def search_groups(
|
|
self, filter: Optional[dict] = None, skip: int = 0, limit: int = 30
|
|
) -> GroupListResponse:
|
|
with get_db() as db:
|
|
query = db.query(Group)
|
|
|
|
if filter:
|
|
if "query" in filter:
|
|
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
|
if "member_id" in filter:
|
|
query = query.join(
|
|
GroupMember, GroupMember.group_id == Group.id
|
|
).filter(GroupMember.user_id == filter["member_id"])
|
|
|
|
if "share" in filter:
|
|
# 'share' is stored in data JSON, support both sqlite and postgres
|
|
share_value = filter["share"]
|
|
print("Filtering by share:", share_value)
|
|
query = query.filter(
|
|
Group.data.op("->>")("share") == str(share_value)
|
|
)
|
|
|
|
total = query.count()
|
|
query = query.order_by(Group.updated_at.desc())
|
|
groups = query.offset(skip).limit(limit).all()
|
|
|
|
return {
|
|
"items": [
|
|
GroupResponse.model_validate(
|
|
**GroupModel.model_validate(group).model_dump(),
|
|
member_count=self.get_group_member_count_by_id(group.id),
|
|
)
|
|
for group in groups
|
|
],
|
|
"total": total,
|
|
}
|
|
|
|
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
|
|
with get_db() as db:
|
|
return [
|
|
GroupModel.model_validate(group)
|
|
for group in db.query(Group)
|
|
.join(GroupMember, GroupMember.group_id == Group.id)
|
|
.filter(GroupMember.user_id == user_id)
|
|
.order_by(Group.updated_at.desc())
|
|
.all()
|
|
]
|
|
|
|
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
|
|
try:
|
|
with get_db() as db:
|
|
group = db.query(Group).filter_by(id=id).first()
|
|
return GroupModel.model_validate(group) if group else None
|
|
except Exception:
|
|
return None
|
|
|
|
def get_group_user_ids_by_id(self, id: str) -> Optional[list[str]]:
|
|
with get_db() as db:
|
|
members = (
|
|
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
|
)
|
|
|
|
if not members:
|
|
return None
|
|
|
|
return [m[0] for m in members]
|
|
|
|
def get_group_user_ids_by_ids(self, group_ids: list[str]) -> dict[str, list[str]]:
|
|
with get_db() as db:
|
|
members = (
|
|
db.query(GroupMember.group_id, GroupMember.user_id)
|
|
.filter(GroupMember.group_id.in_(group_ids))
|
|
.all()
|
|
)
|
|
|
|
group_user_ids: dict[str, list[str]] = {
|
|
group_id: [] for group_id in group_ids
|
|
}
|
|
|
|
for group_id, user_id in members:
|
|
group_user_ids[group_id].append(user_id)
|
|
|
|
return group_user_ids
|
|
|
|
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str]) -> None:
|
|
with get_db() as db:
|
|
# Delete existing members
|
|
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
|
|
|
# Insert new members
|
|
now = int(time.time())
|
|
new_members = [
|
|
GroupMember(
|
|
id=str(uuid.uuid4()),
|
|
group_id=group_id,
|
|
user_id=user_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
for user_id in user_ids
|
|
]
|
|
|
|
db.add_all(new_members)
|
|
db.commit()
|
|
|
|
def get_group_member_count_by_id(self, id: str) -> int:
|
|
with get_db() as db:
|
|
count = (
|
|
db.query(func.count(GroupMember.user_id))
|
|
.filter(GroupMember.group_id == id)
|
|
.scalar()
|
|
)
|
|
return count if count else 0
|
|
|
|
def update_group_by_id(
|
|
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
|
) -> Optional[GroupModel]:
|
|
try:
|
|
with get_db() as db:
|
|
db.query(Group).filter_by(id=id).update(
|
|
{
|
|
**form_data.model_dump(exclude_none=True),
|
|
"updated_at": int(time.time()),
|
|
}
|
|
)
|
|
db.commit()
|
|
return self.get_group_by_id(id=id)
|
|
except Exception as e:
|
|
log.exception(e)
|
|
return None
|
|
|
|
def delete_group_by_id(self, id: str) -> bool:
|
|
try:
|
|
with get_db() as db:
|
|
db.query(Group).filter_by(id=id).delete()
|
|
db.commit()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def delete_all_groups(self) -> bool:
|
|
with get_db() as db:
|
|
try:
|
|
db.query(Group).delete()
|
|
db.commit()
|
|
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
def remove_user_from_all_groups(self, user_id: str) -> bool:
|
|
with get_db() as db:
|
|
try:
|
|
# Find all groups the user belongs to
|
|
groups = (
|
|
db.query(Group)
|
|
.join(GroupMember, GroupMember.group_id == Group.id)
|
|
.filter(GroupMember.user_id == user_id)
|
|
.all()
|
|
)
|
|
|
|
# Remove the user from each group
|
|
for group in groups:
|
|
db.query(GroupMember).filter(
|
|
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
|
).delete()
|
|
|
|
db.query(Group).filter_by(id=group.id).update(
|
|
{"updated_at": int(time.time())}
|
|
)
|
|
|
|
db.commit()
|
|
return True
|
|
|
|
except Exception:
|
|
db.rollback()
|
|
return False
|
|
|
|
def create_groups_by_group_names(
|
|
self, user_id: str, group_names: list[str]
|
|
) -> list[GroupModel]:
|
|
|
|
# check for existing groups
|
|
existing_groups = self.get_all_groups()
|
|
existing_group_names = {group.name for group in existing_groups}
|
|
|
|
new_groups = []
|
|
|
|
with get_db() as db:
|
|
for group_name in group_names:
|
|
if group_name not in existing_group_names:
|
|
new_group = GroupModel(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
name=group_name,
|
|
description="",
|
|
created_at=int(time.time()),
|
|
updated_at=int(time.time()),
|
|
)
|
|
try:
|
|
result = Group(**new_group.model_dump())
|
|
db.add(result)
|
|
db.commit()
|
|
db.refresh(result)
|
|
new_groups.append(GroupModel.model_validate(result))
|
|
except Exception as e:
|
|
log.exception(e)
|
|
continue
|
|
return new_groups
|
|
|
|
def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool:
|
|
with get_db() as db:
|
|
try:
|
|
now = int(time.time())
|
|
|
|
# 1. Groups that SHOULD contain the user
|
|
target_groups = (
|
|
db.query(Group).filter(Group.name.in_(group_names)).all()
|
|
)
|
|
target_group_ids = {g.id for g in target_groups}
|
|
|
|
# 2. Groups the user is CURRENTLY in
|
|
existing_group_ids = {
|
|
g.id
|
|
for g in db.query(Group)
|
|
.join(GroupMember, GroupMember.group_id == Group.id)
|
|
.filter(GroupMember.user_id == user_id)
|
|
.all()
|
|
}
|
|
|
|
# 3. Determine adds + removals
|
|
groups_to_add = target_group_ids - existing_group_ids
|
|
groups_to_remove = existing_group_ids - target_group_ids
|
|
|
|
# 4. Remove in one bulk delete
|
|
if groups_to_remove:
|
|
db.query(GroupMember).filter(
|
|
GroupMember.user_id == user_id,
|
|
GroupMember.group_id.in_(groups_to_remove),
|
|
).delete(synchronize_session=False)
|
|
|
|
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
|
{"updated_at": now}, synchronize_session=False
|
|
)
|
|
|
|
# 5. Bulk insert missing memberships
|
|
for group_id in groups_to_add:
|
|
db.add(
|
|
GroupMember(
|
|
id=str(uuid.uuid4()),
|
|
group_id=group_id,
|
|
user_id=user_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
|
|
if groups_to_add:
|
|
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
|
{"updated_at": now}, synchronize_session=False
|
|
)
|
|
|
|
db.commit()
|
|
return True
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
db.rollback()
|
|
return False
|
|
|
|
def add_users_to_group(
|
|
self, id: str, user_ids: Optional[list[str]] = None
|
|
) -> Optional[GroupModel]:
|
|
try:
|
|
with get_db() as db:
|
|
group = db.query(Group).filter_by(id=id).first()
|
|
if not group:
|
|
return None
|
|
|
|
now = int(time.time())
|
|
|
|
for user_id in user_ids or []:
|
|
try:
|
|
db.add(
|
|
GroupMember(
|
|
id=str(uuid.uuid4()),
|
|
group_id=id,
|
|
user_id=user_id,
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
db.flush() # Detect unique constraint violation early
|
|
except Exception:
|
|
db.rollback() # Clear failed INSERT
|
|
db.begin() # Start a new transaction
|
|
continue # Duplicate → ignore
|
|
|
|
group.updated_at = now
|
|
db.commit()
|
|
db.refresh(group)
|
|
|
|
return GroupModel.model_validate(group)
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
return None
|
|
|
|
def remove_users_from_group(
|
|
self, id: str, user_ids: Optional[list[str]] = None
|
|
) -> Optional[GroupModel]:
|
|
try:
|
|
with get_db() as db:
|
|
group = db.query(Group).filter_by(id=id).first()
|
|
if not group:
|
|
return None
|
|
|
|
if not user_ids:
|
|
return GroupModel.model_validate(group)
|
|
|
|
# Remove each user from group_member
|
|
for user_id in user_ids:
|
|
db.query(GroupMember).filter(
|
|
GroupMember.group_id == id, GroupMember.user_id == user_id
|
|
).delete()
|
|
|
|
# Update group timestamp
|
|
group.updated_at = int(time.time())
|
|
|
|
db.commit()
|
|
db.refresh(group)
|
|
return GroupModel.model_validate(group)
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
return None
|
|
|
|
|
|
Groups = GroupTable()
|