diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index d71c889233..2a21dcbfb1 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -47,15 +47,15 @@ class TagChatIdForm(BaseModel): class TagTable: - def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: + async def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: async with get_db() as db: id = name.replace(" ", "_").lower() tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: result = Tag(**tag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return TagModel.model_validate(result) else: @@ -64,42 +64,44 @@ class TagTable: log.exception(f"Error inserting a new tag: {e}") return None - def get_tag_by_name_and_user_id( + async def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: id = name.replace(" ", "_").lower() async with get_db() as db: - tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() + tag = await db.query(Tag).filter_by(id=id, user_id=user_id).first() return TagModel.model_validate(tag) except Exception: return None - def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: + async def get_tags_by_user_id(self, user_id: str) -> list[TagModel]: async with get_db() as db: return [ TagModel.model_validate(tag) - for tag in (db.query(Tag).filter_by(user_id=user_id).all()) + for tag in (await db.query(Tag).filter_by(user_id=user_id).all()) ] - def get_tags_by_ids_and_user_id( + async def get_tags_by_ids_and_user_id( self, ids: list[str], user_id: str ) -> list[TagModel]: async with get_db() as db: return [ TagModel.model_validate(tag) for tag in ( - db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all() + await db.query(Tag) + .filter(Tag.id.in_(ids), Tag.user_id == user_id) + .all() ) ] - def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: + async def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool: try: async with get_db() as db: id = name.replace(" ", "_").lower() - res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() + res = await db.query(Tag).filter_by(id=id, user_id=user_id).delete() log.debug(f"res: {res}") - db.commit() + await db.commit() return True except Exception as e: log.error(f"delete_tag: {e}") diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 88f0ad625f..0751951949 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -148,9 +148,10 @@ async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user) tag_name = " ".join([word.capitalize() for word in tag_id.split("_")]) if ( tag_id != "none" - and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None + and await Tags.get_tag_by_name_and_user_id(tag_name, user.id) + is None ): - Tags.insert_new_tag(tag_name, user.id) + await Tags.insert_new_tag(tag_name, user.id) return ChatResponse(**chat.model_dump()) except Exception as e: @@ -261,7 +262,7 @@ async def get_user_archived_chats(user=Depends(get_verified_user)): @router.get("/all/tags", response_model=list[TagModel]) async def get_all_user_tags(user=Depends(get_verified_user)): try: - tags = Tags.get_tags_by_user_id(user.id) + tags = await Tags.get_tags_by_user_id(user.id) return tags except Exception as e: log.exception(e) @@ -556,7 +557,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified chat = await Chats.get_chat_by_id(id) for tag in chat.meta.get("tags", []): if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1: - Tags.delete_tag_by_name_and_user_id(tag, user.id) + await Tags.delete_tag_by_name_and_user_id(tag, user.id) result = await Chats.delete_chat_by_id_and_user_id(id, user.id) return result @@ -694,7 +695,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): == 0 ): log.debug(f"deleting tag: {tag_id}") - Tags.delete_tag_by_name_and_user_id(tag_id, user.id) + await Tags.delete_tag_by_name_and_user_id(tag_id, user.id) else: for tag_id in chat.meta.get("tags", []): tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)