This commit is contained in:
Timothy Jaeryang Baek 2025-11-27 07:49:19 -05:00
parent 3b4d7d568b
commit d645cdbaf3

View file

@ -7,7 +7,7 @@ from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
@ -196,11 +196,23 @@ class ChannelTable:
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
with get_db() as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
match_count = func.sum(
case(
(ChannelMember.user_id.in_(unique_user_ids), 1),
else_=0,
)
)
subquery = (
db.query(ChannelMember.channel_id)
.filter(ChannelMember.user_id.in_(user_ids))
.group_by(ChannelMember.channel_id)
.having(func.count(ChannelMember.user_id) == len(user_ids))
# 1. Channel must have exactly len(user_ids) members
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
# 2. All those members must be in unique_user_ids
.having(match_count == len(unique_user_ids))
.subquery()
)