wip: tags

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:21:18 +04:00
parent e1da74541b
commit f4cd24d2ca
2 changed files with 21 additions and 18 deletions

View file

@ -47,15 +47,15 @@ class TagChatIdForm(BaseModel):
class TagTable:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
async def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
async with get_db() as db:
id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag(**tag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
await db.add(result)
await db.commit()
await db.refresh(result)
if result:
return TagModel.model_validate(result)
else:
@ -64,42 +64,44 @@ class TagTable:
log.exception(f"Error inserting a new tag: {e}")
return None
def get_tag_by_name_and_user_id(
async def get_tag_by_name_and_user_id(
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
id = name.replace(" ", "_").lower()
async with get_db() as db:
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
tag = await db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception:
return None
def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
async def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
async with get_db() as db:
return [
TagModel.model_validate(tag)
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
for tag in (await db.query(Tag).filter_by(user_id=user_id).all())
]
def get_tags_by_ids_and_user_id(
async def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str
) -> list[TagModel]:
async with get_db() as db:
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
await db.query(Tag)
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
.all()
)
]
def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
try:
async with get_db() as db:
id = name.replace(" ", "_").lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
res = await db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}")
db.commit()
await db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")

View file

@ -148,9 +148,10 @@ async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)
tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
if (
tag_id != "none"
and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
and await Tags.get_tag_by_name_and_user_id(tag_name, user.id)
is None
):
Tags.insert_new_tag(tag_name, user.id)
await Tags.insert_new_tag(tag_name, user.id)
return ChatResponse(**chat.model_dump())
except Exception as e:
@ -261,7 +262,7 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
@router.get("/all/tags", response_model=list[TagModel])
async def get_all_user_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
tags = await Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
@ -556,7 +557,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
chat = await Chats.get_chat_by_id(id)
for tag in chat.meta.get("tags", []):
if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
Tags.delete_tag_by_name_and_user_id(tag, user.id)
await Tags.delete_tag_by_name_and_user_id(tag, user.id)
result = await Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
@ -694,7 +695,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
== 0
):
log.debug(f"deleting tag: {tag_id}")
Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
await Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
else:
for tag_id in chat.meta.get("tags", []):
tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)