wip: memories

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:41:35 +04:00
parent c512bf3559
commit 652dcabd86
2 changed files with 31 additions and 31 deletions

View file

@ -37,7 +37,7 @@ class MemoryModel(BaseModel):
class MemoriesTable: class MemoriesTable:
def insert_new_memory( async def insert_new_memory(
self, self,
user_id: str, user_id: str,
content: str, content: str,
@ -55,15 +55,15 @@ class MemoriesTable:
} }
) )
result = Memory(**memory.model_dump()) result = Memory(**memory.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return MemoryModel.model_validate(result) return MemoryModel.model_validate(result)
else: else:
return None return None
def update_memory_by_id_and_user_id( async def update_memory_by_id_and_user_id(
self, self,
id: str, id: str,
user_id: str, user_id: str,
@ -71,73 +71,73 @@ class MemoriesTable:
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
async with get_db() as db: async with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = await db.get(Memory, id)
if not memory or memory.user_id != user_id: if not memory or memory.user_id != user_id:
return None return None
memory.content = content memory.content = content
memory.updated_at = int(time.time()) memory.updated_at = int(time.time())
db.commit() await db.commit()
return self.get_memory_by_id(id) return await self.get_memory_by_id(id)
except Exception: except Exception:
return None return None
def get_memories(self) -> list[MemoryModel]: async def get_memories(self) -> list[MemoryModel]:
async with get_db() as db: async with get_db() as db:
try: try:
memories = db.query(Memory).all() memories = await db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except Exception: except Exception:
return None return None
def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]: async def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
async with get_db() as db: async with get_db() as db:
try: try:
memories = db.query(Memory).filter_by(user_id=user_id).all() memories = await db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories] return [MemoryModel.model_validate(memory) for memory in memories]
except Exception: except Exception:
return None return None
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: async def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
async with get_db() as db: async with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = await db.get(Memory, id)
return MemoryModel.model_validate(memory) return MemoryModel.model_validate(memory)
except Exception: except Exception:
return None return None
def delete_memory_by_id(self, id: str) -> bool: async def delete_memory_by_id(self, id: str) -> bool:
async with get_db() as db: async with get_db() as db:
try: try:
db.query(Memory).filter_by(id=id).delete() await db.query(Memory).filter_by(id=id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_memories_by_user_id(self, user_id: str) -> bool: async def delete_memories_by_user_id(self, user_id: str) -> bool:
async with get_db() as db: async with get_db() as db:
try: try:
db.query(Memory).filter_by(user_id=user_id).delete() await db.query(Memory).filter_by(user_id=user_id).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:
return False return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: async def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
async with get_db() as db: async with get_db() as db:
try: try:
memory = db.get(Memory, id) memory = await db.get(Memory, id)
if not memory or memory.user_id != user_id: if not memory or memory.user_id != user_id:
return None return None
# Delete the memory # Delete the memory
db.delete(memory) await db.delete(memory)
db.commit() await db.commit()
return True return True
except Exception: except Exception:

View file

@ -27,7 +27,7 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=list[MemoryModel]) @router.get("/", response_model=list[MemoryModel])
async def get_memories(user=Depends(get_verified_user)): async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id) return await Memories.get_memories_by_user_id(user.id)
############################ ############################
@ -49,7 +49,7 @@ async def add_memory(
form_data: AddMemoryForm, form_data: AddMemoryForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = await Memories.insert_new_memory(user.id, form_data.content)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
@ -82,7 +82,7 @@ class QueryMemoryForm(BaseModel):
async def query_memory( async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
): ):
memories = Memories.get_memories_by_user_id(user.id) memories = await Memories.get_memories_by_user_id(user.id)
if not memories: if not memories:
raise HTTPException(status_code=404, detail="No memories found for user") raise HTTPException(status_code=404, detail="No memories found for user")
@ -104,7 +104,7 @@ async def reset_memory_from_vector_db(
): ):
VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}") VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id) memories = await Memories.get_memories_by_user_id(user.id)
VECTOR_DB_CLIENT.upsert( VECTOR_DB_CLIENT.upsert(
collection_name=f"user-memory-{user.id}", collection_name=f"user-memory-{user.id}",
items=[ items=[
@ -133,7 +133,7 @@ async def reset_memory_from_vector_db(
@router.delete("/delete/user", response_model=bool) @router.delete("/delete/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)): async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id) result = await Memories.delete_memories_by_user_id(user.id)
if result: if result:
try: try:
@ -157,7 +157,7 @@ async def update_memory_by_id(
form_data: MemoryUpdateModel, form_data: MemoryUpdateModel,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
memory = Memories.update_memory_by_id_and_user_id( memory = await Memories.update_memory_by_id_and_user_id(
memory_id, user.id, form_data.content memory_id, user.id, form_data.content
) )
if memory is None: if memory is None:
@ -191,7 +191,7 @@ async def update_memory_by_id(
@router.delete("/{memory_id}", response_model=bool) @router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) result = await Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result: if result:
VECTOR_DB_CLIENT.delete( VECTOR_DB_CLIENT.delete(