wip: users, tools

This commit is contained in:
Timothy Jaeryang Baek 2025-08-14 16:19:54 +04:00
parent 44e9ae243d
commit e1da74541b
32 changed files with 289 additions and 248 deletions

View file

@ -174,7 +174,7 @@ async def generate_function_chat_completion(
pipe_id, _ = pipe_id.split(".", 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
async def get_function_params(function_module, form_data, user, extra_params=None):
if extra_params is None:
extra_params = {}
@ -187,7 +187,9 @@ async def generate_function_chat_completion(
}
if "__user__" in params and hasattr(function_module, "UserValves"):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
user_valves = await Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
except Exception as e:
@ -232,7 +234,7 @@ async def generate_function_chat_completion(
"__metadata__": metadata,
"__request__": request,
}
extra_params["__tools__"] = get_tools(
extra_params["__tools__"] = await get_tools(
request,
tool_ids,
user,
@ -261,7 +263,7 @@ async def generate_function_chat_completion(
function_module = get_function_module_by_id(request, pipe_id)
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
params = await get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):

View file

@ -530,7 +530,7 @@ async def lifespan(app: FastAPI):
# This should be blocking (sync) so functions are not deactivated on first /get_models calls
# when the first user lands on the / route.
log.info("Installing external dependencies of functions and tools...")
install_tool_and_function_dependencies()
await install_tool_and_function_dependencies()
app.state.redis = get_redis_connection(
redis_url=REDIS_URL,
@ -1267,11 +1267,11 @@ if audit_level != AuditLevel.NONE:
async def get_models(
request: Request, refresh: bool = False, user=Depends(get_verified_user)
):
def get_filtered_models(models, user):
async def get_filtered_models(models, user):
filtered_models = []
for model in models:
if model.get("arena"):
if has_access(
if await has_access(
user.id,
type="read",
access_control=model.get("info", {})
@ -1286,7 +1286,7 @@ async def get_models(
if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
or user.id == model_info.user_id
or has_access(
or await has_access(
user.id, type="read", access_control=model_info.access_control
)
):
@ -1334,7 +1334,7 @@ async def get_models(
user.role == "user"
or (user.role == "admin" and not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
) and not BYPASS_MODEL_ACCESS_CONTROL:
models = get_filtered_models(models, user)
models = await get_filtered_models(models, user)
log.debug(
f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}"
@ -1408,7 +1408,7 @@ async def chat_completion(
user.role != "admin" or not ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS
):
try:
check_model_access(user, model)
await check_model_access(user, model)
except Exception as e:
raise e
else:
@ -1628,7 +1628,7 @@ async def get_app_config(request: Request):
if data is not None and "id" in data:
user = await Users.get_user_by_id(data["id"])
user_count = Users.get_num_users()
user_count = await Users.get_num_users()
onboarding = False
if user is None:

View file

@ -101,7 +101,7 @@ class ChannelTable:
channel
for channel in channels
if channel.user_id == user_id
or has_access(user_id, permission, channel.access_control)
or await has_access(user_id, permission, channel.access_control)
]
async def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:

View file

@ -143,7 +143,7 @@ class KnowledgeTable:
)
return knowledge_bases
def get_knowledge_bases_by_user_id(
async def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases()
@ -151,31 +151,31 @@ class KnowledgeTable:
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or has_access(user_id, permission, knowledge_base.access_control)
or await has_access(user_id, permission, knowledge_base.access_control)
]
def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
async def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]:
try:
async with get_db() as db:
knowledge = db.query(Knowledge).filter_by(id=id).first()
knowledge = await db.query(Knowledge).filter_by(id=id).first()
return KnowledgeModel.model_validate(knowledge) if knowledge else None
except Exception:
return None
def update_knowledge_by_id(
async def update_knowledge_by_id(
self, id: str, form_data: KnowledgeForm, overwrite: bool = False
) -> Optional[KnowledgeModel]:
try:
async with get_db() as db:
knowledge = self.get_knowledge_by_id(id=id)
db.query(Knowledge).filter_by(id=id).update(
knowledge = await self.get_knowledge_by_id(id=id)
await db.query(Knowledge).filter_by(id=id).update(
{
**form_data.model_dump(),
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_knowledge_by_id(id=id)
await db.commit()
return await self.get_knowledge_by_id(id=id)
except Exception as e:
log.exception(e)
return None

View file

@ -195,7 +195,7 @@ class ModelsTable:
for model in db.query(Model).filter(Model.base_model_id == None).all()
]
def get_models_by_user_id(
async def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelUserResponse]:
models = self.get_models()
@ -203,7 +203,7 @@ class ModelsTable:
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
or await has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:

View file

@ -96,20 +96,20 @@ class NoteTable:
db.commit()
return note
def get_notes(self) -> list[NoteModel]:
async def get_notes(self) -> list[NoteModel]:
async with get_db() as db:
notes = db.query(Note).order_by(Note.updated_at.desc()).all()
notes = await db.query(Note).order_by(Note.updated_at.desc()).all()
return [NoteModel.model_validate(note) for note in notes]
def get_notes_by_user_id(
async def get_notes_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[NoteModel]:
notes = self.get_notes()
notes = await self.get_notes()
return [
note
for note in notes
if note.user_id == user_id
or has_access(user_id, permission, note.access_control)
or await has_access(user_id, permission, note.access_control)
]
def get_note_by_id(self, id: str) -> Optional[NoteModel]:

View file

@ -127,7 +127,7 @@ class PromptsTable:
prompt
for prompt in prompts
if prompt.user_id == user_id
or has_access(user_id, permission, prompt.access_control)
or await has_access(user_id, permission, prompt.access_control)
]
def update_prompt_by_command(

View file

@ -107,7 +107,7 @@ class ToolValves(BaseModel):
class ToolsTable:
def insert_new_tool(
async def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: list[dict]
) -> Optional[ToolModel]:
async with get_db() as db:
@ -123,9 +123,9 @@ class ToolsTable:
try:
result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
await db.add(result)
await db.commit()
await db.refresh(result)
if result:
return ToolModel.model_validate(result)
else:
@ -134,10 +134,10 @@ class ToolsTable:
log.exception(f"Error creating a new tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
async def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
async with get_db() as db:
tool = db.get(Tool, id)
tool = await db.get(Tool, id)
return ToolModel.model_validate(tool)
except Exception:
return None
@ -145,7 +145,7 @@ class ToolsTable:
async def get_tools(self) -> list[ToolUserModel]:
async with get_db() as db:
tools = []
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
for tool in await db.query(Tool).order_by(Tool.updated_at.desc()).all():
user = await Users.get_user_by_id(tool.user_id)
tools.append(
ToolUserModel.model_validate(
@ -157,35 +157,37 @@ class ToolsTable:
)
return tools
def get_tools_by_user_id(
async def get_tools_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ToolUserModel]:
tools = self.get_tools()
tools = await self.get_tools()
return [
tool
for tool in tools
if tool.user_id == user_id
or has_access(user_id, permission, tool.access_control)
or await has_access(user_id, permission, tool.access_control)
]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
async def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
async with get_db() as db:
tool = db.get(Tool, id)
tool = await db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
log.exception(f"Error getting tool valves by id {id}: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
async def update_tool_valves_by_id(
self, id: str, valves: dict
) -> Optional[ToolValves]:
try:
async with get_db() as db:
db.query(Tool).filter_by(id=id).update(
await db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())}
)
db.commit()
return self.get_tool_by_id(id)
await db.commit()
return await self.get_tool_by_id(id)
except Exception:
return None
@ -225,7 +227,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings})
await Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["tools"]["valves"][id]
except Exception as e:
@ -234,25 +236,25 @@ class ToolsTable:
)
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
async def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
async with get_db() as db:
db.query(Tool).filter_by(id=id).update(
await db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
db.commit()
await db.commit()
tool = db.query(Tool).get(id)
db.refresh(tool)
tool = await db.query(Tool).get(id)
await db.refresh(tool)
return ToolModel.model_validate(tool)
except Exception:
return None
def delete_tool_by_id(self, id: str) -> bool:
async def delete_tool_by_id(self, id: str) -> bool:
try:
async with get_db() as db:
db.query(Tool).filter_by(id=id).delete()
db.commit()
await db.query(Tool).filter_by(id=id).delete()
await db.commit()
return True
except Exception:

View file

@ -517,7 +517,7 @@ def get_sources_from_items(
if note and (
user.role == "admin"
or note.user_id == user.id
or has_access(user.id, "read", note.access_control)
or await has_access(user.id, "read", note.access_control)
):
# User has access to the note
query_result = {
@ -581,7 +581,7 @@ def get_sources_from_items(
if knowledge_base and (
user.role == "admin"
or has_access(user.id, "read", knowledge_base.access_control)
or await has_access(user.id, "read", knowledge_base.access_control)
):
file_ids = knowledge_base.data.get("file_ids", [])

View file

@ -135,7 +135,7 @@ async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_verified_user)
):
if session_user:
user = Users.update_user_by_id(
user = await Users.update_user_by_id(
session_user.id,
{"profile_image_url": form_data.profile_image_url, "name": form_data.name},
)
@ -348,12 +348,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
if not connection_user.bind():
raise HTTPException(400, "Authentication failed.")
user = Users.get_user_by_email(email)
user = await Users.get_user_by_email(email)
if not user:
try:
role = (
"admin"
if not Users.has_users()
if not await Users.has_users()
else request.app.state.config.DEFAULT_USER_ROLE
)
@ -463,7 +463,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email)
if not Users.get_user_by_email(email.lower()):
if not await Users.get_user_by_email(email.lower()):
await signup(
request,
response,
@ -484,10 +484,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
admin_email = "admin@localhost"
admin_password = "admin"
if Users.get_user_by_email(admin_email.lower()):
if await Users.get_user_by_email(admin_email.lower()):
user = await Auths.authenticate_user(admin_email.lower(), admin_password)
else:
if Users.has_users():
if await Users.has_users():
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup(
@ -556,7 +556,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm):
@router.post("/signup", response_model=SessionUserResponse)
async def signup(request: Request, response: Response, form_data: SignupForm):
has_users = Users.has_users()
has_users = await Users.has_users()
if WEBUI_AUTH:
if (
@ -577,7 +577,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
if await Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
@ -733,7 +733,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
if await Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try:
@ -780,11 +780,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}")
if admin_email:
admin = Users.get_user_by_email(admin_email)
admin = await Users.get_user_by_email(admin_email)
if admin:
admin_name = admin.name
else:
admin = Users.get_first_user()
admin = await Users.get_first_user()
if admin:
admin_email = admin.email
admin_name = admin.name
@ -1025,7 +1025,7 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
)
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
success = await Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
@ -1038,14 +1038,14 @@ async def generate_api_key(request: Request, user=Depends(get_current_user)):
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None)
success = await Users.update_user_api_key_by_id(user.id, None)
return success
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id)
api_key = await Users.get_user_api_key_by_id(user.id)
if api_key:
return {
"api_key": api_key,

View file

@ -80,7 +80,7 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -157,7 +157,7 @@ async def get_channel_messages(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -197,7 +197,7 @@ async def get_channel_messages(
async def send_notification(name, webui_url, channel, message, active_user_ids):
users = get_users_with_access("read", channel.access_control)
users = await get_users_with_access("read", channel.access_control)
for user in users:
if user.id in active_user_ids:
@ -236,7 +236,7 @@ async def post_new_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -341,7 +341,7 @@ async def get_channel_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -390,7 +390,7 @@ async def get_channel_thread_messages(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -452,7 +452,9 @@ async def update_message_by_id(
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control)
and not await has_access(
user.id, type="read", access_control=channel.access_control
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -512,7 +514,7 @@ async def add_reaction_to_message(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -578,7 +580,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role != "admin" and not has_access(
if user.role != "admin" and not await has_access(
user.id, type="read", access_control=channel.access_control
):
raise HTTPException(
@ -661,7 +663,9 @@ async def delete_message_by_id(
if (
user.role != "admin"
and message.user_id != user.id
and not has_access(user.id, type="read", access_control=channel.access_control)
and not await has_access(
user.id, type="read", access_control=channel.access_control
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()

View file

@ -440,10 +440,10 @@ async def update_function_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
user_valves = await Functions.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
@ -482,7 +482,7 @@ async def get_function_user_valves_spec_by_id(
async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
function = Functions.get_function_by_id(id)
function = await Functions.get_function_by_id(id)
if function:
function_module, function_type, frontmatter = get_function_module_from_cache(
@ -495,7 +495,7 @@ async def update_function_user_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Functions.update_user_valves_by_id_and_user_id(
await Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()

View file

@ -90,7 +90,7 @@ async def update_group_by_id(
):
try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.update_group_by_id(id, form_data)
if group:
@ -119,7 +119,7 @@ async def add_user_to_group(
):
try:
if form_data.user_ids:
form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids)
form_data.user_ids = await Users.get_valid_user_ids(form_data.user_ids)
group = Groups.add_users_to_group(id, form_data.user_ids)
if group:

View file

@ -261,11 +261,11 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if (
user.role == "admin"
or knowledge.user_id == user.id
or has_access(user.id, "read", knowledge.access_control)
or await has_access(user.id, "read", knowledge.access_control)
):
file_ids = knowledge.data.get("file_ids", []) if knowledge.data else []
files = Files.get_file_metadatas_by_ids(file_ids)
files = await Files.get_file_metadatas_by_ids(file_ids)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
@ -298,7 +298,7 @@ async def update_knowledge_by_id(
# Is the user the original creator, in a group with write access, or an admin
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -348,7 +348,7 @@ def add_file_to_knowledge_by_id(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -432,7 +432,7 @@ def update_file_from_knowledge_by_id(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
@ -503,7 +503,7 @@ def remove_file_from_knowledge_by_id(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -591,7 +591,7 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -654,7 +654,7 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -697,7 +697,7 @@ def add_files_to_knowledge_batch(
if (
knowledge.user_id != user.id
and not has_access(user.id, "write", knowledge.access_control)
and not await has_access(user.id, "write", knowledge.access_control)
and user.role != "admin"
):
raise HTTPException(

View file

@ -119,7 +119,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
if (
(user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS)
or model.user_id == user.id
or has_access(user.id, "read", model.access_control)
or await has_access(user.id, "read", model.access_control)
):
return model
else:
@ -141,7 +141,7 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user)):
if (
user.role == "admin"
or model.user_id == user.id
or has_access(user.id, "write", model.access_control)
or await has_access(user.id, "write", model.access_control)
):
model = Models.toggle_model_by_id(id)
@ -185,7 +185,7 @@ async def update_model_by_id(
if (
model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
and not await has_access(user.id, "write", model.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -204,7 +204,7 @@ async def update_model_by_id(
@router.delete("/model/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
model = await Models.get_model_by_id(id)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -214,18 +214,18 @@ async def delete_model_by_id(id: str, user=Depends(get_verified_user)):
if (
user.role != "admin"
and model.user_id != user.id
and not has_access(user.id, "write", model.access_control)
and not await has_access(user.id, "write", model.access_control)
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Models.delete_model_by_id(id)
result = await Models.delete_model_by_id(id)
return result
@router.delete("/delete/all", response_model=bool)
async def delete_all_models(user=Depends(get_admin_user)):
result = Models.delete_all_models()
result = await Models.delete_all_models()
return result

View file

@ -133,7 +133,11 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us
if user.role != "admin" and (
user.id != note.user_id
and (not has_access(user.id, type="read", access_control=note.access_control))
and (
not await has_access(
user.id, type="read", access_control=note.access_control
)
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -167,7 +171,9 @@ async def update_note_by_id(
if user.role != "admin" and (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
and not await has_access(
user.id, type="write", access_control=note.access_control
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
@ -212,7 +218,9 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified
if user.role != "admin" and (
user.id != note.user_id
and not has_access(user.id, type="write", access_control=note.access_control)
and not await has_access(
user.id, type="write", access_control=note.access_control
)
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()

View file

@ -443,8 +443,10 @@ async def get_filtered_models(models, user):
for model in models.get("models", []):
model_info = Models.get_model_by_id(model["model"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
if user.id == model_info.user_id or (
await has_access(
user.id, type="read", access_control=model_info.access_control
)
):
filtered_models.append(model)
return filtered_models
@ -1336,8 +1338,10 @@ async def generate_chat_completion(
if not bypass_filter and user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
or (
await has_access(
user.id, type="read", access_control=model_info.access_control
)
)
):
raise HTTPException(
@ -1442,8 +1446,10 @@ async def generate_openai_completion(
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
or (
await has_access(
user.id, type="read", access_control=model_info.access_control
)
)
):
raise HTTPException(
@ -1525,8 +1531,10 @@ async def generate_openai_chat_completion(
if user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
or (
await has_access(
user.id, type="read", access_control=model_info.access_control
)
)
):
raise HTTPException(
@ -1623,8 +1631,10 @@ async def get_openai_models(
for model in models:
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
if user.id == model_info.user_id or (
await has_access(
user.id, type="read", access_control=model_info.access_control
)
):
filtered_models.append(model)
models = filtered_models

View file

@ -385,8 +385,10 @@ async def get_filtered_models(models, user):
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
if user.id == model_info.user_id or (
await has_access(
user.id, type="read", access_control=model_info.access_control
)
):
filtered_models.append(model)
return filtered_models
@ -756,7 +758,7 @@ async def generate_chat_completion(
if not bypass_filter and user.role == "user":
if not (
user.id == model_info.user_id
or has_access(
or await has_access(
user.id, type="read", access_control=model_info.access_control
)
):

View file

@ -79,13 +79,13 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
prompt = await Prompts.get_prompt_by_command(f"/{command}")
if prompt:
if (
user.role == "admin"
or prompt.user_id == user.id
or has_access(user.id, "read", prompt.access_control)
or await has_access(user.id, "read", prompt.access_control)
):
return prompt
else:
@ -106,7 +106,7 @@ async def update_prompt_by_command(
form_data: PromptForm,
user=Depends(get_verified_user),
):
prompt = Prompts.get_prompt_by_command(f"/{command}")
prompt = await Prompts.get_prompt_by_command(f"/{command}")
if not prompt:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -116,7 +116,7 @@ async def update_prompt_by_command(
# Is the user the original creator, in a group with write access, or an admin
if (
prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control)
and not await has_access(user.id, "write", prompt.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -124,7 +124,7 @@ async def update_prompt_by_command(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
prompt = await Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
else:
@ -150,7 +150,7 @@ async def delete_prompt_by_command(command: str, user=Depends(get_verified_user)
if (
prompt.user_id != user.id
and not has_access(user.id, "write", prompt.access_control)
and not await has_access(user.id, "write", prompt.access_control)
and user.role != "admin"
):
raise HTTPException(

View file

@ -297,7 +297,7 @@ def get_scim_auth(
)
def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
async def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
"""Convert internal User model to SCIM User"""
# Parse display name into name components
name_parts = user.name.split(" ", 1) if user.name else ["", ""]
@ -346,7 +346,7 @@ def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
)
def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
async def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
"""Convert internal Group model to SCIM Group"""
members = []
for user_id in group.user_ids:
@ -493,20 +493,20 @@ async def get_users(
# In production, you'd want a more robust filter parser
if "userName eq" in filter:
email = filter.split('"')[1]
user = Users.get_user_by_email(email)
user = await Users.get_user_by_email(email)
users_list = [user] if user else []
total = 1 if user else 0
else:
response = Users.get_users(skip=skip, limit=limit)
response = await Users.get_users(skip=skip, limit=limit)
users_list = response["users"]
total = response["total"]
else:
response = Users.get_users(skip=skip, limit=limit)
response = await Users.get_users(skip=skip, limit=limit)
users_list = response["users"]
total = response["total"]
# Convert to SCIM format
scim_users = [user_to_scim(user, request) for user in users_list]
scim_users = [await user_to_scim(user, request) for user in users_list]
return SCIMListResponse(
totalResults=total,
@ -529,7 +529,7 @@ async def get_user(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
)
return user_to_scim(user, request)
return await user_to_scim(user, request)
@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
@ -540,7 +540,7 @@ async def create_user(
):
"""Create SCIM User"""
# Check if user already exists
existing_user = Users.get_user_by_email(user_data.userName)
existing_user = await Users.get_user_by_email(user_data.userName)
if existing_user:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
@ -579,7 +579,7 @@ async def create_user(
detail="Failed to create user",
)
return user_to_scim(new_user, request)
return await user_to_scim(new_user, request)
@router.put("/Users/{user_id}", response_model=SCIMUser)
@ -623,14 +623,14 @@ async def update_user(
update_data["profile_image_url"] = user_data.photos[0].value
# Update user
updated_user = Users.update_user_by_id(user_id, update_data)
updated_user = await Users.update_user_by_id(user_id, update_data)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update user",
)
return user_to_scim(updated_user, request)
return await user_to_scim(updated_user, request)
@router.patch("/Users/{user_id}", response_model=SCIMUser)
@ -669,7 +669,7 @@ async def patch_user(
# Update user
if update_data:
updated_user = Users.update_user_by_id(user_id, update_data)
updated_user = await Users.update_user_by_id(user_id, update_data)
if not updated_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -678,7 +678,7 @@ async def patch_user(
else:
updated_user = user
return user_to_scim(updated_user, request)
return await user_to_scim(updated_user, request)
@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@ -695,7 +695,7 @@ async def delete_user(
detail=f"User {user_id} not found",
)
success = Users.delete_user_by_id(user_id)
success = await Users.delete_user_by_id(user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -716,7 +716,7 @@ async def get_groups(
):
"""List SCIM Groups"""
# Get all groups
groups_list = Groups.get_groups()
groups_list = await Groups.get_groups()
# Apply pagination
total = len(groups_list)
@ -725,7 +725,7 @@ async def get_groups(
paginated_groups = groups_list[start:end]
# Convert to SCIM format
scim_groups = [group_to_scim(group, request) for group in paginated_groups]
scim_groups = [await group_to_scim(group, request) for group in paginated_groups]
return SCIMListResponse(
totalResults=total,
@ -742,14 +742,14 @@ async def get_group(
_: bool = Depends(get_scim_auth),
):
"""Get SCIM Group by ID"""
group = Groups.get_group_by_id(group_id)
group = await Groups.get_group_by_id(group_id)
if not group:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Group {group_id} not found",
)
return group_to_scim(group, request)
return await group_to_scim(group, request)
@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
@ -774,14 +774,14 @@ async def create_group(
)
# Need to get the creating user's ID - we'll use the first admin
admin_user = Users.get_super_admin_user()
admin_user = await Users.get_super_admin_user()
if not admin_user:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="No admin user found",
)
new_group = Groups.insert_new_group(admin_user.id, form)
new_group = await Groups.insert_new_group(admin_user.id, form)
if not new_group:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -797,10 +797,10 @@ async def create_group(
description=new_group.description,
user_ids=member_ids,
)
Groups.update_group_by_id(new_group.id, update_form)
new_group = Groups.get_group_by_id(new_group.id)
await Groups.update_group_by_id(new_group.id, update_form)
new_group = await Groups.get_group_by_id(new_group.id)
return group_to_scim(new_group, request)
return await group_to_scim(new_group, request)
@router.put("/Groups/{group_id}", response_model=SCIMGroup)
@ -839,7 +839,7 @@ async def update_group(
detail="Failed to update group",
)
return group_to_scim(updated_group, request)
return await group_to_scim(updated_group, request)
@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
@ -899,7 +899,7 @@ async def patch_group(
detail="Failed to update group",
)
return group_to_scim(updated_group, request)
return await group_to_scim(updated_group, request)
@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)

View file

@ -49,7 +49,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
request.app.state.config.TOOL_SERVER_CONNECTIONS
)
tools = Tools.get_tools()
tools = await Tools.get_tools()
for server in request.app.state.TOOL_SERVERS:
tools.append(
ToolUserResponse(
@ -83,7 +83,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
tool
for tool in tools
if tool.user_id == user.id
or has_access(user.id, "read", tool.access_control)
or await has_access(user.id, "read", tool.access_control)
]
return tools
@ -96,9 +96,9 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
@router.get("/list", response_model=list[ToolUserResponse])
async def get_tool_list(user=Depends(get_verified_user)):
if user.role == "admin" and ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS:
tools = Tools.get_tools()
tools = await Tools.get_tools()
else:
tools = Tools.get_tools_by_user_id(user.id, "write")
tools = await Tools.get_tools_by_user_id(user.id, "write")
return tools
@ -184,7 +184,7 @@ async def load_tool_from_url(
@router.get("/export", response_model=list[ToolModel])
async def export_tools(user=Depends(get_admin_user)):
tools = Tools.get_tools()
tools = await Tools.get_tools()
return tools
@ -215,11 +215,11 @@ async def create_new_tools(
form_data.id = form_data.id.lower()
tools = Tools.get_tool_by_id(form_data.id)
tools = await Tools.get_tool_by_id(form_data.id)
if tools is None:
try:
form_data.content = replace_imports(form_data.content)
tool_module, frontmatter = load_tool_module_by_id(
tool_module, frontmatter = await load_tool_module_by_id(
form_data.id, content=form_data.content
)
form_data.meta.manifest = frontmatter
@ -228,7 +228,7 @@ async def create_new_tools(
TOOLS[form_data.id] = tool_module
specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs)
tools = await Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
@ -260,13 +260,13 @@ async def create_new_tools(
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if tools:
if (
user.role == "admin"
or tools.user_id == user.id
or has_access(user.id, "read", tools.access_control)
or await has_access(user.id, "read", tools.access_control)
):
return tools
else:
@ -288,7 +288,7 @@ async def update_tools_by_id(
form_data: ToolForm,
user=Depends(get_verified_user),
):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -298,7 +298,7 @@ async def update_tools_by_id(
# Is the user the original creator, in a group with write access, or an admin
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and not await has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -308,7 +308,9 @@ async def update_tools_by_id(
try:
form_data.content = replace_imports(form_data.content)
tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
tool_module, frontmatter = await load_tool_module_by_id(
id, content=form_data.content
)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
@ -322,7 +324,7 @@ async def update_tools_by_id(
}
log.debug(updated)
tools = Tools.update_tool_by_id(id, updated)
tools = await Tools.update_tool_by_id(id, updated)
if tools:
return tools
@ -348,7 +350,7 @@ async def update_tools_by_id(
async def delete_tools_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -357,7 +359,7 @@ async def delete_tools_by_id(
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and not await has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -365,7 +367,7 @@ async def delete_tools_by_id(
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
result = Tools.delete_tool_by_id(id)
result = await Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
@ -381,10 +383,10 @@ async def delete_tools_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if tools:
try:
valves = Tools.get_tool_valves_by_id(id)
valves = await Tools.get_tool_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
@ -407,12 +409,12 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"):
@ -435,7 +437,7 @@ async def get_tools_valves_spec_by_id(
async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -444,7 +446,7 @@ async def update_tools_valves_by_id(
if (
tools.user_id != user.id
and not has_access(user.id, "write", tools.access_control)
and not await has_access(user.id, "write", tools.access_control)
and user.role != "admin"
):
raise HTTPException(
@ -455,7 +457,7 @@ async def update_tools_valves_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"):
@ -468,7 +470,7 @@ async def update_tools_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump())
await Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
log.exception(f"Failed to update tool valves by id {id}: {e}")
@ -485,10 +487,10 @@ async def update_tools_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if tools:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
user_valves = await Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
@ -506,12 +508,12 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
@ -529,13 +531,13 @@ async def get_tools_user_valves_spec_by_id(
async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
tools = Tools.get_tool_by_id(id)
tools = await Tools.get_tool_by_id(id)
if tools:
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
tools_module, _ = load_tool_module_by_id(id)
tools_module, _ = await load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
@ -544,7 +546,7 @@ async def update_tools_user_valves_by_id(
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Tools.update_user_valves_by_id_and_user_id(
await Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()

View file

@ -88,14 +88,14 @@ async def get_users(
if direction:
filter["direction"] = direction
return Users.get_users(filter=filter, skip=skip, limit=limit)
return await Users.get_users(filter=filter, skip=skip, limit=limit)
@router.get("/all", response_model=UserInfoListResponse)
async def get_all_users(
user=Depends(get_admin_user),
):
return Users.get_users()
return await Users.get_users()
############################
@ -205,7 +205,7 @@ async def update_default_user_permissions(
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
user = await Users.get_user_by_id(user.id)
if user:
return user.settings
else:
@ -237,7 +237,7 @@ async def update_user_settings_by_session_user(
# If the user is not an admin and does not have permission to use tool servers, remove the key
updated_user_settings["ui"].pop("toolServers", None)
user = Users.update_user_settings_by_id(user.id, updated_user_settings)
user = await Users.update_user_settings_by_id(user.id, updated_user_settings)
if user:
return user.settings
else:
@ -254,7 +254,7 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
user = await Users.get_user_by_id(user.id)
if user:
return user.info
else:
@ -273,12 +273,14 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
user = await Users.get_user_by_id(user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
user = await Users.update_user_by_id(
user.id, {"info": {**user.info, **form_data}}
)
if user:
return user.info
else:
@ -343,7 +345,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.get("/{user_id}/profile/image")
async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
user = Users.get_user_by_id(user_id)
user = await Users.get_user_by_id(user_id)
if user:
if user.profile_image_url:
# check if it's url or base64
@ -398,7 +400,7 @@ async def update_user_by_id(
):
# Prevent modification of the primary admin user by other admins
try:
first_user = Users.get_first_user()
first_user = await Users.get_first_user()
if first_user:
if user_id == first_user.id:
if session_user.id != user_id:
@ -472,7 +474,7 @@ async def update_user_by_id(
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
# Prevent deletion of the primary admin user
try:
first_user = Users.get_first_user()
first_user = await Users.get_first_user()
if first_user and user_id == first_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,

View file

@ -263,7 +263,7 @@ async def connect(sid, environ, auth):
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
user = await Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.model_dump()
@ -284,7 +284,7 @@ async def user_join(sid, data):
if data is None or "id" not in data:
return
user = Users.get_user_by_id(data["id"])
user = await Users.get_user_by_id(data["id"])
if not user:
return
@ -312,7 +312,7 @@ async def join_channel(sid, data):
if data is None or "id" not in data:
return
user = Users.get_user_by_id(data["id"])
user = await Users.get_user_by_id(data["id"])
if not user:
return
@ -333,7 +333,7 @@ async def join_note(sid, data):
if token_data is None or "id" not in token_data:
return
user = Users.get_user_by_id(token_data["id"])
user = await Users.get_user_by_id(token_data["id"])
if not user:
return
@ -345,7 +345,9 @@ async def join_note(sid, data):
if (
user.role != "admin"
and user.id != note.user_id
and not has_access(user.id, type="read", access_control=note.access_control)
and not await has_access(
user.id, type="read", access_control=note.access_control
)
):
log.error(f"User {user.id} does not have access to note {data['note_id']}")
return
@ -400,7 +402,7 @@ async def ydoc_document_join(sid, data):
if (
user.get("role") != "admin"
and user.get("id") != note.user_id
and not has_access(
and not await has_access(
user.get("id"), type="read", access_control=note.access_control
)
):
@ -470,14 +472,14 @@ async def document_save_handler(document_id, data, user):
if (
user.get("role") != "admin"
and user.get("id") != note.user_id
and not has_access(
and not await has_access(
user.get("id"), type="read", access_control=note.access_control
)
):
log.error(f"User {user.get('id')} does not have access to note {note_id}")
return
Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
await Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
@sio.on("ydoc:document:state")

View file

@ -107,7 +107,7 @@ def has_permission(
return get_permission(default_permissions, permission_hierarchy)
def has_access(
async def has_access(
user_id: str,
type: str = "write",
access_control: Optional[dict] = None,
@ -115,7 +115,7 @@ def has_access(
if access_control is None:
return type == "read"
user_groups = Groups.get_groups_by_member_id(user_id)
user_groups = await Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups]
permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", [])
@ -127,11 +127,11 @@ def has_access(
# Get all users with access to a resource
def get_users_with_access(
async def get_users_with_access(
type: str = "write", access_control: Optional[dict] = None
) -> List[UserModel]:
if access_control is None:
return Users.get_users()
return await Users.get_users()
permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", [])
@ -140,8 +140,8 @@ def get_users_with_access(
user_ids_with_access = set(permitted_user_ids)
for group_id in permitted_group_ids:
group_user_ids = Groups.get_group_user_ids_by_id(group_id)
group_user_ids = await Groups.get_group_user_ids_by_id(group_id)
if group_user_ids:
user_ids_with_access.update(group_user_ids)
return Users.get_users_by_user_ids(list(user_ids_with_access))
return await Users.get_users_by_user_ids(list(user_ids_with_access))

View file

@ -206,7 +206,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
return None
def get_current_user(
async def get_current_user(
request: Request,
response: Response,
background_tasks: BackgroundTasks,
@ -270,7 +270,7 @@ def get_current_user(
)
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
user = await Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -303,6 +303,7 @@ def get_current_user(
# Refresh the user's last active timestamp asynchronously
# to prevent blocking the request
if background_tasks:
background_tasks.add_task(Users.update_user_last_active_by_id, user.id)
return user
else:
@ -312,8 +313,8 @@ def get_current_user(
)
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
async def get_current_user_by_api_key(api_key: str):
user = await Users.get_user_by_api_key(api_key)
if user is None:
raise HTTPException(
@ -329,7 +330,7 @@ def get_current_user_by_api_key(api_key: str):
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
Users.update_user_last_active_by_id(user.id)
await Users.update_user_last_active_by_id(user.id)
return user

View file

@ -199,7 +199,7 @@ async def generate_chat_completion(
# Check if user has access to the model
if not bypass_filter and user.role == "user":
try:
check_model_access(user, model)
await check_model_access(user, model)
except Exception as e:
raise e
@ -424,8 +424,10 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
**(
await Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
)
except Exception as e:

View file

@ -70,7 +70,7 @@ async def generate_embeddings(
# Access filtering
if not getattr(request.state, "direct", False):
if not bypass_filter and user.role == "user":
check_model_access(user, model)
await check_model_access(user, model)
# Ollama backend
if model.get("owned_by") == "ollama":

View file

@ -109,8 +109,10 @@ async def process_filter_functions(
if hasattr(function_module, "UserValves"):
try:
params["__user__"]["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
**(
await Functions.get_user_valves_by_id_and_user_id(
filter_id, params["__user__"]["id"]
)
)
)
except Exception as e:

View file

@ -311,9 +311,9 @@ async def get_all_models(request, refresh: bool = False, user: UserModel = None)
return models
def check_model_access(user, model):
async def check_model_access(user, model):
if model.get("arena"):
if not has_access(
if not await has_access(
user.id,
type="read",
access_control=model.get("info", {})
@ -322,12 +322,12 @@ def check_model_access(user, model):
):
raise Exception("Model not found")
else:
model_info = Models.get_model_by_id(model.get("id"))
model_info = await Models.get_model_by_id(model.get("id"))
if not model_info:
raise Exception("Model not found")
elif not (
user.id == model_info.user_id
or has_access(
or await has_access(
user.id, type="read", access_control=model_info.access_control
)
):

View file

@ -89,8 +89,8 @@ class OAuthManager:
def get_client(self, provider_name):
return self.oauth.create_client(provider_name)
def get_user_role(self, user, user_data):
user_count = Users.get_num_users()
async def get_user_role(self, user, user_data):
user_count = await Users.get_num_users()
if user and user_count == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
log.debug("Assigning the only user the admin role")
@ -147,7 +147,7 @@ class OAuthManager:
return role
def update_user_groups(self, user, user_data, default_permissions):
async def update_user_groups(self, user, user_data, default_permissions):
log.debug("Running OAUTH Group management")
oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM
@ -172,8 +172,10 @@ class OAuthManager:
else:
user_oauth_groups = []
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
all_available_groups: list[GroupModel] = Groups.get_groups()
user_current_groups: list[GroupModel] = await Groups.get_groups_by_member_id(
user.id
)
all_available_groups: list[GroupModel] = await Groups.get_groups()
# Create groups if they don't exist and creation is enabled
if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION:
@ -181,7 +183,7 @@ class OAuthManager:
all_group_names = {g.name for g in all_available_groups}
groups_created = False
# Determine creator ID: Prefer admin, fallback to current user if no admin exists
admin_user = Users.get_super_admin_user()
admin_user = await Users.get_super_admin_user()
creator_id = admin_user.id if admin_user else user.id
log.debug(f"Using creator ID {creator_id} for potential group creation.")
@ -198,7 +200,7 @@ class OAuthManager:
user_ids=[], # Start with no users, user will be added later by subsequent logic
)
# Use determined creator ID (admin or fallback to current user)
created_group = Groups.insert_new_group(
created_group = await Groups.insert_new_group(
creator_id, new_group_form
)
if created_group:
@ -217,7 +219,7 @@ class OAuthManager:
# Refresh the list of all available groups if any were created
if groups_created:
all_available_groups = Groups.get_groups()
all_available_groups = await Groups.get_groups()
log.debug("Refreshed list of all available groups after creation.")
log.debug(f"Oauth Groups claim: {oauth_claim}")
@ -430,21 +432,21 @@ class OAuthManager:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
user = await Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL:
# Check if the user exists by email
user = Users.get_user_by_email(email)
user = await Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
await Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = self.get_user_role(user, user_data)
determined_role = await self.get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
await Users.update_user_role_by_id(user.id, determined_role)
# Update profile picture if enabled and different from current
if auth_manager_config.OAUTH_UPDATE_PICTURE_ON_LOGIN:
@ -457,7 +459,7 @@ class OAuthManager:
new_picture_url, token.get("access_token")
)
if processed_picture_url != user.profile_image_url:
Users.update_user_profile_image_url_by_id(
await Users.update_user_profile_image_url_by_id(
user.id, processed_picture_url
)
log.debug(f"Updated profile picture for user {user.email}")
@ -466,7 +468,7 @@ class OAuthManager:
# If the user does not exist, check if signups are enabled
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(email)
existing_user = await Users.get_user_by_email(email)
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
@ -488,7 +490,7 @@ class OAuthManager:
log.warning("Username claim is missing, using email as name")
name = email
role = self.get_user_role(None, user_data)
role = await self.get_user_role(None, user_data)
user = await Auths.insert_new_auth(
email=email,
@ -523,7 +525,7 @@ class OAuthManager:
)
if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT and user.role != "admin":
self.update_user_groups(
await self.update_user_groups(
user=user,
user_data=user_data,
default_permissions=request.app.state.config.USER_PERMISSIONS,

View file

@ -68,17 +68,17 @@ def replace_imports(content):
return content
def load_tool_module_by_id(tool_id, content=None):
async def load_tool_module_by_id(tool_id, content=None):
if content is None:
tool = Tools.get_tool_by_id(tool_id)
tool = await Tools.get_tool_by_id(tool_id)
if not tool:
raise Exception(f"Toolkit not found: {tool_id}")
content = tool.content
content = replace_imports(content)
Tools.update_tool_by_id(tool_id, {"content": content})
await Tools.update_tool_by_id(tool_id, {"content": content})
else:
frontmatter = extract_frontmatter(content)
# Install required packages found within the frontmatter
@ -241,7 +241,7 @@ def install_frontmatter_requirements(requirements: str):
log.info("No requirements found in frontmatter.")
def install_tool_and_function_dependencies():
async def install_tool_and_function_dependencies():
"""
Install all dependencies for all admin tools and active functions.
@ -249,8 +249,8 @@ def install_tool_and_function_dependencies():
and then installing them using pip. Duplicates or similar version specifications are
handled by pip as much as possible.
"""
function_list = Functions.get_functions(active_only=True)
tool_list = Tools.get_tools()
function_list = await Functions.get_functions(active_only=True)
tool_list = await Tools.get_tools()
all_dependencies = ""
try:

View file

@ -68,13 +68,13 @@ def get_async_tool_function_and_apply_extra_params(
return new_function
def get_tools(
async def get_tools(
request: Request, tool_ids: list[str], user: UserModel, extra_params: dict
) -> dict[str, dict]:
tools_dict = {}
for tool_id in tool_ids:
tool = Tools.get_tool_by_id(tool_id)
tool = await Tools.get_tool_by_id(tool_id)
if tool is None:
if tool_id.startswith("server:"):
server_idx = int(tool_id.split(":")[1])
@ -140,18 +140,18 @@ def get_tools(
else:
module = request.app.state.TOOLS.get(tool_id, None)
if module is None:
module, _ = load_tool_module_by_id(tool_id)
module, _ = await load_tool_module_by_id(tool_id)
request.app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id
# Set valves for the tool
if hasattr(module, "valves") and hasattr(module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id) or {}
valves = await Tools.get_tool_valves_by_id(tool_id) or {}
module.valves = module.Valves(**valves)
if hasattr(module, "UserValves"):
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
**(await Tools.get_user_valves_by_id_and_user_id(tool_id, user.id))
)
for spec in tool.specs: