From 93d27b84d4ad7e981abacd601c545347fd415637 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 14 Aug 2025 16:23:37 +0400 Subject: [PATCH] wip: prompts --- backend/open_webui/models/prompts.py | 32 ++++++++++++---------- backend/open_webui/routers/prompts.py | 16 +++++------ backend/open_webui/utils/access_control.py | 4 +-- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 6c94db504b..51968cfe28 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -69,7 +69,7 @@ class PromptForm(BaseModel): class PromptsTable: - def insert_new_prompt( + async def insert_new_prompt( self, user_id: str, form_data: PromptForm ) -> Optional[PromptModel]: prompt = PromptModel( @@ -83,9 +83,9 @@ class PromptsTable: try: async with get_db() as db: result = Prompt(**prompt.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + await db.add(result) + await db.commit() + await db.refresh(result) if result: return PromptModel.model_validate(result) else: @@ -93,10 +93,10 @@ class PromptsTable: except Exception: return None - def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + async def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: async with get_db() as db: - prompt = db.query(Prompt).filter_by(command=command).first() + prompt = await db.query(Prompt).filter_by(command=command).first() return PromptModel.model_validate(prompt) except Exception: return None @@ -105,7 +105,9 @@ class PromptsTable: async with get_db() as db: prompts = [] - for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): + for prompt in ( + await db.query(Prompt).order_by(Prompt.timestamp.desc()).all() + ): user = await Users.get_user_by_id(prompt.user_id) prompts.append( PromptUserResponse.model_validate( @@ -118,10 +120,10 @@ class PromptsTable: return prompts - def get_prompts_by_user_id( + async def get_prompts_by_user_id( self, user_id: str, permission: str = "write" ) -> list[PromptUserResponse]: - prompts = self.get_prompts() + prompts = await self.get_prompts() return [ prompt @@ -130,26 +132,26 @@ class PromptsTable: or await has_access(user_id, permission, prompt.access_control) ] - def update_prompt_by_command( + async def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: async with get_db() as db: - prompt = db.query(Prompt).filter_by(command=command).first() + prompt = await db.query(Prompt).filter_by(command=command).first() prompt.title = form_data.title prompt.content = form_data.content prompt.access_control = form_data.access_control prompt.timestamp = int(time.time()) - db.commit() + await db.commit() return PromptModel.model_validate(prompt) except Exception: return None - def delete_prompt_by_command(self, command: str) -> bool: + async def delete_prompt_by_command(self, command: str) -> bool: try: async with get_db() as db: - db.query(Prompt).filter_by(command=command).delete() - db.commit() + await db.query(Prompt).filter_by(command=command).delete() + await db.commit() return True except Exception: diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 3988bbd236..414e2079fd 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -22,9 +22,9 @@ router = APIRouter() @router.get("/", response_model=list[PromptModel]) async def get_prompts(user=Depends(get_verified_user)): if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: - prompts = Prompts.get_prompts() + prompts = await Prompts.get_prompts() else: - prompts = Prompts.get_prompts_by_user_id(user.id, "read") + prompts = await Prompts.get_prompts_by_user_id(user.id, "read") return prompts @@ -32,9 +32,9 @@ async def get_prompts(user=Depends(get_verified_user)): @router.get("/list", response_model=list[PromptUserResponse]) async def get_prompt_list(user=Depends(get_verified_user)): if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS: - prompts = Prompts.get_prompts() + prompts = await Prompts.get_prompts() else: - prompts = Prompts.get_prompts_by_user_id(user.id, "write") + prompts = await Prompts.get_prompts_by_user_id(user.id, "write") return prompts @@ -56,9 +56,9 @@ async def create_new_prompt( detail=ERROR_MESSAGES.UNAUTHORIZED, ) - prompt = Prompts.get_prompt_by_command(form_data.command) + prompt = await Prompts.get_prompt_by_command(form_data.command) if prompt is None: - prompt = Prompts.insert_new_prompt(user.id, form_data) + prompt = await Prompts.insert_new_prompt(user.id, form_data) if prompt: return prompt @@ -141,7 +141,7 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") + prompt = await Prompts.get_prompt_by_command(f"/{command}") if not prompt: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -158,5 +158,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user) detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Prompts.delete_prompt_by_command(f"/{command}") + result = await Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index ec732c0408..7c7b379e3e 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -69,7 +69,7 @@ def get_permissions( return permissions -def has_permission( +async def has_permission( user_id: str, permission_key: str, default_permissions: Dict[str, Any] = {}, @@ -93,7 +93,7 @@ def has_permission( permission_hierarchy = permission_key.split(".") # Retrieve user group permissions - user_groups = Groups.get_groups_by_member_id(user_id) + user_groups = await Groups.get_groups_by_member_id(user_id) for group in user_groups: group_permissions = group.permissions