This commit is contained in:
Timothy Jaeryang Baek 2025-11-27 07:49:19 -05:00
parent 3b4d7d568b
commit d645cdbaf3

View file

@ -7,7 +7,7 @@ from open_webui.internal.db import Base, get_db
from open_webui.utils.access_control import has_access from open_webui.utils.access_control import has_access
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case
from sqlalchemy import or_, func, select, and_, text from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
@ -196,11 +196,23 @@ class ChannelTable:
def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]: def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]:
with get_db() as db: with get_db() as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
match_count = func.sum(
case(
(ChannelMember.user_id.in_(unique_user_ids), 1),
else_=0,
)
)
subquery = ( subquery = (
db.query(ChannelMember.channel_id) db.query(ChannelMember.channel_id)
.filter(ChannelMember.user_id.in_(user_ids))
.group_by(ChannelMember.channel_id) .group_by(ChannelMember.channel_id)
.having(func.count(ChannelMember.user_id) == len(user_ids)) # 1. Channel must have exactly len(user_ids) members
.having(func.count(ChannelMember.user_id) == len(unique_user_ids))
# 2. All those members must be in unique_user_ids
.having(match_count == len(unique_user_ids))
.subquery() .subquery()
) )