mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
wip: prompts
This commit is contained in:
parent
f4cd24d2ca
commit
93d27b84d4
3 changed files with 27 additions and 25 deletions
|
|
@ -69,7 +69,7 @@ class PromptForm(BaseModel):
|
|||
|
||||
|
||||
class PromptsTable:
|
||||
def insert_new_prompt(
|
||||
async def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
prompt = PromptModel(
|
||||
|
|
@ -83,9 +83,9 @@ class PromptsTable:
|
|||
try:
|
||||
async with get_db() as db:
|
||||
result = Prompt(**prompt.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
await db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
if result:
|
||||
return PromptModel.model_validate(result)
|
||||
else:
|
||||
|
|
@ -93,10 +93,10 @@ class PromptsTable:
|
|||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||
async def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
prompt = await db.query(Prompt).filter_by(command=command).first()
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
|
|
@ -105,7 +105,9 @@ class PromptsTable:
|
|||
async with get_db() as db:
|
||||
prompts = []
|
||||
|
||||
for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all():
|
||||
for prompt in (
|
||||
await db.query(Prompt).order_by(Prompt.timestamp.desc()).all()
|
||||
):
|
||||
user = await Users.get_user_by_id(prompt.user_id)
|
||||
prompts.append(
|
||||
PromptUserResponse.model_validate(
|
||||
|
|
@ -118,10 +120,10 @@ class PromptsTable:
|
|||
|
||||
return prompts
|
||||
|
||||
def get_prompts_by_user_id(
|
||||
async def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[PromptUserResponse]:
|
||||
prompts = self.get_prompts()
|
||||
prompts = await self.get_prompts()
|
||||
|
||||
return [
|
||||
prompt
|
||||
|
|
@ -130,26 +132,26 @@ class PromptsTable:
|
|||
or await has_access(user_id, permission, prompt.access_control)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
async def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
) -> Optional[PromptModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
prompt = await db.query(Prompt).filter_by(command=command).first()
|
||||
prompt.title = form_data.title
|
||||
prompt.content = form_data.content
|
||||
prompt.access_control = form_data.access_control
|
||||
prompt.timestamp = int(time.time())
|
||||
db.commit()
|
||||
await db.commit()
|
||||
return PromptModel.model_validate(prompt)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_prompt_by_command(self, command: str) -> bool:
|
||||
async def delete_prompt_by_command(self, command: str) -> bool:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Prompt).filter_by(command=command).delete()
|
||||
db.commit()
|
||||
await db.query(Prompt).filter_by(command=command).delete()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ router = APIRouter()
|
|||
@router.get("/", response_model=list[PromptModel])
|
||||
async def get_prompts(user=Depends(get_verified_user)):
|
||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||
prompts = Prompts.get_prompts()
|
||||
prompts = await Prompts.get_prompts()
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "read")
|
||||
prompts = await Prompts.get_prompts_by_user_id(user.id, "read")
|
||||
|
||||
return prompts
|
||||
|
||||
|
|
@ -32,9 +32,9 @@ async def get_prompts(user=Depends(get_verified_user)):
|
|||
@router.get("/list", response_model=list[PromptUserResponse])
|
||||
async def get_prompt_list(user=Depends(get_verified_user)):
|
||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||
prompts = Prompts.get_prompts()
|
||||
prompts = await Prompts.get_prompts()
|
||||
else:
|
||||
prompts = Prompts.get_prompts_by_user_id(user.id, "write")
|
||||
prompts = await Prompts.get_prompts_by_user_id(user.id, "write")
|
||||
|
||||
return prompts
|
||||
|
||||
|
|
@ -56,9 +56,9 @@ async def create_new_prompt(
|
|||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||
prompt = await Prompts.get_prompt_by_command(form_data.command)
|
||||
if prompt is None:
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||
prompt = await Prompts.insert_new_prompt(user.id, form_data)
|
||||
|
||||
if prompt:
|
||||
return prompt
|
||||
|
|
@ -141,7 +141,7 @@ async def update_prompt_by_command(
|
|||
|
||||
@router.delete("/command/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
prompt = await Prompts.get_prompt_by_command(f"/{command}")
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -158,5 +158,5 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||
result = await Prompts.delete_prompt_by_command(f"/{command}")
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def get_permissions(
|
|||
return permissions
|
||||
|
||||
|
||||
def has_permission(
|
||||
async def has_permission(
|
||||
user_id: str,
|
||||
permission_key: str,
|
||||
default_permissions: Dict[str, Any] = {},
|
||||
|
|
@ -93,7 +93,7 @@ def has_permission(
|
|||
permission_hierarchy = permission_key.split(".")
|
||||
|
||||
# Retrieve user group permissions
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id)
|
||||
|
||||
for group in user_groups:
|
||||
group_permissions = group.permissions
|
||||
|
|
|
|||
Loading…
Reference in a new issue