mirror of
https://github.com/open-webui/open-webui.git
synced 2025-12-13 04:45:19 +00:00
wip: users, tools
This commit is contained in:
parent
44e9ae243d
commit
e1da74541b
32 changed files with 289 additions and 248 deletions
|
|
@ -174,7 +174,7 @@ async def generate_function_chat_completion(
|
|||
pipe_id, _ = pipe_id.split(".", 1)
|
||||
return pipe_id
|
||||
|
||||
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
async def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
if extra_params is None:
|
||||
extra_params = {}
|
||||
|
||||
|
|
@ -187,7 +187,9 @@ async def generate_function_chat_completion(
|
|||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||
user_valves = await Functions.get_user_valves_by_id_and_user_id(
|
||||
pipe_id, user.id
|
||||
)
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
||||
except Exception as e:
|
||||
|
|
@ -232,7 +234,7 @@ async def generate_function_chat_completion(
|
|||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
}
|
||||
extra_params["__tools__"] = get_tools(
|
||||
extra_params["__tools__"] = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
|
|
@ -261,7 +263,7 @@ async def generate_function_chat_completion(
|
|||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
params = await get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data.get("stream", False):
|
||||
|
||||
|
|
|
|||
|
|
@ -530,7 +530,7 @@ async def lifespan(app: FastAPI):
|
|||
# This should be blocking (sync) so functions are not deactivated on first /get_models calls
|
||||
# when the first user lands on the / route.
|
||||
log.info("Installing external dependencies of functions and tools...")
|
||||
install_tool_and_function_dependencies()
|
||||
await install_tool_and_function_dependencies()
|
||||
|
||||
app.state.redis = get_redis_connection(
|
||||
redis_url=REDIS_URL,
|
||||
|
|
@ -1267,11 +1267,11 @@ if audit_level != AuditLevel.NONE:
|
|||
async def get_models(
|
||||
request: Request, refresh: bool = False, user=Depends(get_verified_user)
|
||||
):
|
||||
def get_filtered_models(models, user):
|
||||
async def get_filtered_models(models, user):
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
if model.get("arena"):
|
||||
if has_access(
|
||||
if await has_access(
|
||||
user.id,
|
||||
type="read",
|
||||
access_control=model.get("info", {})
|
||||
|
|
@ -1286,7 +1286,7 @@ async def get_models(
|
|||
if (
|
||||
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||
or user.id == model_info.user_id
|
||||
or has_access(
|
||||
or await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
|
|
@ -1334,7 +1334,7 @@ async def get_models(
|
|||
user.role == "user"
|
||||
or (user.role == "admin" and not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||
) and not BYPASS_MODEL_ACCESS_CONTROL:
|
||||
models = get_filtered_models(models, user)
|
||||
models = await get_filtered_models(models, user)
|
||||
|
||||
log.debug(
|
||||
f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"
|
||||
|
|
@ -1408,7 +1408,7 @@ async def chat_completion(
|
|||
user.role != "admin" or not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
|
||||
):
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
await check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
|
|
@ -1628,7 +1628,7 @@ async def get_app_config(request: Request):
|
|||
if data is not None and "id" in data:
|
||||
user = await Users.get_user_by_id(data["id"])
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
user_count = await Users.get_num_users()
|
||||
onboarding = False
|
||||
|
||||
if user is None:
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class ChannelTable:
|
|||
channel
|
||||
for channel in channels
|
||||
if channel.user_id == user_id
|
||||
or has_access(user_id, permission, channel.access_control)
|
||||
or await has_access(user_id, permission, channel.access_control)
|
||||
]
|
||||
|
||||
async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class KnowledgeTable:
|
|||
)
|
||||
return knowledge_bases
|
||||
|
||||
def get_knowledge_bases_by_user_id(
|
||||
async def get_knowledge_bases_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[KnowledgeUserModel]:
|
||||
knowledge_bases = self.get_knowledge_bases()
|
||||
|
|
@ -151,31 +151,31 @@ class KnowledgeTable:
|
|||
knowledge_base
|
||||
for knowledge_base in knowledge_bases
|
||||
if knowledge_base.user_id == user_id
|
||||
or has_access(user_id, permission, knowledge_base.access_control)
|
||||
or await has_access(user_id, permission, knowledge_base.access_control)
|
||||
]
|
||||
|
||||
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||
async def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
||||
knowledge = await db.query(Knowledge).filter_by(id=id).first()
|
||||
return KnowledgeModel.model_validate(knowledge) if knowledge else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_knowledge_by_id(
|
||||
async def update_knowledge_by_id(
|
||||
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
|
||||
) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
knowledge = self.get_knowledge_by_id(id=id)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
knowledge = await self.get_knowledge_by_id(id=id)
|
||||
await db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_knowledge_by_id(id=id)
|
||||
await db.commit()
|
||||
return await self.get_knowledge_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ class ModelsTable:
|
|||
for model in db.query(Model).filter(Model.base_model_id == None).all()
|
||||
]
|
||||
|
||||
def get_models_by_user_id(
|
||||
async def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ModelUserResponse]:
|
||||
models = self.get_models()
|
||||
|
|
@ -203,7 +203,7 @@ class ModelsTable:
|
|||
model
|
||||
for model in models
|
||||
if model.user_id == user_id
|
||||
or has_access(user_id, permission, model.access_control)
|
||||
or await has_access(user_id, permission, model.access_control)
|
||||
]
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
|
|
|
|||
|
|
@ -96,20 +96,20 @@ class NoteTable:
|
|||
db.commit()
|
||||
return note
|
||||
|
||||
def get_notes(self) -> list[NoteModel]:
|
||||
async def get_notes(self) -> list[NoteModel]:
|
||||
async with get_db() as db:
|
||||
notes = db.query(Note).order_by(Note.updated_at.desc()).all()
|
||||
notes = await db.query(Note).order_by(Note.updated_at.desc()).all()
|
||||
return [NoteModel.model_validate(note) for note in notes]
|
||||
|
||||
def get_notes_by_user_id(
|
||||
async def get_notes_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[NoteModel]:
|
||||
notes = self.get_notes()
|
||||
notes = await self.get_notes()
|
||||
return [
|
||||
note
|
||||
for note in notes
|
||||
if note.user_id == user_id
|
||||
or has_access(user_id, permission, note.access_control)
|
||||
or await has_access(user_id, permission, note.access_control)
|
||||
]
|
||||
|
||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class PromptsTable:
|
|||
prompt
|
||||
for prompt in prompts
|
||||
if prompt.user_id == user_id
|
||||
or has_access(user_id, permission, prompt.access_control)
|
||||
or await has_access(user_id, permission, prompt.access_control)
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class ToolValves(BaseModel):
|
|||
|
||||
|
||||
class ToolsTable:
|
||||
def insert_new_tool(
|
||||
async def insert_new_tool(
|
||||
self, user_id: str, form_data: ToolForm, specs: list[dict]
|
||||
) -> Optional[ToolModel]:
|
||||
async with get_db() as db:
|
||||
|
|
@ -123,9 +123,9 @@ class ToolsTable:
|
|||
|
||||
try:
|
||||
result = Tool(**tool.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
await db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
if result:
|
||||
return ToolModel.model_validate(result)
|
||||
else:
|
||||
|
|
@ -134,10 +134,10 @@ class ToolsTable:
|
|||
log.exception(f"Error creating a new tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
async def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
tool = db.get(Tool, id)
|
||||
tool = await db.get(Tool, id)
|
||||
return ToolModel.model_validate(tool)
|
||||
except Exception:
|
||||
return None
|
||||
|
|
@ -145,7 +145,7 @@ class ToolsTable:
|
|||
async def get_tools(self) -> list[ToolUserModel]:
|
||||
async with get_db() as db:
|
||||
tools = []
|
||||
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
|
||||
for tool in await db.query(Tool).order_by(Tool.updated_at.desc()).all():
|
||||
user = await Users.get_user_by_id(tool.user_id)
|
||||
tools.append(
|
||||
ToolUserModel.model_validate(
|
||||
|
|
@ -157,35 +157,37 @@ class ToolsTable:
|
|||
)
|
||||
return tools
|
||||
|
||||
def get_tools_by_user_id(
|
||||
async def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ToolUserModel]:
|
||||
tools = self.get_tools()
|
||||
tools = await self.get_tools()
|
||||
|
||||
return [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user_id
|
||||
or has_access(user_id, permission, tool.access_control)
|
||||
or await has_access(user_id, permission, tool.access_control)
|
||||
]
|
||||
|
||||
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
async def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
tool = db.get(Tool, id)
|
||||
tool = await db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting tool valves by id {id}: {e}")
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
|
||||
async def update_tool_valves_by_id(
|
||||
self, id: str, valves: dict
|
||||
) -> Optional[ToolValves]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
await db.query(Tool).filter_by(id=id).update(
|
||||
{"valves": valves, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_tool_by_id(id)
|
||||
await db.commit()
|
||||
return await self.get_tool_by_id(id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
|
@ -225,7 +227,7 @@ class ToolsTable:
|
|||
user_settings["tools"]["valves"][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
await Users.update_user_by_id(user_id, {"settings": user_settings})
|
||||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
except Exception as e:
|
||||
|
|
@ -234,25 +236,25 @@ class ToolsTable:
|
|||
)
|
||||
return None
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
async def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
await db.query(Tool).filter_by(id=id).update(
|
||||
{**updated, "updated_at": int(time.time())}
|
||||
)
|
||||
db.commit()
|
||||
await db.commit()
|
||||
|
||||
tool = db.query(Tool).get(id)
|
||||
db.refresh(tool)
|
||||
tool = await db.query(Tool).get(id)
|
||||
await db.refresh(tool)
|
||||
return ToolModel.model_validate(tool)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_tool_by_id(self, id: str) -> bool:
|
||||
async def delete_tool_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
async with get_db() as db:
|
||||
db.query(Tool).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
await db.query(Tool).filter_by(id=id).delete()
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -517,7 +517,7 @@ def get_sources_from_items(
|
|||
if note and (
|
||||
user.role == "admin"
|
||||
or note.user_id == user.id
|
||||
or has_access(user.id, "read", note.access_control)
|
||||
or await has_access(user.id, "read", note.access_control)
|
||||
):
|
||||
# User has access to the note
|
||||
query_result = {
|
||||
|
|
@ -581,7 +581,7 @@ def get_sources_from_items(
|
|||
|
||||
if knowledge_base and (
|
||||
user.role == "admin"
|
||||
or has_access(user.id, "read", knowledge_base.access_control)
|
||||
or await has_access(user.id, "read", knowledge_base.access_control)
|
||||
):
|
||||
|
||||
file_ids = knowledge_base.data.get("file_ids", [])
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ async def update_profile(
|
|||
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
|
||||
):
|
||||
if session_user:
|
||||
user = Users.update_user_by_id(
|
||||
user = await Users.update_user_by_id(
|
||||
session_user.id,
|
||||
{"profile_image_url": form_data.profile_image_url, "name": form_data.name},
|
||||
)
|
||||
|
|
@ -348,12 +348,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
|||
if not connection_user.bind():
|
||||
raise HTTPException(400, "Authentication failed.")
|
||||
|
||||
user = Users.get_user_by_email(email)
|
||||
user = await Users.get_user_by_email(email)
|
||||
if not user:
|
||||
try:
|
||||
role = (
|
||||
"admin"
|
||||
if not Users.has_users()
|
||||
if not await Users.has_users()
|
||||
else request.app.state.config.DEFAULT_USER_ROLE
|
||||
)
|
||||
|
||||
|
|
@ -463,7 +463,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
|
||||
|
||||
if not Users.get_user_by_email(email.lower()):
|
||||
if not await Users.get_user_by_email(email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
response,
|
||||
|
|
@ -484,10 +484,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
admin_email = "admin@localhost"
|
||||
admin_password = "admin"
|
||||
|
||||
if Users.get_user_by_email(admin_email.lower()):
|
||||
if await Users.get_user_by_email(admin_email.lower()):
|
||||
user = await Auths.authenticate_user(admin_email.lower(), admin_password)
|
||||
else:
|
||||
if Users.has_users():
|
||||
if await Users.has_users():
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
|
||||
|
||||
await signup(
|
||||
|
|
@ -556,7 +556,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
|
|||
|
||||
@router.post("/signup", response_model=SessionUserResponse)
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
has_users = Users.has_users()
|
||||
has_users = await Users.has_users()
|
||||
|
||||
if WEBUI_AUTH:
|
||||
if (
|
||||
|
|
@ -577,7 +577,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
)
|
||||
|
||||
if Users.get_user_by_email(form_data.email.lower()):
|
||||
if await Users.get_user_by_email(form_data.email.lower()):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
|
|
@ -733,7 +733,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
|||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
)
|
||||
|
||||
if Users.get_user_by_email(form_data.email.lower()):
|
||||
if await Users.get_user_by_email(form_data.email.lower()):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
|
|
@ -780,11 +780,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
|||
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
admin = await Users.get_user_by_email(admin_email)
|
||||
if admin:
|
||||
admin_name = admin.name
|
||||
else:
|
||||
admin = Users.get_first_user()
|
||||
admin = await Users.get_first_user()
|
||||
if admin:
|
||||
admin_email = admin.email
|
||||
admin_name = admin.name
|
||||
|
|
@ -1025,7 +1025,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
|||
)
|
||||
|
||||
api_key = create_api_key()
|
||||
success = Users.update_user_api_key_by_id(user.id, api_key)
|
||||
success = await Users.update_user_api_key_by_id(user.id, api_key)
|
||||
|
||||
if success:
|
||||
return {
|
||||
|
|
@ -1038,14 +1038,14 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
|
|||
# delete api key
|
||||
@router.delete("/api_key", response_model=bool)
|
||||
async def delete_api_key(user=Depends(get_current_user)):
|
||||
success = Users.update_user_api_key_by_id(user.id, None)
|
||||
success = await Users.update_user_api_key_by_id(user.id, None)
|
||||
return success
|
||||
|
||||
|
||||
# get api key
|
||||
@router.get("/api_key", response_model=ApiKey)
|
||||
async def get_api_key(user=Depends(get_current_user)):
|
||||
api_key = Users.get_user_api_key_by_id(user.id)
|
||||
api_key = await Users.get_user_api_key_by_id(user.id)
|
||||
if api_key:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -157,7 +157,7 @@ async def get_channel_messages(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -197,7 +197,7 @@ async def get_channel_messages(
|
|||
|
||||
|
||||
async def send_notification(name, webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
users = await get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
if user.id in active_user_ids:
|
||||
|
|
@ -236,7 +236,7 @@ async def post_new_message(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -341,7 +341,7 @@ async def get_channel_message(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -390,7 +390,7 @@ async def get_channel_thread_messages(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -452,7 +452,9 @@ async def update_message_by_id(
|
|||
if (
|
||||
user.role != "admin"
|
||||
and message.user_id != user.id
|
||||
and not has_access(user.id, type="read", access_control=channel.access_control)
|
||||
and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
@ -512,7 +514,7 @@ async def add_reaction_to_message(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -578,7 +580,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
if user.role != "admin" and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -661,7 +663,9 @@ async def delete_message_by_id(
|
|||
if (
|
||||
user.role != "admin"
|
||||
and message.user_id != user.id
|
||||
and not has_access(user.id, type="read", access_control=channel.access_control)
|
||||
and not await has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
|
|||
|
|
@ -440,10 +440,10 @@ async def update_function_valves_by_id(
|
|||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
if function:
|
||||
try:
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
user_valves = await Functions.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
return user_valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -482,7 +482,7 @@ async def get_function_user_valves_spec_by_id(
|
|||
async def update_function_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
function = Functions.get_function_by_id(id)
|
||||
function = await Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
function_module, function_type, frontmatter = get_function_module_from_cache(
|
||||
|
|
@ -495,7 +495,7 @@ async def update_function_user_valves_by_id(
|
|||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
user_valves = UserValves(**form_data)
|
||||
Functions.update_user_valves_by_id_and_user_id(
|
||||
await Functions.update_user_valves_by_id_and_user_id(
|
||||
id, user.id, user_valves.model_dump()
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ async def update_group_by_id(
|
|||
):
|
||||
try:
|
||||
if form_data.user_ids:
|
||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
||||
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
|
||||
|
||||
group = Groups.update_group_by_id(id, form_data)
|
||||
if group:
|
||||
|
|
@ -119,7 +119,7 @@ async def add_user_to_group(
|
|||
):
|
||||
try:
|
||||
if form_data.user_ids:
|
||||
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
|
||||
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
|
||||
|
||||
group = Groups.add_users_to_group(id, form_data.user_ids)
|
||||
if group:
|
||||
|
|
|
|||
|
|
@ -261,11 +261,11 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
user.role == "admin"
|
||||
or knowledge.user_id == user.id
|
||||
or has_access(user.id, "read", knowledge.access_control)
|
||||
or await has_access(user.id, "read", knowledge.access_control)
|
||||
):
|
||||
|
||||
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
|
||||
files = Files.get_file_metadatas_by_ids(file_ids)
|
||||
files = await Files.get_file_metadatas_by_ids(file_ids)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
|
|
@ -298,7 +298,7 @@ async def update_knowledge_by_id(
|
|||
# Is the user the original creator, in a group with write access, or an admin
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -348,7 +348,7 @@ def add_file_to_knowledge_by_id(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -432,7 +432,7 @@ def update_file_from_knowledge_by_id(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
|
||||
|
|
@ -503,7 +503,7 @@ def remove_file_from_knowledge_by_id(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -591,7 +591,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -654,7 +654,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -697,7 +697,7 @@ def add_files_to_knowledge_batch(
|
|||
|
||||
if (
|
||||
knowledge.user_id != user.id
|
||||
and not has_access(user.id, "write", knowledge.access_control)
|
||||
and not await has_access(user.id, "write", knowledge.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
|
||||
or model.user_id == user.id
|
||||
or has_access(user.id, "read", model.access_control)
|
||||
or await has_access(user.id, "read", model.access_control)
|
||||
):
|
||||
return model
|
||||
else:
|
||||
|
|
@ -141,7 +141,7 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
user.role == "admin"
|
||||
or model.user_id == user.id
|
||||
or has_access(user.id, "write", model.access_control)
|
||||
or await has_access(user.id, "write", model.access_control)
|
||||
):
|
||||
model = Models.toggle_model_by_id(id)
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ async def update_model_by_id(
|
|||
|
||||
if (
|
||||
model.user_id != user.id
|
||||
and not has_access(user.id, "write", model.access_control)
|
||||
and not await has_access(user.id, "write", model.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -204,7 +204,7 @@ async def update_model_by_id(
|
|||
|
||||
@router.delete("/model/delete", response_model=bool)
|
||||
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
model = await Models.get_model_by_id(id)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -214,18 +214,18 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
|
|||
if (
|
||||
user.role != "admin"
|
||||
and model.user_id != user.id
|
||||
and not has_access(user.id, "write", model.access_control)
|
||||
and not await has_access(user.id, "write", model.access_control)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
result = Models.delete_model_by_id(id)
|
||||
result = await Models.delete_model_by_id(id)
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/delete/all", response_model=bool)
|
||||
async def delete_all_models(user=Depends(get_admin_user)):
|
||||
result = Models.delete_all_models()
|
||||
result = await Models.delete_all_models()
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -133,7 +133,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
|
|||
|
||||
if user.role != "admin" and (
|
||||
user.id != note.user_id
|
||||
and (not has_access(user.id, type="read", access_control=note.access_control))
|
||||
and (
|
||||
not await has_access(
|
||||
user.id, type="read", access_control=note.access_control
|
||||
)
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
@ -167,7 +171,9 @@ async def update_note_by_id(
|
|||
|
||||
if user.role != "admin" and (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
and not await has_access(
|
||||
user.id, type="write", access_control=note.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
@ -212,7 +218,9 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
|
|||
|
||||
if user.role != "admin" and (
|
||||
user.id != note.user_id
|
||||
and not has_access(user.id, type="write", access_control=note.access_control)
|
||||
and not await has_access(
|
||||
user.id, type="write", access_control=note.access_control
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
|
|
|
|||
|
|
@ -443,8 +443,10 @@ async def get_filtered_models(models, user):
|
|||
for model in models.get("models", []):
|
||||
model_info = Models.get_model_by_id(model["model"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
if user.id == model_info.user_id or (
|
||||
await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
filtered_models.append(model)
|
||||
return filtered_models
|
||||
|
|
@ -1336,8 +1338,10 @@ async def generate_chat_completion(
|
|||
if not bypass_filter and user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
or (
|
||||
await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -1442,8 +1446,10 @@ async def generate_openai_completion(
|
|||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
or (
|
||||
await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -1525,8 +1531,10 @@ async def generate_openai_chat_completion(
|
|||
if user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
or (
|
||||
await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -1623,8 +1631,10 @@ async def get_openai_models(
|
|||
for model in models:
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
if user.id == model_info.user_id or (
|
||||
await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
filtered_models.append(model)
|
||||
models = filtered_models
|
||||
|
|
|
|||
|
|
@ -385,8 +385,10 @@ async def get_filtered_models(models, user):
|
|||
for model in models.get("data", []):
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if user.id == model_info.user_id or has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
if user.id == model_info.user_id or (
|
||||
await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
filtered_models.append(model)
|
||||
return filtered_models
|
||||
|
|
@ -756,7 +758,7 @@ async def generate_chat_completion(
|
|||
if not bypass_filter and user.role == "user":
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
or await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -79,13 +79,13 @@ async def create_new_prompt(
|
|||
|
||||
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
prompt = await Prompts.get_prompt_by_command(f"/{command}")
|
||||
|
||||
if prompt:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or prompt.user_id == user.id
|
||||
or has_access(user.id, "read", prompt.access_control)
|
||||
or await has_access(user.id, "read", prompt.access_control)
|
||||
):
|
||||
return prompt
|
||||
else:
|
||||
|
|
@ -106,7 +106,7 @@ async def update_prompt_by_command(
|
|||
form_data: PromptForm,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
prompt = await Prompts.get_prompt_by_command(f"/{command}")
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -116,7 +116,7 @@ async def update_prompt_by_command(
|
|||
# Is the user the original creator, in a group with write access, or an admin
|
||||
if (
|
||||
prompt.user_id != user.id
|
||||
and not has_access(user.id, "write", prompt.access_control)
|
||||
and not await has_access(user.id, "write", prompt.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -124,7 +124,7 @@ async def update_prompt_by_command(
|
|||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
|
||||
prompt = await Prompts.update_prompt_by_command(f"/{command}", form_data)
|
||||
if prompt:
|
||||
return prompt
|
||||
else:
|
||||
|
|
@ -150,7 +150,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
|
|||
|
||||
if (
|
||||
prompt.user_id != user.id
|
||||
and not has_access(user.id, "write", prompt.access_control)
|
||||
and not await has_access(user.id, "write", prompt.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ def get_scim_auth(
|
|||
)
|
||||
|
||||
|
||||
def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||
async def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
||||
"""Convert internal User model to SCIM User"""
|
||||
# Parse display name into name components
|
||||
name_parts = user.name.split(" ", 1) if user.name else ["", ""]
|
||||
|
|
@ -346,7 +346,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
|
|||
)
|
||||
|
||||
|
||||
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||
async def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
|
||||
"""Convert internal Group model to SCIM Group"""
|
||||
members = []
|
||||
for user_id in group.user_ids:
|
||||
|
|
@ -493,20 +493,20 @@ async def get_users(
|
|||
# In production, you'd want a more robust filter parser
|
||||
if "userName eq" in filter:
|
||||
email = filter.split('"')[1]
|
||||
user = Users.get_user_by_email(email)
|
||||
user = await Users.get_user_by_email(email)
|
||||
users_list = [user] if user else []
|
||||
total = 1 if user else 0
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit)
|
||||
response = await Users.get_users(skip=skip, limit=limit)
|
||||
users_list = response["users"]
|
||||
total = response["total"]
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit)
|
||||
response = await Users.get_users(skip=skip, limit=limit)
|
||||
users_list = response["users"]
|
||||
total = response["total"]
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_users = [user_to_scim(user, request) for user in users_list]
|
||||
scim_users = [await user_to_scim(user, request) for user in users_list]
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total,
|
||||
|
|
@ -529,7 +529,7 @@ async def get_user(
|
|||
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
|
||||
)
|
||||
|
||||
return user_to_scim(user, request)
|
||||
return await user_to_scim(user, request)
|
||||
|
||||
|
||||
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
|
||||
|
|
@ -540,7 +540,7 @@ async def create_user(
|
|||
):
|
||||
"""Create SCIM User"""
|
||||
# Check if user already exists
|
||||
existing_user = Users.get_user_by_email(user_data.userName)
|
||||
existing_user = await Users.get_user_by_email(user_data.userName)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
|
|
@ -579,7 +579,7 @@ async def create_user(
|
|||
detail="Failed to create user",
|
||||
)
|
||||
|
||||
return user_to_scim(new_user, request)
|
||||
return await user_to_scim(new_user, request)
|
||||
|
||||
|
||||
@router.put("/Users/{user_id}", response_model=SCIMUser)
|
||||
|
|
@ -623,14 +623,14 @@ async def update_user(
|
|||
update_data["profile_image_url"] = user_data.photos[0].value
|
||||
|
||||
# Update user
|
||||
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||
updated_user = await Users.update_user_by_id(user_id, update_data)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update user",
|
||||
)
|
||||
|
||||
return user_to_scim(updated_user, request)
|
||||
return await user_to_scim(updated_user, request)
|
||||
|
||||
|
||||
@router.patch("/Users/{user_id}", response_model=SCIMUser)
|
||||
|
|
@ -669,7 +669,7 @@ async def patch_user(
|
|||
|
||||
# Update user
|
||||
if update_data:
|
||||
updated_user = Users.update_user_by_id(user_id, update_data)
|
||||
updated_user = await Users.update_user_by_id(user_id, update_data)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -678,7 +678,7 @@ async def patch_user(
|
|||
else:
|
||||
updated_user = user
|
||||
|
||||
return user_to_scim(updated_user, request)
|
||||
return await user_to_scim(updated_user, request)
|
||||
|
||||
|
||||
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
|
@ -695,7 +695,7 @@ async def delete_user(
|
|||
detail=f"User {user_id} not found",
|
||||
)
|
||||
|
||||
success = Users.delete_user_by_id(user_id)
|
||||
success = await Users.delete_user_by_id(user_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -716,7 +716,7 @@ async def get_groups(
|
|||
):
|
||||
"""List SCIM Groups"""
|
||||
# Get all groups
|
||||
groups_list = Groups.get_groups()
|
||||
groups_list = await Groups.get_groups()
|
||||
|
||||
# Apply pagination
|
||||
total = len(groups_list)
|
||||
|
|
@ -725,7 +725,7 @@ async def get_groups(
|
|||
paginated_groups = groups_list[start:end]
|
||||
|
||||
# Convert to SCIM format
|
||||
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
|
||||
scim_groups = [await group_to_scim(group, request) for group in paginated_groups]
|
||||
|
||||
return SCIMListResponse(
|
||||
totalResults=total,
|
||||
|
|
@ -742,14 +742,14 @@ async def get_group(
|
|||
_: bool = Depends(get_scim_auth),
|
||||
):
|
||||
"""Get SCIM Group by ID"""
|
||||
group = Groups.get_group_by_id(group_id)
|
||||
group = await Groups.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Group {group_id} not found",
|
||||
)
|
||||
|
||||
return group_to_scim(group, request)
|
||||
return await group_to_scim(group, request)
|
||||
|
||||
|
||||
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
|
||||
|
|
@ -774,14 +774,14 @@ async def create_group(
|
|||
)
|
||||
|
||||
# Need to get the creating user's ID - we'll use the first admin
|
||||
admin_user = Users.get_super_admin_user()
|
||||
admin_user = await Users.get_super_admin_user()
|
||||
if not admin_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="No admin user found",
|
||||
)
|
||||
|
||||
new_group = Groups.insert_new_group(admin_user.id, form)
|
||||
new_group = await Groups.insert_new_group(admin_user.id, form)
|
||||
if not new_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -797,10 +797,10 @@ async def create_group(
|
|||
description=new_group.description,
|
||||
user_ids=member_ids,
|
||||
)
|
||||
Groups.update_group_by_id(new_group.id, update_form)
|
||||
new_group = Groups.get_group_by_id(new_group.id)
|
||||
await Groups.update_group_by_id(new_group.id, update_form)
|
||||
new_group = await Groups.get_group_by_id(new_group.id)
|
||||
|
||||
return group_to_scim(new_group, request)
|
||||
return await group_to_scim(new_group, request)
|
||||
|
||||
|
||||
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
|
|
@ -839,7 +839,7 @@ async def update_group(
|
|||
detail="Failed to update group",
|
||||
)
|
||||
|
||||
return group_to_scim(updated_group, request)
|
||||
return await group_to_scim(updated_group, request)
|
||||
|
||||
|
||||
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
|
||||
|
|
@ -899,7 +899,7 @@ async def patch_group(
|
|||
detail="Failed to update group",
|
||||
)
|
||||
|
||||
return group_to_scim(updated_group, request)
|
||||
return await group_to_scim(updated_group, request)
|
||||
|
||||
|
||||
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
request.app.state.config.TOOL_SERVER_CONNECTIONS
|
||||
)
|
||||
|
||||
tools = Tools.get_tools()
|
||||
tools = await Tools.get_tools()
|
||||
for server in request.app.state.TOOL_SERVERS:
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
|
|
@ -83,7 +83,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
tool
|
||||
for tool in tools
|
||||
if tool.user_id == user.id
|
||||
or has_access(user.id, "read", tool.access_control)
|
||||
or await has_access(user.id, "read", tool.access_control)
|
||||
]
|
||||
return tools
|
||||
|
||||
|
|
@ -96,9 +96,9 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
|
|||
@router.get("/list", response_model=list[ToolUserResponse])
|
||||
async def get_tool_list(user=Depends(get_verified_user)):
|
||||
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
|
||||
tools = Tools.get_tools()
|
||||
tools = await Tools.get_tools()
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "write")
|
||||
tools = await Tools.get_tools_by_user_id(user.id, "write")
|
||||
return tools
|
||||
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ async def load_tool_from_url(
|
|||
|
||||
@router.get("/export", response_model=list[ToolModel])
|
||||
async def export_tools(user=Depends(get_admin_user)):
|
||||
tools = Tools.get_tools()
|
||||
tools = await Tools.get_tools()
|
||||
return tools
|
||||
|
||||
|
||||
|
|
@ -215,11 +215,11 @@ async def create_new_tools(
|
|||
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
tools = Tools.get_tool_by_id(form_data.id)
|
||||
tools = await Tools.get_tool_by_id(form_data.id)
|
||||
if tools is None:
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
tool_module, frontmatter = load_tool_module_by_id(
|
||||
tool_module, frontmatter = await load_tool_module_by_id(
|
||||
form_data.id, content=form_data.content
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
|
@ -228,7 +228,7 @@ async def create_new_tools(
|
|||
TOOLS[form_data.id] = tool_module
|
||||
|
||||
specs = get_tool_specs(TOOLS[form_data.id])
|
||||
tools = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
tools = await Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -260,13 +260,13 @@ async def create_new_tools(
|
|||
|
||||
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
||||
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
|
||||
if tools:
|
||||
if (
|
||||
user.role == "admin"
|
||||
or tools.user_id == user.id
|
||||
or has_access(user.id, "read", tools.access_control)
|
||||
or await has_access(user.id, "read", tools.access_control)
|
||||
):
|
||||
return tools
|
||||
else:
|
||||
|
|
@ -288,7 +288,7 @@ async def update_tools_by_id(
|
|||
form_data: ToolForm,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -298,7 +298,7 @@ async def update_tools_by_id(
|
|||
# Is the user the original creator, in a group with write access, or an admin
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and not await has_access(user.id, "write", tools.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -308,7 +308,9 @@ async def update_tools_by_id(
|
|||
|
||||
try:
|
||||
form_data.content = replace_imports(form_data.content)
|
||||
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
|
||||
tool_module, frontmatter = await load_tool_module_by_id(
|
||||
id, content=form_data.content
|
||||
)
|
||||
form_data.meta.manifest = frontmatter
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
|
|
@ -322,7 +324,7 @@ async def update_tools_by_id(
|
|||
}
|
||||
|
||||
log.debug(updated)
|
||||
tools = Tools.update_tool_by_id(id, updated)
|
||||
tools = await Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if tools:
|
||||
return tools
|
||||
|
|
@ -348,7 +350,7 @@ async def update_tools_by_id(
|
|||
async def delete_tools_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -357,7 +359,7 @@ async def delete_tools_by_id(
|
|||
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and not await has_access(user.id, "write", tools.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -365,7 +367,7 @@ async def delete_tools_by_id(
|
|||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
result = Tools.delete_tool_by_id(id)
|
||||
result = await Tools.delete_tool_by_id(id)
|
||||
if result:
|
||||
TOOLS = request.app.state.TOOLS
|
||||
if id in TOOLS:
|
||||
|
|
@ -381,10 +383,10 @@ async def delete_tools_by_id(
|
|||
|
||||
@router.get("/id/{id}/valves", response_model=Optional[dict])
|
||||
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if tools:
|
||||
try:
|
||||
valves = Tools.get_tool_valves_by_id(id)
|
||||
valves = await Tools.get_tool_valves_by_id(id)
|
||||
return valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -407,12 +409,12 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
|
|||
async def get_tools_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if tools:
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
tools_module, _ = await load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "Valves"):
|
||||
|
|
@ -435,7 +437,7 @@ async def get_tools_valves_spec_by_id(
|
|||
async def update_tools_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -444,7 +446,7 @@ async def update_tools_valves_by_id(
|
|||
|
||||
if (
|
||||
tools.user_id != user.id
|
||||
and not has_access(user.id, "write", tools.access_control)
|
||||
and not await has_access(user.id, "write", tools.access_control)
|
||||
and user.role != "admin"
|
||||
):
|
||||
raise HTTPException(
|
||||
|
|
@ -455,7 +457,7 @@ async def update_tools_valves_by_id(
|
|||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
tools_module, _ = await load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if not hasattr(tools_module, "Valves"):
|
||||
|
|
@ -468,7 +470,7 @@ async def update_tools_valves_by_id(
|
|||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
valves = Valves(**form_data)
|
||||
Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
await Tools.update_tool_valves_by_id(id, valves.model_dump())
|
||||
return valves.model_dump()
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to update tool valves by id {id}: {e}")
|
||||
|
|
@ -485,10 +487,10 @@ async def update_tools_valves_by_id(
|
|||
|
||||
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
|
||||
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if tools:
|
||||
try:
|
||||
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
user_valves = await Tools.get_user_valves_by_id_and_user_id(id, user.id)
|
||||
return user_valves
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -506,12 +508,12 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
|
|||
async def get_tools_user_valves_spec_by_id(
|
||||
request: Request, id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
if tools:
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
tools_module, _ = await load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "UserValves"):
|
||||
|
|
@ -529,13 +531,13 @@ async def get_tools_user_valves_spec_by_id(
|
|||
async def update_tools_user_valves_by_id(
|
||||
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
tools = Tools.get_tool_by_id(id)
|
||||
tools = await Tools.get_tool_by_id(id)
|
||||
|
||||
if tools:
|
||||
if id in request.app.state.TOOLS:
|
||||
tools_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
tools_module, _ = load_tool_module_by_id(id)
|
||||
tools_module, _ = await load_tool_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = tools_module
|
||||
|
||||
if hasattr(tools_module, "UserValves"):
|
||||
|
|
@ -544,7 +546,7 @@ async def update_tools_user_valves_by_id(
|
|||
try:
|
||||
form_data = {k: v for k, v in form_data.items() if v is not None}
|
||||
user_valves = UserValves(**form_data)
|
||||
Tools.update_user_valves_by_id_and_user_id(
|
||||
await Tools.update_user_valves_by_id_and_user_id(
|
||||
id, user.id, user_valves.model_dump()
|
||||
)
|
||||
return user_valves.model_dump()
|
||||
|
|
|
|||
|
|
@ -88,14 +88,14 @@ async def get_users(
|
|||
if direction:
|
||||
filter["direction"] = direction
|
||||
|
||||
return Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
return await Users.get_users(filter=filter, skip=skip, limit=limit)
|
||||
|
||||
|
||||
@router.get("/all", response_model=UserInfoListResponse)
|
||||
async def get_all_users(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
return Users.get_users()
|
||||
return await Users.get_users()
|
||||
|
||||
|
||||
############################
|
||||
|
|
@ -205,7 +205,7 @@ async def update_default_user_permissions(
|
|||
|
||||
@router.get("/user/settings", response_model=Optional[UserSettings])
|
||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
user = await Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
|
@ -237,7 +237,7 @@ async def update_user_settings_by_session_user(
|
|||
# If the user is not an admin and does not have permission to use tool servers, remove the key
|
||||
updated_user_settings["ui"].pop("toolServers", None)
|
||||
|
||||
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
|
||||
user = await Users.update_user_settings_by_id(user.id, updated_user_settings)
|
||||
if user:
|
||||
return user.settings
|
||||
else:
|
||||
|
|
@ -254,7 +254,7 @@ async def update_user_settings_by_session_user(
|
|||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
user = await Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
|
|
@ -273,12 +273,14 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
|||
async def update_user_info_by_session_user(
|
||||
form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
user = await Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
user = await Users.update_user_by_id(
|
||||
user.id, {"info": {**user.info, **form_data}}
|
||||
)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
|
|
@ -343,7 +345,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
|||
|
||||
@router.get("/{user_id}/profile/image")
|
||||
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user_id)
|
||||
user = await Users.get_user_by_id(user_id)
|
||||
if user:
|
||||
if user.profile_image_url:
|
||||
# check if it's url or base64
|
||||
|
|
@ -398,7 +400,7 @@ async def update_user_by_id(
|
|||
):
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
first_user = await Users.get_first_user()
|
||||
if first_user:
|
||||
if user_id == first_user.id:
|
||||
if session_user.id != user_id:
|
||||
|
|
@ -472,7 +474,7 @@ async def update_user_by_id(
|
|||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
# Prevent deletion of the primary admin user
|
||||
try:
|
||||
first_user = Users.get_first_user()
|
||||
first_user = await Users.get_first_user()
|
||||
if first_user and user_id == first_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
|
|||
|
|
@ -263,7 +263,7 @@ async def connect(sid, environ, auth):
|
|||
data = decode_token(auth["token"])
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = await Users.get_user_by_id(data["id"])
|
||||
|
||||
if user:
|
||||
SESSION_POOL[sid] = user.model_dump()
|
||||
|
|
@ -284,7 +284,7 @@ async def user_join(sid, data):
|
|||
if data is None or "id" not in data:
|
||||
return
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = await Users.get_user_by_id(data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
|
|
@ -312,7 +312,7 @@ async def join_channel(sid, data):
|
|||
if data is None or "id" not in data:
|
||||
return
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = await Users.get_user_by_id(data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
|
|
@ -333,7 +333,7 @@ async def join_note(sid, data):
|
|||
if token_data is None or "id" not in token_data:
|
||||
return
|
||||
|
||||
user = Users.get_user_by_id(token_data["id"])
|
||||
user = await Users.get_user_by_id(token_data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
|
|
@ -345,7 +345,9 @@ async def join_note(sid, data):
|
|||
if (
|
||||
user.role != "admin"
|
||||
and user.id != note.user_id
|
||||
and not has_access(user.id, type="read", access_control=note.access_control)
|
||||
and not await has_access(
|
||||
user.id, type="read", access_control=note.access_control
|
||||
)
|
||||
):
|
||||
log.error(f"User {user.id} does not have access to note {data['note_id']}")
|
||||
return
|
||||
|
|
@ -400,7 +402,7 @@ async def ydoc_document_join(sid, data):
|
|||
if (
|
||||
user.get("role") != "admin"
|
||||
and user.get("id") != note.user_id
|
||||
and not has_access(
|
||||
and not await has_access(
|
||||
user.get("id"), type="read", access_control=note.access_control
|
||||
)
|
||||
):
|
||||
|
|
@ -470,14 +472,14 @@ async def document_save_handler(document_id, data, user):
|
|||
if (
|
||||
user.get("role") != "admin"
|
||||
and user.get("id") != note.user_id
|
||||
and not has_access(
|
||||
and not await has_access(
|
||||
user.get("id"), type="read", access_control=note.access_control
|
||||
)
|
||||
):
|
||||
log.error(f"User {user.get('id')} does not have access to note {note_id}")
|
||||
return
|
||||
|
||||
Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
|
||||
await Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
|
||||
|
||||
|
||||
@sio.on("ydoc:document:state")
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def has_permission(
|
|||
return get_permission(default_permissions, permission_hierarchy)
|
||||
|
||||
|
||||
def has_access(
|
||||
async def has_access(
|
||||
user_id: str,
|
||||
type: str = "write",
|
||||
access_control: Optional[dict] = None,
|
||||
|
|
@ -115,7 +115,7 @@ def has_access(
|
|||
if access_control is None:
|
||||
return type == "read"
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id)
|
||||
user_group_ids = [group.id for group in user_groups]
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
|
|
@ -127,11 +127,11 @@ def has_access(
|
|||
|
||||
|
||||
# Get all users with access to a resource
|
||||
def get_users_with_access(
|
||||
async def get_users_with_access(
|
||||
type: str = "write", access_control: Optional[dict] = None
|
||||
) -> List[UserModel]:
|
||||
if access_control is None:
|
||||
return Users.get_users()
|
||||
return await Users.get_users()
|
||||
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
|
|
@ -140,8 +140,8 @@ def get_users_with_access(
|
|||
user_ids_with_access = set(permitted_user_ids)
|
||||
|
||||
for group_id in permitted_group_ids:
|
||||
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
|
||||
group_user_ids = await Groups.get_group_user_ids_by_id(group_id)
|
||||
if group_user_ids:
|
||||
user_ids_with_access.update(group_user_ids)
|
||||
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||
return await Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
|||
return None
|
||||
|
||||
|
||||
def get_current_user(
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
response: Response,
|
||||
background_tasks: BackgroundTasks,
|
||||
|
|
@ -270,7 +270,7 @@ def get_current_user(
|
|||
)
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = await Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
|
@ -303,6 +303,7 @@ def get_current_user(
|
|||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
if background_tasks:
|
||||
|
||||
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
|
||||
return user
|
||||
else:
|
||||
|
|
@ -312,8 +313,8 @@ def get_current_user(
|
|||
)
|
||||
|
||||
|
||||
def get_current_user_by_api_key(api_key: str):
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
async def get_current_user_by_api_key(api_key: str):
|
||||
user = await Users.get_user_by_api_key(api_key)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -329,7 +330,7 @@ def get_current_user_by_api_key(api_key: str):
|
|||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
await Users.update_user_last_active_by_id(user.id)
|
||||
|
||||
return user
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ async def generate_chat_completion(
|
|||
# Check if user has access to the model
|
||||
if not bypass_filter and user.role == "user":
|
||||
try:
|
||||
check_model_access(user, model)
|
||||
await check_model_access(user, model)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
|
@ -424,8 +424,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
__user__["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
action_id, user.id
|
||||
**(
|
||||
await Functions.get_user_valves_by_id_and_user_id(
|
||||
action_id, user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ async def generate_embeddings(
|
|||
# Access filtering
|
||||
if not getattr(request.state, "direct", False):
|
||||
if not bypass_filter and user.role == "user":
|
||||
check_model_access(user, model)
|
||||
await check_model_access(user, model)
|
||||
|
||||
# Ollama backend
|
||||
if model.get("owned_by") == "ollama":
|
||||
|
|
|
|||
|
|
@ -109,8 +109,10 @@ async def process_filter_functions(
|
|||
if hasattr(function_module, "UserValves"):
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(
|
||||
**Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
**(
|
||||
await Functions.get_user_valves_by_id_and_user_id(
|
||||
filter_id, params["__user__"]["id"]
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -311,9 +311,9 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
|
|||
return models
|
||||
|
||||
|
||||
def check_model_access(user, model):
|
||||
async def check_model_access(user, model):
|
||||
if model.get("arena"):
|
||||
if not has_access(
|
||||
if not await has_access(
|
||||
user.id,
|
||||
type="read",
|
||||
access_control=model.get("info", {})
|
||||
|
|
@ -322,12 +322,12 @@ def check_model_access(user, model):
|
|||
):
|
||||
raise Exception("Model not found")
|
||||
else:
|
||||
model_info = Models.get_model_by_id(model.get("id"))
|
||||
model_info = await Models.get_model_by_id(model.get("id"))
|
||||
if not model_info:
|
||||
raise Exception("Model not found")
|
||||
elif not (
|
||||
user.id == model_info.user_id
|
||||
or has_access(
|
||||
or await has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -89,8 +89,8 @@ class OAuthManager:
|
|||
def get_client(self, provider_name):
|
||||
return self.oauth.create_client(provider_name)
|
||||
|
||||
def get_user_role(self, user, user_data):
|
||||
user_count = Users.get_num_users()
|
||||
async def get_user_role(self, user, user_data):
|
||||
user_count = await Users.get_num_users()
|
||||
if user and user_count == 1:
|
||||
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
|
||||
log.debug("Assigning the only user the admin role")
|
||||
|
|
@ -147,7 +147,7 @@ class OAuthManager:
|
|||
|
||||
return role
|
||||
|
||||
def update_user_groups(self, user, user_data, default_permissions):
|
||||
async def update_user_groups(self, user, user_data, default_permissions):
|
||||
log.debug("Running OAUTH Group management")
|
||||
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
|
||||
|
||||
|
|
@ -172,8 +172,10 @@ class OAuthManager:
|
|||
else:
|
||||
user_oauth_groups = []
|
||||
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
user_current_groups: list[GroupModel] = await Groups.get_groups_by_member_id(
|
||||
user.id
|
||||
)
|
||||
all_available_groups: list[GroupModel] = await Groups.get_groups()
|
||||
|
||||
# Create groups if they don't exist and creation is enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
|
||||
|
|
@ -181,7 +183,7 @@ class OAuthManager:
|
|||
all_group_names = {g.name for g in all_available_groups}
|
||||
groups_created = False
|
||||
# Determine creator ID: Prefer admin, fallback to current user if no admin exists
|
||||
admin_user = Users.get_super_admin_user()
|
||||
admin_user = await Users.get_super_admin_user()
|
||||
creator_id = admin_user.id if admin_user else user.id
|
||||
log.debug(f"Using creator ID {creator_id} for potential group creation.")
|
||||
|
||||
|
|
@ -198,7 +200,7 @@ class OAuthManager:
|
|||
user_ids=[], # Start with no users, user will be added later by subsequent logic
|
||||
)
|
||||
# Use determined creator ID (admin or fallback to current user)
|
||||
created_group = Groups.insert_new_group(
|
||||
created_group = await Groups.insert_new_group(
|
||||
creator_id, new_group_form
|
||||
)
|
||||
if created_group:
|
||||
|
|
@ -217,7 +219,7 @@ class OAuthManager:
|
|||
|
||||
# Refresh the list of all available groups if any were created
|
||||
if groups_created:
|
||||
all_available_groups = Groups.get_groups()
|
||||
all_available_groups = await Groups.get_groups()
|
||||
log.debug("Refreshed list of all available groups after creation.")
|
||||
|
||||
log.debug(f"Oauth Groups claim: {oauth_claim}")
|
||||
|
|
@ -430,21 +432,21 @@ class OAuthManager:
|
|||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
|
||||
|
||||
# Check if the user exists
|
||||
user = Users.get_user_by_oauth_sub(provider_sub)
|
||||
user = await Users.get_user_by_oauth_sub(provider_sub)
|
||||
|
||||
if not user:
|
||||
# If the user does not exist, check if merging is enabled
|
||||
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
|
||||
# Check if the user exists by email
|
||||
user = Users.get_user_by_email(email)
|
||||
user = await Users.get_user_by_email(email)
|
||||
if user:
|
||||
# Update the user with the new oauth sub
|
||||
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||
await Users.update_user_oauth_sub_by_id(user.id, provider_sub)
|
||||
|
||||
if user:
|
||||
determined_role = self.get_user_role(user, user_data)
|
||||
determined_role = await self.get_user_role(user, user_data)
|
||||
if user.role != determined_role:
|
||||
Users.update_user_role_by_id(user.id, determined_role)
|
||||
await Users.update_user_role_by_id(user.id, determined_role)
|
||||
|
||||
# Update profile picture if enabled and different from current
|
||||
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
|
||||
|
|
@ -457,7 +459,7 @@ class OAuthManager:
|
|||
new_picture_url, token.get("access_token")
|
||||
)
|
||||
if processed_picture_url != user.profile_image_url:
|
||||
Users.update_user_profile_image_url_by_id(
|
||||
await Users.update_user_profile_image_url_by_id(
|
||||
user.id, processed_picture_url
|
||||
)
|
||||
log.debug(f"Updated profile picture for user {user.email}")
|
||||
|
|
@ -466,7 +468,7 @@ class OAuthManager:
|
|||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(email)
|
||||
existing_user = await Users.get_user_by_email(email)
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
|
|
@ -488,7 +490,7 @@ class OAuthManager:
|
|||
log.warning("Username claim is missing, using email as name")
|
||||
name = email
|
||||
|
||||
role = self.get_user_role(None, user_data)
|
||||
role = await self.get_user_role(None, user_data)
|
||||
|
||||
user = await Auths.insert_new_auth(
|
||||
email=email,
|
||||
|
|
@ -523,7 +525,7 @@ class OAuthManager:
|
|||
)
|
||||
|
||||
if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin":
|
||||
self.update_user_groups(
|
||||
await self.update_user_groups(
|
||||
user=user,
|
||||
user_data=user_data,
|
||||
default_permissions=request.app.state.config.USER_PERMISSIONS,
|
||||
|
|
|
|||
|
|
@ -68,17 +68,17 @@ def replace_imports(content):
|
|||
return content
|
||||
|
||||
|
||||
def load_tool_module_by_id(tool_id, content=None):
|
||||
async def load_tool_module_by_id(tool_id, content=None):
|
||||
|
||||
if content is None:
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tool = await Tools.get_tool_by_id(tool_id)
|
||||
if not tool:
|
||||
raise Exception(f"Toolkit not found: {tool_id}")
|
||||
|
||||
content = tool.content
|
||||
|
||||
content = replace_imports(content)
|
||||
Tools.update_tool_by_id(tool_id, {"content": content})
|
||||
await Tools.update_tool_by_id(tool_id, {"content": content})
|
||||
else:
|
||||
frontmatter = extract_frontmatter(content)
|
||||
# Install required packages found within the frontmatter
|
||||
|
|
@ -241,7 +241,7 @@ def install_frontmatter_requirements(requirements: str):
|
|||
log.info("No requirements found in frontmatter.")
|
||||
|
||||
|
||||
def install_tool_and_function_dependencies():
|
||||
async def install_tool_and_function_dependencies():
|
||||
"""
|
||||
Install all dependencies for all admin tools and active functions.
|
||||
|
||||
|
|
@ -249,8 +249,8 @@ def install_tool_and_function_dependencies():
|
|||
and then installing them using pip. Duplicates or similar version specifications are
|
||||
handled by pip as much as possible.
|
||||
"""
|
||||
function_list = Functions.get_functions(active_only=True)
|
||||
tool_list = Tools.get_tools()
|
||||
function_list = await Functions.get_functions(active_only=True)
|
||||
tool_list = await Tools.get_tools()
|
||||
|
||||
all_dependencies = ""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -68,13 +68,13 @@ def get_async_tool_function_and_apply_extra_params(
|
|||
return new_function
|
||||
|
||||
|
||||
def get_tools(
|
||||
async def get_tools(
|
||||
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
tools_dict = {}
|
||||
|
||||
for tool_id in tool_ids:
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tool = await Tools.get_tool_by_id(tool_id)
|
||||
if tool is None:
|
||||
if tool_id.startswith("server:"):
|
||||
server_idx = int(tool_id.split(":")[1])
|
||||
|
|
@ -140,18 +140,18 @@ def get_tools(
|
|||
else:
|
||||
module = request.app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_tool_module_by_id(tool_id)
|
||||
module, _ = await load_tool_module_by_id(tool_id)
|
||||
request.app.state.TOOLS[tool_id] = module
|
||||
|
||||
extra_params["__id__"] = tool_id
|
||||
|
||||
# Set valves for the tool
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
valves = await Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
if hasattr(module, "UserValves"):
|
||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
**(await Tools.get_user_valves_by_id_and_user_id(tool_id, user.id))
|
||||
)
|
||||
|
||||
for spec in tool.specs:
|
||||
|
|
|
|||
Loading…
Reference in a new issue