mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
refac
This commit is contained in:
parent
3b4d7d568b
commit
d645cdbaf3
1 changed files with 15 additions and 3 deletions
|
|
@ -7,7 +7,7 @@ from open_webui.internal.db import Base, get_db
|
|||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
|
|
@ -196,11 +196,23 @@ class ChannelTable:
|
|||
|
||||
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
# Ensure uniqueness in case a list with duplicates is passed
|
||||
unique_user_ids = list(set(user_ids))
|
||||
|
||||
match_count = func.sum(
|
||||
case(
|
||||
(ChannelMember.user_id.in_(unique_user_ids), 1),
|
||||
else_=0,
|
||||
)
|
||||
)
|
||||
|
||||
subquery = (
|
||||
db.query(ChannelMember.channel_id)
|
||||
.filter(ChannelMember.user_id.in_(user_ids))
|
||||
.group_by(ChannelMember.channel_id)
|
||||
.having(func.count(ChannelMember.user_id) == len(user_ids))
|
||||
# 1. Channel must have exactly len(user_ids) members
|
||||
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
|
||||
# 2. All those members must be in unique_user_ids
|
||||
.having(match_count == len(unique_user_ids))
|
||||
.subquery()
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue