From 2bccf8350d0915f69b8020934bb179c52e81b7b5 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Wed, 10 Dec 2025 15:48:42 -0500 Subject: [PATCH] enh: channel files --- .../6283dc0e4d8d_add_channel_file_table.py | 54 +++++++ backend/open_webui/models/channels.py | 140 +++++++++++++++++- src/lib/components/channel/Channel.svelte | 1 + .../components/channel/MessageInput.svelte | 20 +-- 4 files changed, 205 insertions(+), 10 deletions(-) create mode 100644 backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py diff --git a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py new file mode 100644 index 0000000000..59fe57a421 --- /dev/null +++ b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py @@ -0,0 +1,54 @@ +"""Add channel file table + +Revision ID: 6283dc0e4d8d +Revises: 3e0e00844bb0 +Create Date: 2025-12-10 15:11:39.424601 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import open_webui.internal.db + + +# revision identifiers, used by Alembic. +revision: str = "6283dc0e4d8d" +down_revision: Union[str, None] = "3e0e00844bb0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "channel_file", + sa.Column("id", sa.Text(), primary_key=True), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column( + "channel_id", + sa.Text(), + sa.ForeignKey("channel.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "file_id", + sa.Text(), + sa.ForeignKey("file.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + # indexes + sa.Index("ix_channel_file_channel_id", "channel_id"), + sa.Index("ix_channel_file_file_id", "file_id"), + sa.Index("ix_channel_file_user_id", "user_id"), + # unique constraints + sa.UniqueConstraint( + "channel_id", "file_id", name="uq_channel_file_channel_file" + ), # prevent duplicate entries + ) + + +def downgrade() -> None: + op.drop_table("channel_file") diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 754f6e3dfa..ae45d53b4c 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -10,7 +10,18 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, case, cast +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + ForeignKey, + String, + Text, + JSON, + UniqueConstraint, + case, + cast, +) from sqlalchemy import or_, func, select, and_, text from sqlalchemy.sql import exists @@ -137,6 +148,38 @@ class ChannelMemberModel(BaseModel): updated_at: Optional[int] = None # timestamp in epoch (time_ns) +class ChannelFile(Base): + __tablename__ = "channel_file" + + id = Column(Text, unique=True, primary_key=True) + + channel_id = Column( + Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False + ) + file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False) + user_id = Column(Text, nullable=False) + + created_at = Column(BigInteger, nullable=False) + updated_at = Column(BigInteger, nullable=False) + + __table_args__ = ( + UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"), + ) + + +class ChannelFileModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: str + + channel_id: str + file_id: str + user_id: str + + created_at: int # timestamp in epoch (time_ns) + updated_at: int # timestamp in epoch (time_ns) + + class ChannelWebhook(Base): __tablename__ = "channel_webhook" @@ -642,6 +685,63 @@ class ChannelTable: channel = db.query(Channel).filter(Channel.id == id).first() return ChannelModel.model_validate(channel) if channel else None + def get_channel_by_id_and_user_id( + self, id: str, user_id: str + ) -> Optional[ChannelModel]: + with get_db() as db: + # Fetch the channel + channel: Channel = ( + db.query(Channel) + .filter( + Channel.id == id, + Channel.deleted_at.is_(None), + Channel.archived_at.is_(None), + ) + .first() + ) + + if not channel: + return None + + # If the channel is a group or dm, read access requires membership (active) + if channel.type in ["group", "dm"]: + membership = ( + db.query(ChannelMember) + .filter( + ChannelMember.channel_id == id, + ChannelMember.user_id == user_id, + ChannelMember.is_active.is_(True), + ) + .first() + ) + if membership: + return ChannelModel.model_validate(channel) + else: + return None + + # For channels that are NOT group/dm, fall back to ACL-based read access + query = db.query(Channel).filter(Channel.id == id) + + # Determine user groups + user_group_ids = [ + group.id for group in Groups.get_groups_by_member_id(user_id) + ] + + # Apply ACL rules + query = self._has_permission( + db, + query, + {"user_id": user_id, "group_ids": user_group_ids}, + permission="read", + ) + + channel_allowed = query.first() + return ( + ChannelModel.model_validate(channel_allowed) + if channel_allowed + else None + ) + def update_channel_by_id( self, id: str, form_data: ChannelForm ) -> Optional[ChannelModel]: @@ -663,6 +763,44 @@ class ChannelTable: db.commit() return ChannelModel.model_validate(channel) if channel else None + def add_file_to_channel_by_id( + self, channel_id: str, file_id: str, user_id: str + ) -> Optional[ChannelFileModel]: + with get_db() as db: + channel_file = ChannelFileModel( + **{ + "id": str(uuid.uuid4()), + "channel_id": channel_id, + "file_id": file_id, + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + try: + result = ChannelFile(**channel_file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ChannelFileModel.model_validate(result) + else: + return None + except Exception: + return None + + def remove_file_from_channel_by_id(self, channel_id: str, file_id: str) -> bool: + try: + with get_db() as db: + db.query(ChannelFile).filter_by( + channel_id=channel_id, file_id=file_id + ).delete() + db.commit() + return True + except Exception: + return False + def delete_channel_by_id(self, id: str): with get_db() as db: db.query(Channel).filter(Channel.id == id).delete() diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 0fdc6e835d..f04da31352 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -365,6 +365,7 @@ bind:chatInputElement bind:replyToMessage {typingUsers} + {channel} userSuggestions={true} channelSuggestions={true} disabled={!channel?.write_access} diff --git a/src/lib/components/channel/MessageInput.svelte b/src/lib/components/channel/MessageInput.svelte index 337a3affb8..c4acd4ec49 100644 --- a/src/lib/components/channel/MessageInput.svelte +++ b/src/lib/components/channel/MessageInput.svelte @@ -42,9 +42,10 @@ import XMark from '../icons/XMark.svelte'; export let placeholder = $i18n.t('Type here...'); + export let chatInputElement; export let id = null; - export let chatInputElement; + export let channel = null; export let typingUsers = []; export let inputLoading = false; @@ -459,15 +460,16 @@ try { // During the file upload, file content is automatically extracted. // If the file is an audio file, provide the language for STT. - let metadata = null; - if ( - (file.type.startsWith('audio/') || file.type.startsWith('video/')) && + let metadata = { + channel_id: channel.id, + // If the file is an audio file, provide the language for STT. + ...((file.type.startsWith('audio/') || file.type.startsWith('video/')) && $settings?.audio?.stt?.language - ) { - metadata = { - language: $settings?.audio?.stt?.language - }; - } + ? { + language: $settings?.audio?.stt?.language + } + : {}) + }; const uploadedFile = await uploadFile(localStorage.token, file, metadata, process);