mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-11 20:05:19 +00:00
wip: tags
This commit is contained in:
parent
e1da74541b
commit
f4cd24d2ca
2 changed files with 21 additions and 18 deletions
|
|
@ -47,15 +47,15 @@ class TagChatIdForm(BaseModel):
|
|||
|
||||
|
||||
class TagTable:
|
||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
async def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
async with get_db() as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
try:
|
||||
result = Tag(**tag.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
await db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
if result:
|
||||
return TagModel.model_validate(result)
|
||||
else:
|
||||
|
|
@ -64,42 +64,44 @@ class TagTable:
|
|||
log.exception(f"Error inserting a new tag: {e}")
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
async def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str
|
||||
) -> Optional[TagModel]:
|
||||
try:
|
||||
id = name.replace(" ", "_").lower()
|
||||
async with get_db() as db:
|
||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||
tag = await db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||
return TagModel.model_validate(tag)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
||||
async def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
||||
async with get_db() as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
for tag in (await db.query(Tag).filter_by(user_id=user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_ids_and_user_id(
|
||||
async def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str
|
||||
) -> list[TagModel]:
|
||||
async with get_db() as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
||||
await db.query(Tag)
|
||||
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||
res = await db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||
log.debug(f"res: {res}")
|
||||
db.commit()
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tag: {e}")
|
||||
|
|
|
|||
|
|
@ -148,9 +148,10 @@ async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)
|
|||
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
||||
if (
|
||||
tag_id != "none"
|
||||
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
|
||||
and await Tags.get_tag_by_name_and_user_id(tag_name, user.id)
|
||||
is None
|
||||
):
|
||||
Tags.insert_new_tag(tag_name, user.id)
|
||||
await Tags.insert_new_tag(tag_name, user.id)
|
||||
|
||||
return ChatResponse(**chat.model_dump())
|
||||
except Exception as e:
|
||||
|
|
@ -261,7 +262,7 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
|
|||
@router.get("/all/tags", response_model=list[TagModel])
|
||||
async def get_all_user_tags(user=Depends(get_verified_user)):
|
||||
try:
|
||||
tags = Tags.get_tags_by_user_id(user.id)
|
||||
tags = await Tags.get_tags_by_user_id(user.id)
|
||||
return tags
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
|
@ -556,7 +557,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
|||
chat = await Chats.get_chat_by_id(id)
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
await Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
result = await Chats.delete_chat_by_id_and_user_id(id, user.id)
|
||||
return result
|
||||
|
|
@ -694,7 +695,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
|||
== 0
|
||||
):
|
||||
log.debug(f"deleting tag: {tag_id}")
|
||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||
await Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||
else:
|
||||
for tag_id in chat.meta.get("tags", []):
|
||||
tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)
|
||||
|
|
|
|||
Loading…
Reference in a new issue