wip: memories

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:41:35 +04:00
parent c512bf3559
commit 652dcabd86
2 changed files with 31 additions and 31 deletions

View file

@ -37,7 +37,7 @@ class MemoryModel(BaseModel):
class MemoriesTable:
def insert_new_memory(
async def insert_new_memory(
self,
user_id: str,
content: str,
@ -55,15 +55,15 @@ class MemoriesTable:
}
)
result = Memory(**memory.model_dump())
db.add(result)
db.commit()
db.refresh(result)
await db.add(result)
await db.commit()
await db.refresh(result)
if result:
return MemoryModel.model_validate(result)
else:
return None
def update_memory_by_id_and_user_id(
async def update_memory_by_id_and_user_id(
self,
id: str,
user_id: str,
@ -71,73 +71,73 @@ class MemoriesTable:
) -> Optional[MemoryModel]:
async with get_db() as db:
try:
memory = db.get(Memory, id)
memory = await db.get(Memory, id)
if not memory or memory.user_id != user_id:
return None
memory.content = content
memory.updated_at = int(time.time())
db.commit()
return self.get_memory_by_id(id)
await db.commit()
return await self.get_memory_by_id(id)
except Exception:
return None
def get_memories(self) -> list[MemoryModel]:
async def get_memories(self) -> list[MemoryModel]:
async with get_db() as db:
try:
memories = db.query(Memory).all()
memories = await db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except Exception:
return None
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
async def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
async with get_db() as db:
try:
memories = db.query(Memory).filter_by(user_id=user_id).all()
memories = await db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except Exception:
return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
async def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
async with get_db() as db:
try:
memory = db.get(Memory, id)
memory = await db.get(Memory, id)
return MemoryModel.model_validate(memory)
except Exception:
return None
def delete_memory_by_id(self, id: str) -> bool:
async def delete_memory_by_id(self, id: str) -> bool:
async with get_db() as db:
try:
db.query(Memory).filter_by(id=id).delete()
db.commit()
await db.query(Memory).filter_by(id=id).delete()
await db.commit()
return True
except Exception:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
async def delete_memories_by_user_id(self, user_id: str) -> bool:
async with get_db() as db:
try:
db.query(Memory).filter_by(user_id=user_id).delete()
db.commit()
await db.query(Memory).filter_by(user_id=user_id).delete()
await db.commit()
return True
except Exception:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
async def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
async with get_db() as db:
try:
memory = db.get(Memory, id)
memory = await db.get(Memory, id)
if not memory or memory.user_id != user_id:
return None
# Delete the memory
db.delete(memory)
db.commit()
await db.delete(memory)
await db.commit()
return True
except Exception:

View file

@ -27,7 +27,7 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=list[MemoryModel])
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
return await Memories.get_memories_by_user_id(user.id)
############################
@ -49,7 +49,7 @@ async def add_memory(
form_data: AddMemoryForm,
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory = await Memories.insert_new_memory(user.id, form_data.content)
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
@ -82,7 +82,7 @@ class QueryMemoryForm(BaseModel):
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
memories = Memories.get_memories_by_user_id(user.id)
memories = await Memories.get_memories_by_user_id(user.id)
if not memories:
raise HTTPException(status_code=404, detail="No memories found for user")
@ -104,7 +104,7 @@ async def reset_memory_from_vector_db(
):
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
memories = await Memories.get_memories_by_user_id(user.id)
VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}",
items=[
@ -133,7 +133,7 @@ async def reset_memory_from_vector_db(
@router.delete("/delete/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id)
result = await Memories.delete_memories_by_user_id(user.id)
if result:
try:
@ -157,7 +157,7 @@ async def update_memory_by_id(
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
):
memory = Memories.update_memory_by_id_and_user_id(
memory = await Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content
)
if memory is None:
@ -191,7 +191,7 @@ async def update_memory_by_id(
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
result = await Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
VECTOR_DB_CLIENT.delete(