diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index a78c3b05e0..5d452b0216 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -7,7 +7,7 @@ from open_webui.internal.db import Base, get_db from open_webui.utils.access_control import has_access from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case from sqlalchemy import or_, func, select, and_, text from sqlalchemy.sql import exists @@ -196,11 +196,23 @@ class ChannelTable: def get_dm_channel_by_user_ids(self, user_ids: list[str]) -> Optional[ChannelModel]: with get_db() as db: + # Ensure uniqueness in case a list with duplicates is passed + unique_user_ids = list(set(user_ids)) + + match_count = func.sum( + case( + (ChannelMember.user_id.in_(unique_user_ids), 1), + else_=0, + ) + ) + subquery = ( db.query(ChannelMember.channel_id) - .filter(ChannelMember.user_id.in_(user_ids)) .group_by(ChannelMember.channel_id) - .having(func.count(ChannelMember.user_id) == len(user_ids)) + # 1. Channel must have exactly len(user_ids) members + .having(func.count(ChannelMember.user_id) == len(unique_user_ids)) + # 2. All those members must be in unique_user_ids + .having(match_count == len(unique_user_ids)) .subquery() )