wip: prompts

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:23:37 +04:00
parent f4cd24d2ca
commit 93d27b84d4
3 changed files with 27 additions and 25 deletions

View file

@ -69,7 +69,7 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def insert_new_prompt( async def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
prompt = PromptModel( prompt = PromptModel(
@ -83,9 +83,9 @@ class PromptsTable:
try: try:
async with get_db() as db: async with get_db() as db:
result = Prompt(**prompt.model_dump()) result = Prompt(**prompt.model_dump())
db.add(result) await db.add(result)
db.commit() await db.commit()
db.refresh(result) await db.refresh(result)
if result: if result:
return PromptModel.model_validate(result) return PromptModel.model_validate(result)
else: else:
@ -93,10 +93,10 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: async def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
async with get_db() as db: async with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = await db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except Exception: except Exception:
return None return None
@ -105,7 +105,9 @@ class PromptsTable:
async with get_db() as db: async with get_db() as db:
prompts = [] prompts = []
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): for prompt in (
await db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
):
user = await Users.get_user_by_id(prompt.user_id) user = await Users.get_user_by_id(prompt.user_id)
prompts.append( prompts.append(
PromptUserResponse.model_validate( PromptUserResponse.model_validate(
@ -118,10 +120,10 @@ class PromptsTable:
return prompts return prompts
def get_prompts_by_user_id( async def get_prompts_by_user_id(
self, user_id: str, permission: str = "write" self, user_id: str, permission: str = "write"
) -> list[PromptUserResponse]: ) -> list[PromptUserResponse]:
prompts = self.get_prompts() prompts = await self.get_prompts()
return [ return [
prompt prompt
@ -130,26 +132,26 @@ class PromptsTable:
or await has_access(user_id, permission, prompt.access_control) or await has_access(user_id, permission, prompt.access_control)
] ]
def update_prompt_by_command( async def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
async with get_db() as db: async with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = await db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title prompt.title = form_data.title
prompt.content = form_data.content prompt.content = form_data.content
prompt.access_control = form_data.access_control prompt.access_control = form_data.access_control
prompt.timestamp = int(time.time()) prompt.timestamp = int(time.time())
db.commit() await db.commit()
return PromptModel.model_validate(prompt) return PromptModel.model_validate(prompt)
except Exception: except Exception:
return None return None
def delete_prompt_by_command(self, command: str) -> bool: async def delete_prompt_by_command(self, command: str) -> bool:
try: try:
async with get_db() as db: async with get_db() as db:
db.query(Prompt).filter_by(command=command).delete() await db.query(Prompt).filter_by(command=command).delete()
db.commit() await db.commit()
return True return True
except Exception: except Exception:

View file

@ -22,9 +22,9 @@ router = APIRouter()
@router.get("/", response_model=list[PromptModel]) @router.get("/", response_model=list[PromptModel])
async def get_prompts(user=Depends(get_verified_user)): async def get_prompts(user=Depends(get_verified_user)):
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
prompts = Prompts.get_prompts() prompts = await Prompts.get_prompts()
else: else:
prompts = Prompts.get_prompts_by_user_id(user.id, "read") prompts = await Prompts.get_prompts_by_user_id(user.id, "read")
return prompts return prompts
@ -32,9 +32,9 @@ async def get_prompts(user=Depends(get_verified_user)):
@router.get("/list", response_model=list[PromptUserResponse]) @router.get("/list", response_model=list[PromptUserResponse])
async def get_prompt_list(user=Depends(get_verified_user)): async def get_prompt_list(user=Depends(get_verified_user)):
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
prompts = Prompts.get_prompts() prompts = await Prompts.get_prompts()
else: else:
prompts = Prompts.get_prompts_by_user_id(user.id, "write") prompts = await Prompts.get_prompts_by_user_id(user.id, "write")
return prompts return prompts
@ -56,9 +56,9 @@ async def create_new_prompt(
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
prompt = Prompts.get_prompt_by_command(form_data.command) prompt = await Prompts.get_prompt_by_command(form_data.command)
if prompt is None: if prompt is None:
prompt = Prompts.insert_new_prompt(user.id, form_data) prompt = await Prompts.insert_new_prompt(user.id, form_data)
if prompt: if prompt:
return prompt return prompt
@ -141,7 +141,7 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool) @router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}") prompt = await Prompts.get_prompt_by_command(f"/{command}")
if not prompt: if not prompt:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -158,5 +158,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Prompts.delete_prompt_by_command(f"/{command}") result = await Prompts.delete_prompt_by_command(f"/{command}")
return result return result

View file

@ -69,7 +69,7 @@ def get_permissions(
return permissions return permissions
def has_permission( async def has_permission(
user_id: str, user_id: str,
permission_key: str, permission_key: str,
default_permissions: Dict[str, Any] = {}, default_permissions: Dict[str, Any] = {},
@ -93,7 +93,7 @@ def has_permission(
permission_hierarchy = permission_key.split(".") permission_hierarchy = permission_key.split(".")
# Retrieve user group permissions # Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = await Groups.get_groups_by_member_id(user_id)
for group in user_groups: for group in user_groups:
group_permissions = group.permissions group_permissions = group.permissions