mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-12 12:25:20 +00:00
wip: tags
This commit is contained in:
parent
e1da74541b
commit
f4cd24d2ca
2 changed files with 21 additions and 18 deletions
|
|
@ -47,15 +47,15 @@ class TagChatIdForm(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TagTable:
|
class TagTable:
|
||||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
async def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
id = name.replace(" ", "_").lower()
|
id = name.replace(" ", "_").lower()
|
||||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||||
try:
|
try:
|
||||||
result = Tag(**tag.model_dump())
|
result = Tag(**tag.model_dump())
|
||||||
db.add(result)
|
await db.add(result)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(result)
|
await db.refresh(result)
|
||||||
if result:
|
if result:
|
||||||
return TagModel.model_validate(result)
|
return TagModel.model_validate(result)
|
||||||
else:
|
else:
|
||||||
|
|
@ -64,42 +64,44 @@ class TagTable:
|
||||||
log.exception(f"Error inserting a new tag: {e}")
|
log.exception(f"Error inserting a new tag: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tag_by_name_and_user_id(
|
async def get_tag_by_name_and_user_id(
|
||||||
self, name: str, user_id: str
|
self, name: str, user_id: str
|
||||||
) -> Optional[TagModel]:
|
) -> Optional[TagModel]:
|
||||||
try:
|
try:
|
||||||
id = name.replace(" ", "_").lower()
|
id = name.replace(" ", "_").lower()
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
tag = await db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||||
return TagModel.model_validate(tag)
|
return TagModel.model_validate(tag)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
async def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
return [
|
return [
|
||||||
TagModel.model_validate(tag)
|
TagModel.model_validate(tag)
|
||||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
for tag in (await db.query(Tag).filter_by(user_id=user_id).all())
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_tags_by_ids_and_user_id(
|
async def get_tags_by_ids_and_user_id(
|
||||||
self, ids: list[str], user_id: str
|
self, ids: list[str], user_id: str
|
||||||
) -> list[TagModel]:
|
) -> list[TagModel]:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
return [
|
return [
|
||||||
TagModel.model_validate(tag)
|
TagModel.model_validate(tag)
|
||||||
for tag in (
|
for tag in (
|
||||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
await db.query(Tag)
|
||||||
|
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
|
||||||
|
.all()
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
async with get_db() as db:
|
async with get_db() as db:
|
||||||
id = name.replace(" ", "_").lower()
|
id = name.replace(" ", "_").lower()
|
||||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
res = await db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||||
log.debug(f"res: {res}")
|
log.debug(f"res: {res}")
|
||||||
db.commit()
|
await db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"delete_tag: {e}")
|
log.error(f"delete_tag: {e}")
|
||||||
|
|
|
||||||
|
|
@ -148,9 +148,10 @@ async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)
|
||||||
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
|
||||||
if (
|
if (
|
||||||
tag_id != "none"
|
tag_id != "none"
|
||||||
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
|
and await Tags.get_tag_by_name_and_user_id(tag_name, user.id)
|
||||||
|
is None
|
||||||
):
|
):
|
||||||
Tags.insert_new_tag(tag_name, user.id)
|
await Tags.insert_new_tag(tag_name, user.id)
|
||||||
|
|
||||||
return ChatResponse(**chat.model_dump())
|
return ChatResponse(**chat.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -261,7 +262,7 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
|
||||||
@router.get("/all/tags", response_model=list[TagModel])
|
@router.get("/all/tags", response_model=list[TagModel])
|
||||||
async def get_all_user_tags(user=Depends(get_verified_user)):
|
async def get_all_user_tags(user=Depends(get_verified_user)):
|
||||||
try:
|
try:
|
||||||
tags = Tags.get_tags_by_user_id(user.id)
|
tags = await Tags.get_tags_by_user_id(user.id)
|
||||||
return tags
|
return tags
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
@ -556,7 +557,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
|
||||||
chat = await Chats.get_chat_by_id(id)
|
chat = await Chats.get_chat_by_id(id)
|
||||||
for tag in chat.meta.get("tags", []):
|
for tag in chat.meta.get("tags", []):
|
||||||
if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
|
||||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
await Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||||
|
|
||||||
result = await Chats.delete_chat_by_id_and_user_id(id, user.id)
|
result = await Chats.delete_chat_by_id_and_user_id(id, user.id)
|
||||||
return result
|
return result
|
||||||
|
|
@ -694,7 +695,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
|
||||||
== 0
|
== 0
|
||||||
):
|
):
|
||||||
log.debug(f"deleting tag: {tag_id}")
|
log.debug(f"deleting tag: {tag_id}")
|
||||||
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
await Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
|
||||||
else:
|
else:
|
||||||
for tag_id in chat.meta.get("tags", []):
|
for tag_id in chat.meta.get("tags", []):
|
||||||
tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)
|
tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue